diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 5ebbf31c..490ce1d9 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -48,6 +48,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) { {"PasswordCRUD", testPasswordCRUD}, {"KeysCRUD", testKeysCRUD}, {"GarbageCollection", testGC}, + {"TimezoneSupport", testTimezones}, }) } @@ -370,14 +371,23 @@ func testKeysCRUD(t *testing.T, s storage.Storage) { } func testGC(t *testing.T, s storage.Storage) { - n := time.Now().UTC() + est, err := time.LoadLocation("America/New_York") + if err != nil { + t.Fatal(err) + } + pst, err := time.LoadLocation("America/Los_Angeles") + if err != nil { + t.Fatal(err) + } + + expiry := time.Now().In(est) c := storage.AuthCode{ ID: storage.NewID(), ClientID: "foobar", RedirectURI: "https://localhost:80/callback", Nonce: "foobar", Scopes: []string{"openid", "email"}, - Expiry: n.Add(time.Second), + Expiry: expiry, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ @@ -393,14 +403,21 @@ func testGC(t *testing.T, s storage.Storage) { t.Fatalf("failed creating auth code: %v", err) } - if _, err := s.GarbageCollect(n); err != nil { - t.Errorf("garbage collection failed: %v", err) - } - if _, err := s.GetAuthCode(c.ID); err != nil { - t.Errorf("expected to be able to get auth code after GC: %v", err) + for _, tz := range []*time.Location{time.UTC, est, pst} { + result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + if err != nil { + t.Errorf("garbage collection failed: %v", err) + } else { + if result.AuthCodes != 0 || result.AuthRequests != 0 { + t.Errorf("expected no garbage collection results, got %#v", result) + } + } + if _, err := s.GetAuthCode(c.ID); err != nil { + t.Errorf("expected to be able to get auth code after GC: %v", err) + } } - if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { + if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthCodes != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) @@ -422,7 +439,7 @@ func testGC(t *testing.T, s storage.Storage) { State: "bar", ForceApprovalPrompt: true, LoggedIn: true, - Expiry: n, + Expiry: expiry, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ @@ -438,14 +455,21 @@ func testGC(t *testing.T, s storage.Storage) { t.Fatalf("failed creating auth request: %v", err) } - if _, err := s.GarbageCollect(n); err != nil { - t.Errorf("garbage collection failed: %v", err) - } - if _, err := s.GetAuthRequest(a.ID); err != nil { - t.Errorf("expected to be able to get auth code after GC: %v", err) + for _, tz := range []*time.Location{time.UTC, est, pst} { + result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + if err != nil { + t.Errorf("garbage collection failed: %v", err) + } else { + if result.AuthCodes != 0 || result.AuthRequests != 0 { + t.Errorf("expected no garbage collection results, got %#v", result) + } + } + if _, err := s.GetAuthRequest(a.ID); err != nil { + t.Errorf("expected to be able to get auth code after GC: %v", err) + } } - if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { + if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthRequests != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) @@ -457,3 +481,49 @@ func testGC(t *testing.T, s storage.Storage) { t.Errorf("expected storage.ErrNotFound, got %v", err) } } + +// testTimezones tests that backends either fully support timezones or +// do the correct standardization. +func testTimezones(t *testing.T, s storage.Storage) { + est, err := time.LoadLocation("America/New_York") + if err != nil { + t.Fatal(err) + } + // Create an expiry with timezone info. Only expect backends to be + // accurate to the millisecond + expiry := time.Now().In(est).Round(time.Millisecond) + + c := storage.AuthCode{ + ID: storage.NewID(), + ClientID: "foobar", + RedirectURI: "https://localhost:80/callback", + Nonce: "foobar", + Scopes: []string{"openid", "email"}, + Expiry: expiry, + ConnectorID: "ldap", + ConnectorData: []byte(`{"some":"data"}`), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + } + if err := s.CreateAuthCode(c); err != nil { + t.Fatalf("failed creating auth code: %v", err) + } + got, err := s.GetAuthCode(c.ID) + if err != nil { + t.Fatalf("failed to get auth code: %v", err) + } + + // Ensure that if the resulting time is converted to the same + // timezone, it's the same value. We DO NOT expect timezones + // to be preserved. + gotTime := got.Expiry.In(est) + wantTime := expiry + if !gotTime.Equal(wantTime) { + t.Fatalf("expected expiry %v got %v", wantTime, gotTime) + } +} diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index d9c254d3..3bb410aa 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -9,7 +9,7 @@ func (c *conn) migrate() (int, error) { _, err := c.Exec(` create table if not exists migrations ( num integer not null, - at timestamp not null + at timestamptz not null ); `) if err != nil { @@ -100,7 +100,7 @@ var migrations = []migration{ connector_id text not null, connector_data bytea, - expiry timestamp not null + expiry timestamptz not null ); create table auth_code ( @@ -119,7 +119,7 @@ var migrations = []migration{ connector_id text not null, connector_data bytea, - expiry timestamp not null + expiry timestamptz not null ); create table refresh_token ( @@ -151,7 +151,7 @@ var migrations = []migration{ verification_keys bytea not null, -- JSON array signing_key bytea not null, -- JSON object signing_key_pub bytea not null, -- JSON object - next_rotation timestamp not null + next_rotation timestamptz not null ); `, }, diff --git a/storage/sql/sql.go b/storage/sql/sql.go index c1f57b0e..0d3026e4 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -4,6 +4,7 @@ package sql import ( "database/sql" "regexp" + "time" "github.com/Sirupsen/logrus" "github.com/cockroachdb/cockroach-go/crdb" @@ -28,6 +29,9 @@ type flavor struct { // // See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44 executeTx func(db *sql.DB, fn func(*sql.Tx) error) error + + // Does the flavor support timezones? + supportsTimezones bool } // A regexp with a replacement string. @@ -69,6 +73,8 @@ var ( } return tx.Commit() }, + + supportsTimezones: true, } flavorSQLite3 = flavor{ @@ -80,7 +86,7 @@ var ( {matchLiteral("boolean"), "integer"}, // Translate other types. {matchLiteral("bytea"), "blob"}, - // {matchLiteral("timestamp"), "integer"}, + {matchLiteral("timestamptz"), "timestamp"}, // SQLite doesn't have a "now()" method, replace with "date('now')" {regexp.MustCompile(`\bnow\(\)`), "date('now')"}, }, @@ -107,6 +113,22 @@ func (f flavor) translate(query string) string { return query } +// translateArgs translates query parameters that may be unique to +// a specific SQL flavor. For example, standardizing "time.Time" +// types to UTC for clients that don't provide timezone support. +func (c *conn) translateArgs(args []interface{}) []interface{} { + if c.flavor.supportsTimezones { + return args + } + + for i, arg := range args { + if t, ok := arg.(time.Time); ok { + args[i] = t.UTC() + } + } + return args +} + // conn is the main database connection. type conn struct { db *sql.DB @@ -122,17 +144,17 @@ func (c *conn) Close() error { func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) { query = c.flavor.translate(query) - return c.db.Exec(query, args...) + return c.db.Exec(query, c.translateArgs(args)...) } func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) { query = c.flavor.translate(query) - return c.db.Query(query, args...) + return c.db.Query(query, c.translateArgs(args)...) } func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row { query = c.flavor.translate(query) - return c.db.QueryRow(query, args...) + return c.db.QueryRow(query, c.translateArgs(args)...) } // ExecTx runs a method which operates on a transaction. @@ -163,15 +185,15 @@ type trans struct { func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) { query = t.c.flavor.translate(query) - return t.tx.Exec(query, args...) + return t.tx.Exec(query, t.c.translateArgs(args)...) } func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) { query = t.c.flavor.translate(query) - return t.tx.Query(query, args...) + return t.tx.Query(query, t.c.translateArgs(args)...) } func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row { query = t.c.flavor.translate(query) - return t.tx.QueryRow(query, args...) + return t.tx.QueryRow(query, t.c.translateArgs(args)...) } diff --git a/storage/storage.go b/storage/storage.go index cd480326..22a9ea50 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -44,11 +44,9 @@ type GCResult struct { AuthCodes int64 } -// Storage is the storage interface used by the server. Implementations, at minimum -// require compare-and-swap atomic actions. -// -// Implementations are expected to perform their own garbage collection of -// expired objects (expect keys, which are handled by the server). +// Storage is the storage interface used by the server. Implementations are +// required to be able to perform atomic compare-and-swap updates and either +// support timezones or standardize on UTC. type Storage interface { Close() error