diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index fac7553f..76d39780 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -13,10 +13,24 @@ import ( var neverExpire = time.Now().Add(time.Hour * 24 * 365 * 100) +// StorageFactory is a method for creating a new storage. The returned storage sould be initialized +// but shouldn't have any existing data in it. +type StorageFactory func() storage.Storage + // RunTestSuite runs a set of conformance tests against a storage. -func RunTestSuite(t *testing.T, s storage.Storage) { - t.Run("UpdateAuthRequest", func(t *testing.T) { testUpdateAuthRequest(t, s) }) - t.Run("CreateRefresh", func(t *testing.T) { testCreateRefresh(t, s) }) +func RunTestSuite(t *testing.T, sf StorageFactory) { + tests := []struct { + name string + run func(t *testing.T, s storage.Storage) + }{ + {"UpdateAuthRequest", testUpdateAuthRequest}, + {"CreateRefresh", testCreateRefresh}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.run(t, sf()) + }) + } } func testUpdateAuthRequest(t *testing.T, s storage.Storage) { diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index cb3cbe2f..c0011f39 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -4,6 +4,7 @@ import ( "os" "testing" + "github.com/coreos/dex/storage" "github.com/coreos/dex/storage/conformance" ) @@ -73,5 +74,8 @@ func TestURLFor(t *testing.T) { func TestStorage(t *testing.T) { client := loadClient(t) - conformance.RunTestSuite(t, client) + conformance.RunTestSuite(t, func() storage.Storage { + // TODO(erichiang): Tear down namespaces between each iteration. + return client + }) } diff --git a/storage/memory/memory_test.go b/storage/memory/memory_test.go index 73f2a790..56c5f93d 100644 --- a/storage/memory/memory_test.go +++ b/storage/memory/memory_test.go @@ -7,6 +7,5 @@ import ( ) func TestStorage(t *testing.T) { - s := New() - conformance.RunTestSuite(t, s) + conformance.RunTestSuite(t, New) } diff --git a/storage/sql/config_test.go b/storage/sql/config_test.go index e4b317b4..55a9d69c 100644 --- a/storage/sql/config_test.go +++ b/storage/sql/config_test.go @@ -1 +1,90 @@ package sql + +import ( + "fmt" + "os" + "runtime" + "testing" + "time" + + "github.com/coreos/dex/storage" + "github.com/coreos/dex/storage/conformance" +) + +func withTimeout(t time.Duration, f func()) { + c := make(chan struct{}) + defer close(c) + + go func() { + select { + case <-c: + case <-time.After(t): + // Dump a stack trace of the program. Useful for debugging deadlocks. + buf := make([]byte, 2<<20) + fmt.Fprintf(os.Stderr, "%s\n", buf[:runtime.Stack(buf, true)]) + panic("test took too long") + } + }() + + f() +} + +func cleanDB(c *conn) error { + _, err := c.Exec(` + delete from client; + delete from auth_request; + delete from auth_code; + delete from refresh_token; + delete from keys; + `) + return err +} + +func TestSQLite3(t *testing.T) { + newStorage := func() storage.Storage { + // NOTE(ericchiang): In memory means we only get one connection at a time. If we + // ever write tests that require using multiple connections, for instance to test + // transactions, we need to move to a file based system. + s := &SQLite3{":memory:"} + conn, err := s.open() + if err != nil { + t.Fatal(err) + } + return conn + } + + withTimeout(time.Second*10, func() { + conformance.RunTestSuite(t, newStorage) + }) +} + +func TestPostgres(t *testing.T) { + if os.Getenv("DEX_POSTGRES_HOST") == "" { + t.Skip("postgres envs not set, skipping tests") + } + p := Postgres{ + Database: os.Getenv("DEX_POSTGRES_DATABASE"), + User: os.Getenv("DEX_POSTGRES_USER"), + Password: os.Getenv("DEX_POSTGRES_PASSWORD"), + Host: os.Getenv("DEX_POSTGRES_HOST"), + SSL: PostgresSSL{ + Mode: sslDisable, // Postgres container doesn't support SSL. + }, + ConnectionTimeout: 5, + } + conn, err := p.open() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + newStorage := func() storage.Storage { + if err := cleanDB(conn); err != nil { + t.Fatal(err) + } + return conn + } + withTimeout(time.Minute*1, func() { + conformance.RunTestSuite(t, newStorage) + }) +}