From 208afd3b0186fedb53d7ce1567a50028c4435add Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Tue, 1 Mar 2016 10:54:12 -0800 Subject: [PATCH] *: add functional tests for case insensitive emails --- db/migrate_test.go | 86 +++++++++++++++++++++++++++++++ functional/repo/user_repo_test.go | 27 ++++++++-- 2 files changed, 108 insertions(+), 5 deletions(-) diff --git a/db/migrate_test.go b/db/migrate_test.go index a7251689..e4de0534 100644 --- a/db/migrate_test.go +++ b/db/migrate_test.go @@ -3,10 +3,12 @@ package db import ( "fmt" "os" + "sort" "strconv" "testing" "github.com/go-gorp/gorp" + "github.com/kylelemons/godebug/pretty" ) func initDB(dsn string) *gorp.DbMap { @@ -14,6 +16,9 @@ func initDB(dsn string) *gorp.DbMap { if err != nil { panic(fmt.Sprintf("error making db connection: %q", err)) } + if _, err := c.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", migrationTable)); err != nil { + panic(fmt.Sprintf("failed to drop migration table: %v", err)) + } if err = c.DropTablesIfExists(); err != nil { panic(fmt.Sprintf("Unable to drop database tables: %v", err)) } @@ -119,3 +124,84 @@ func TestMigrateClientMetadata(t *testing.T) { } } } + +func TestMigrationNumber11(t *testing.T) { + dsn := os.Getenv("DEX_TEST_DSN") + if dsn == "" { + t.Skip("Test will not run without DEX_TEST_DSN environment variable.") + return + } + + tests := []struct { + sqlStmt string + wantEmails []string + wantError bool + }{ + { + sqlStmt: `INSERT INTO authd_user + (id, email, email_verified, display_name, admin, created_at) + VALUES + (1, 'Foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())), + (2, 'Bar@example.com', TRUE, 'foo', FALSE, extract(epoch from now())) + ;`, + wantEmails: []string{"foo@example.com", "bar@example.com"}, + wantError: false, + }, + { + sqlStmt: `INSERT INTO authd_user + (id, email, email_verified, display_name, admin, created_at) + VALUES + (1, 'Foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())), + (2, 'foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())), + (3, 'bar@example.com', TRUE, 'foo', FALSE, extract(epoch from now())) + ;`, + wantError: true, + }, + } + migrateN := func(dbMap *gorp.DbMap, n int) error { + nPerformed, err := MigrateMaxMigrations(dbMap, n) + if err == nil && n != nPerformed { + err = fmt.Errorf("expected to perform %d migrations, performed %d", n, nPerformed) + } + return err + } + + for i, tt := range tests { + err := func() error { + dbMap := initDB(dsn) + + nMigrations := 10 + if err := migrateN(dbMap, nMigrations); err != nil { + return fmt.Errorf("failed to perform initial migration: %v", err) + } + if _, err := dbMap.Exec(tt.sqlStmt); err != nil { + return fmt.Errorf("failed to insert users: %v", err) + } + if err := migrateN(dbMap, 1); err != nil { + if tt.wantError { + return nil + } + return fmt.Errorf("failed to perform migration: %v", err) + } + + if tt.wantError { + return fmt.Errorf("expected an error when migrating") + } + + var gotEmails []string + if _, err := dbMap.Select(&gotEmails, `SELECT email FROM authd_user;`); err != nil { + return fmt.Errorf("could not get user emails: %v", err) + } + + sort.Strings(tt.wantEmails) + sort.Strings(gotEmails) + if diff := pretty.Compare(tt.wantEmails, gotEmails); diff != "" { + return fmt.Errorf("wantEmails != gotEmails: %s", diff) + } + return nil + }() + if err != nil { + t.Errorf("case %d: %v", i, err) + } + } +} diff --git a/functional/repo/user_repo_test.go b/functional/repo/user_repo_test.go index 3fa699df..3dda44cc 100644 --- a/functional/repo/user_repo_test.go +++ b/functional/repo/user_repo_test.go @@ -104,6 +104,15 @@ func TestNewUser(t *testing.T) { }, err: user.ErrorDuplicateEmail, }, + { + user: user.User{ + ID: "ID-same", + Email: "email-1@example.com", + DisplayName: "A lower case version of the original email", + CreatedAt: now, + }, + err: user.ErrorDuplicateEmail, + }, { user: user.User{ Email: "AnotherEmail@example.com", @@ -421,12 +430,19 @@ func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int func TestGetByEmail(t *testing.T) { tests := []struct { - email string - wantErr error + email string + wantEmail string + wantErr error }{ { - email: "Email-1@example.com", - wantErr: nil, + email: "Email-1@example.com", + wantEmail: "Email-1@example.com", + wantErr: nil, + }, + { + email: "EMAIL-1@example.com", // Emails should be case insensitive. + wantEmail: "Email-1@example.com", + wantErr: nil, }, { email: "NoSuchEmail@example.com", @@ -446,9 +462,10 @@ func TestGetByEmail(t *testing.T) { if gotErr != nil { t.Errorf("case %d: want nil err:% q", i, gotErr) + continue } - if tt.email != gotUser.Email { + if tt.wantEmail != gotUser.Email { t.Errorf("case %d: want=%q, got=%q", i, tt.email, gotUser.Email) } }