forked from mystiq/dex
*: add functional tests for case insensitive emails
This commit is contained in:
parent
9bc68edae7
commit
208afd3b01
2 changed files with 108 additions and 5 deletions
|
@ -3,10 +3,12 @@ package db
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
)
|
||||
|
||||
func initDB(dsn string) *gorp.DbMap {
|
||||
|
@ -14,6 +16,9 @@ func initDB(dsn string) *gorp.DbMap {
|
|||
if err != nil {
|
||||
panic(fmt.Sprintf("error making db connection: %q", err))
|
||||
}
|
||||
if _, err := c.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", migrationTable)); err != nil {
|
||||
panic(fmt.Sprintf("failed to drop migration table: %v", err))
|
||||
}
|
||||
if err = c.DropTablesIfExists(); err != nil {
|
||||
panic(fmt.Sprintf("Unable to drop database tables: %v", err))
|
||||
}
|
||||
|
@ -119,3 +124,84 @@ func TestMigrateClientMetadata(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationNumber11(t *testing.T) {
|
||||
dsn := os.Getenv("DEX_TEST_DSN")
|
||||
if dsn == "" {
|
||||
t.Skip("Test will not run without DEX_TEST_DSN environment variable.")
|
||||
return
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
sqlStmt string
|
||||
wantEmails []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
sqlStmt: `INSERT INTO authd_user
|
||||
(id, email, email_verified, display_name, admin, created_at)
|
||||
VALUES
|
||||
(1, 'Foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())),
|
||||
(2, 'Bar@example.com', TRUE, 'foo', FALSE, extract(epoch from now()))
|
||||
;`,
|
||||
wantEmails: []string{"foo@example.com", "bar@example.com"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
sqlStmt: `INSERT INTO authd_user
|
||||
(id, email, email_verified, display_name, admin, created_at)
|
||||
VALUES
|
||||
(1, 'Foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())),
|
||||
(2, 'foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())),
|
||||
(3, 'bar@example.com', TRUE, 'foo', FALSE, extract(epoch from now()))
|
||||
;`,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
migrateN := func(dbMap *gorp.DbMap, n int) error {
|
||||
nPerformed, err := MigrateMaxMigrations(dbMap, n)
|
||||
if err == nil && n != nPerformed {
|
||||
err = fmt.Errorf("expected to perform %d migrations, performed %d", n, nPerformed)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := func() error {
|
||||
dbMap := initDB(dsn)
|
||||
|
||||
nMigrations := 10
|
||||
if err := migrateN(dbMap, nMigrations); err != nil {
|
||||
return fmt.Errorf("failed to perform initial migration: %v", err)
|
||||
}
|
||||
if _, err := dbMap.Exec(tt.sqlStmt); err != nil {
|
||||
return fmt.Errorf("failed to insert users: %v", err)
|
||||
}
|
||||
if err := migrateN(dbMap, 1); err != nil {
|
||||
if tt.wantError {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to perform migration: %v", err)
|
||||
}
|
||||
|
||||
if tt.wantError {
|
||||
return fmt.Errorf("expected an error when migrating")
|
||||
}
|
||||
|
||||
var gotEmails []string
|
||||
if _, err := dbMap.Select(&gotEmails, `SELECT email FROM authd_user;`); err != nil {
|
||||
return fmt.Errorf("could not get user emails: %v", err)
|
||||
}
|
||||
|
||||
sort.Strings(tt.wantEmails)
|
||||
sort.Strings(gotEmails)
|
||||
if diff := pretty.Compare(tt.wantEmails, gotEmails); diff != "" {
|
||||
return fmt.Errorf("wantEmails != gotEmails: %s", diff)
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
t.Errorf("case %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -104,6 +104,15 @@ func TestNewUser(t *testing.T) {
|
|||
},
|
||||
err: user.ErrorDuplicateEmail,
|
||||
},
|
||||
{
|
||||
user: user.User{
|
||||
ID: "ID-same",
|
||||
Email: "email-1@example.com",
|
||||
DisplayName: "A lower case version of the original email",
|
||||
CreatedAt: now,
|
||||
},
|
||||
err: user.ErrorDuplicateEmail,
|
||||
},
|
||||
{
|
||||
user: user.User{
|
||||
Email: "AnotherEmail@example.com",
|
||||
|
@ -421,12 +430,19 @@ func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int
|
|||
|
||||
func TestGetByEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
email string
|
||||
wantErr error
|
||||
email string
|
||||
wantEmail string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
email: "Email-1@example.com",
|
||||
wantErr: nil,
|
||||
email: "Email-1@example.com",
|
||||
wantEmail: "Email-1@example.com",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
email: "EMAIL-1@example.com", // Emails should be case insensitive.
|
||||
wantEmail: "Email-1@example.com",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
email: "NoSuchEmail@example.com",
|
||||
|
@ -446,9 +462,10 @@ func TestGetByEmail(t *testing.T) {
|
|||
|
||||
if gotErr != nil {
|
||||
t.Errorf("case %d: want nil err:% q", i, gotErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if tt.email != gotUser.Email {
|
||||
if tt.wantEmail != gotUser.Email {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, tt.email, gotUser.Email)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue