This repository has been archived on 2022-08-17. You can view files and clone it, but cannot push or open issues or pull requests.
dex/db/migrate_test.go

224 lines
5.6 KiB
Go
Raw Normal View History

2015-08-20 04:10:36 +05:30
package db
import (
"fmt"
"os"
"sort"
"strconv"
2015-08-20 04:10:36 +05:30
"testing"
"github.com/go-gorp/gorp"
"github.com/kylelemons/godebug/pretty"
2015-08-20 04:10:36 +05:30
)
func initDB(dsn string) *gorp.DbMap {
c, err := NewConnection(Config{DSN: dsn})
if err != nil {
panic(fmt.Sprintf("error making db connection: %q", err))
}
if _, err := c.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", migrationTable)); err != nil {
panic(fmt.Sprintf("failed to drop migration table: %v", err))
}
2015-08-20 04:10:36 +05:30
if err = c.DropTablesIfExists(); err != nil {
panic(fmt.Sprintf("Unable to drop database tables: %v", err))
}
return c
}
// TestGetPlannedMigrations is a sanity check, ensuring that at least one
// migration can be found.
func TestGetPlannedMigrations(t *testing.T) {
dsn := os.Getenv("DEX_TEST_DSN")
if dsn == "" {
t.Skip("Test will not run without DEX_TEST_DSN environment variable.")
2015-08-20 04:10:36 +05:30
return
}
dbMap := initDB(dsn)
ms, err := GetPlannedMigrations(dbMap)
if err != nil {
pwd, err := os.Getwd()
t.Logf("pwd: %v", pwd)
t.Fatalf("unexpected err: %q", err)
}
if len(ms) == 0 {
t.Fatalf("expected non-empty migrations")
}
}
func TestMigrateClientMetadata(t *testing.T) {
2016-06-21 02:36:08 +05:30
// oldClientModel exists to model what the client model looked like at
// migration time. Without using this, the test fails because there's no
// columns for the new fields.
type oldClientModel struct {
ID string `db:"id"`
Secret []byte `db:"secret"`
Metadata string `db:"metadata"`
DexAdmin bool `db:"dex_admin"`
}
register(table{
name: clientTableName,
model: oldClientModel{},
autoinc: false,
pkey: []string{"id"},
})
dsn := os.Getenv("DEX_TEST_DSN")
if dsn == "" {
t.Skip("Test will not run without DEX_TEST_DSN environment variable.")
return
}
dbMap := initDB(dsn)
nMigrations := 9
n, err := MigrateMaxMigrations(dbMap, nMigrations)
if err != nil {
t.Fatalf("failed to perform initial migration: %v", err)
}
if n != nMigrations {
t.Fatalf("expected to perform %d migrations, got %d", nMigrations, n)
}
tests := []struct {
before string
after string
}{
// only update rows without a "redirect_uris" key
{
`{"redirectURLs":["foo"]}`,
`{"redirectURLs" : ["foo"], "redirect_uris" : ["foo"]}`,
},
{
`{"redirectURLs":["foo","bar"]}`,
`{"redirectURLs" : ["foo","bar"], "redirect_uris" : ["foo","bar"]}`,
},
{
`{"redirect_uris":["foo"],"another_field":8}`,
`{"redirect_uris":["foo"],"another_field":8}`,
},
{
`{"redirectURLs" : ["foo"], "redirect_uris" : ["foo"]}`,
`{"redirectURLs" : ["foo"], "redirect_uris" : ["foo"]}`,
},
}
for i, tt := range tests {
2016-06-21 02:36:08 +05:30
model := &oldClientModel{
ID: strconv.Itoa(i),
Secret: []byte("verysecret"),
Metadata: tt.before,
}
if err := dbMap.Insert(model); err != nil {
t.Fatalf("could not insert model: %v", err)
}
}
n, err = MigrateMaxMigrations(dbMap, 1)
if err != nil {
t.Fatalf("failed to perform initial migration: %v", err)
}
if n != 1 {
t.Fatalf("expected to perform 1 migration, got %d", n)
}
for i, tt := range tests {
id := strconv.Itoa(i)
2016-06-21 02:36:08 +05:30
m, err := dbMap.Get(oldClientModel{}, id)
if err != nil {
2016-02-13 02:48:49 +05:30
t.Errorf("case %d: failed to get model: %v", i, err)
continue
}
2016-06-21 02:36:08 +05:30
cim, ok := m.(*oldClientModel)
if !ok {
t.Errorf("case %d: unrecognized model type: %T", i, m)
continue
}
if cim.Metadata != tt.after {
t.Errorf("case %d: want=%q, got=%q", i, tt.after, cim.Metadata)
}
}
}
func TestMigrationNumber11(t *testing.T) {
dsn := os.Getenv("DEX_TEST_DSN")
if dsn == "" {
t.Skip("Test will not run without DEX_TEST_DSN environment variable.")
return
}
tests := []struct {
sqlStmt string
wantEmails []string
wantError bool
}{
{
sqlStmt: `INSERT INTO authd_user
(id, email, email_verified, display_name, admin, created_at)
VALUES
(1, 'Foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())),
(2, 'Bar@example.com', TRUE, 'foo', FALSE, extract(epoch from now()))
;`,
wantEmails: []string{"foo@example.com", "bar@example.com"},
wantError: false,
},
{
sqlStmt: `INSERT INTO authd_user
(id, email, email_verified, display_name, admin, created_at)
VALUES
(1, 'Foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())),
(2, 'foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())),
(3, 'bar@example.com', TRUE, 'foo', FALSE, extract(epoch from now()))
;`,
wantError: true,
},
}
migrateN := func(dbMap *gorp.DbMap, n int) error {
nPerformed, err := MigrateMaxMigrations(dbMap, n)
if err == nil && n != nPerformed {
err = fmt.Errorf("expected to perform %d migrations, performed %d", n, nPerformed)
}
return err
}
for i, tt := range tests {
err := func() error {
dbMap := initDB(dsn)
nMigrations := 10
if err := migrateN(dbMap, nMigrations); err != nil {
return fmt.Errorf("failed to perform initial migration: %v", err)
}
if _, err := dbMap.Exec(tt.sqlStmt); err != nil {
return fmt.Errorf("failed to insert users: %v", err)
}
if err := migrateN(dbMap, 1); err != nil {
if tt.wantError {
return nil
}
return fmt.Errorf("failed to perform migration: %v", err)
}
if tt.wantError {
return fmt.Errorf("expected an error when migrating")
}
var gotEmails []string
if _, err := dbMap.Select(&gotEmails, `SELECT email FROM authd_user;`); err != nil {
return fmt.Errorf("could not get user emails: %v", err)
}
sort.Strings(tt.wantEmails)
sort.Strings(gotEmails)
if diff := pretty.Compare(tt.wantEmails, gotEmails); diff != "" {
return fmt.Errorf("wantEmails != gotEmails: %s", diff)
}
return nil
}()
if err != nil {
t.Errorf("case %d: %v", i, err)
}
}
}