Merge pull request #749 from ericchiang/postgres-timezones

storage: fix postgres timezone handling
This commit is contained in:
Eric Chiang 2016-12-16 15:36:12 -08:00 committed by GitHub
commit c58dd948c7
4 changed files with 121 additions and 31 deletions

View file

@ -48,6 +48,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
{"PasswordCRUD", testPasswordCRUD}, {"PasswordCRUD", testPasswordCRUD},
{"KeysCRUD", testKeysCRUD}, {"KeysCRUD", testKeysCRUD},
{"GarbageCollection", testGC}, {"GarbageCollection", testGC},
{"TimezoneSupport", testTimezones},
}) })
} }
@ -370,14 +371,23 @@ func testKeysCRUD(t *testing.T, s storage.Storage) {
} }
func testGC(t *testing.T, s storage.Storage) { func testGC(t *testing.T, s storage.Storage) {
n := time.Now().UTC() est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
}
pst, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
t.Fatal(err)
}
expiry := time.Now().In(est)
c := storage.AuthCode{ c := storage.AuthCode{
ID: storage.NewID(), ID: storage.NewID(),
ClientID: "foobar", ClientID: "foobar",
RedirectURI: "https://localhost:80/callback", RedirectURI: "https://localhost:80/callback",
Nonce: "foobar", Nonce: "foobar",
Scopes: []string{"openid", "email"}, Scopes: []string{"openid", "email"},
Expiry: n.Add(time.Second), Expiry: expiry,
ConnectorID: "ldap", ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`), ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{ Claims: storage.Claims{
@ -393,14 +403,21 @@ func testGC(t *testing.T, s storage.Storage) {
t.Fatalf("failed creating auth code: %v", err) t.Fatalf("failed creating auth code: %v", err)
} }
if _, err := s.GarbageCollect(n); err != nil { for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err) t.Errorf("garbage collection failed: %v", err)
} else {
if result.AuthCodes != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
} }
if _, err := s.GetAuthCode(c.ID); err != nil { if _, err := s.GetAuthCode(c.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err) t.Errorf("expected to be able to get auth code after GC: %v", err)
} }
}
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err) t.Errorf("garbage collection failed: %v", err)
} else if r.AuthCodes != 1 { } else if r.AuthCodes != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
@ -422,7 +439,7 @@ func testGC(t *testing.T, s storage.Storage) {
State: "bar", State: "bar",
ForceApprovalPrompt: true, ForceApprovalPrompt: true,
LoggedIn: true, LoggedIn: true,
Expiry: n, Expiry: expiry,
ConnectorID: "ldap", ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`), ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{ Claims: storage.Claims{
@ -438,14 +455,21 @@ func testGC(t *testing.T, s storage.Storage) {
t.Fatalf("failed creating auth request: %v", err) t.Fatalf("failed creating auth request: %v", err)
} }
if _, err := s.GarbageCollect(n); err != nil { for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err) t.Errorf("garbage collection failed: %v", err)
} else {
if result.AuthCodes != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
} }
if _, err := s.GetAuthRequest(a.ID); err != nil { if _, err := s.GetAuthRequest(a.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err) t.Errorf("expected to be able to get auth code after GC: %v", err)
} }
}
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err) t.Errorf("garbage collection failed: %v", err)
} else if r.AuthRequests != 1 { } else if r.AuthRequests != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
@ -457,3 +481,49 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected storage.ErrNotFound, got %v", err) t.Errorf("expected storage.ErrNotFound, got %v", err)
} }
} }
// testTimezones tests that backends either fully support timezones or
// do the correct standardization.
func testTimezones(t *testing.T, s storage.Storage) {
est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
}
// Create an expiry with timezone info. Only expect backends to be
// accurate to the millisecond
expiry := time.Now().In(est).Round(time.Millisecond)
c := storage.AuthCode{
ID: storage.NewID(),
ClientID: "foobar",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: expiry,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}
if err := s.CreateAuthCode(c); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
got, err := s.GetAuthCode(c.ID)
if err != nil {
t.Fatalf("failed to get auth code: %v", err)
}
// Ensure that if the resulting time is converted to the same
// timezone, it's the same value. We DO NOT expect timezones
// to be preserved.
gotTime := got.Expiry.In(est)
wantTime := expiry
if !gotTime.Equal(wantTime) {
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
}
}

View file

@ -9,7 +9,7 @@ func (c *conn) migrate() (int, error) {
_, err := c.Exec(` _, err := c.Exec(`
create table if not exists migrations ( create table if not exists migrations (
num integer not null, num integer not null,
at timestamp not null at timestamptz not null
); );
`) `)
if err != nil { if err != nil {
@ -100,7 +100,7 @@ var migrations = []migration{
connector_id text not null, connector_id text not null,
connector_data bytea, connector_data bytea,
expiry timestamp not null expiry timestamptz not null
); );
create table auth_code ( create table auth_code (
@ -119,7 +119,7 @@ var migrations = []migration{
connector_id text not null, connector_id text not null,
connector_data bytea, connector_data bytea,
expiry timestamp not null expiry timestamptz not null
); );
create table refresh_token ( create table refresh_token (
@ -151,7 +151,7 @@ var migrations = []migration{
verification_keys bytea not null, -- JSON array verification_keys bytea not null, -- JSON array
signing_key bytea not null, -- JSON object signing_key bytea not null, -- JSON object
signing_key_pub bytea not null, -- JSON object signing_key_pub bytea not null, -- JSON object
next_rotation timestamp not null next_rotation timestamptz not null
); );
`, `,
}, },

View file

@ -4,6 +4,7 @@ package sql
import ( import (
"database/sql" "database/sql"
"regexp" "regexp"
"time"
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
"github.com/cockroachdb/cockroach-go/crdb" "github.com/cockroachdb/cockroach-go/crdb"
@ -28,6 +29,9 @@ type flavor struct {
// //
// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44 // See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44
executeTx func(db *sql.DB, fn func(*sql.Tx) error) error executeTx func(db *sql.DB, fn func(*sql.Tx) error) error
// Does the flavor support timezones?
supportsTimezones bool
} }
// A regexp with a replacement string. // A regexp with a replacement string.
@ -69,6 +73,8 @@ var (
} }
return tx.Commit() return tx.Commit()
}, },
supportsTimezones: true,
} }
flavorSQLite3 = flavor{ flavorSQLite3 = flavor{
@ -80,7 +86,7 @@ var (
{matchLiteral("boolean"), "integer"}, {matchLiteral("boolean"), "integer"},
// Translate other types. // Translate other types.
{matchLiteral("bytea"), "blob"}, {matchLiteral("bytea"), "blob"},
// {matchLiteral("timestamp"), "integer"}, {matchLiteral("timestamptz"), "timestamp"},
// SQLite doesn't have a "now()" method, replace with "date('now')" // SQLite doesn't have a "now()" method, replace with "date('now')"
{regexp.MustCompile(`\bnow\(\)`), "date('now')"}, {regexp.MustCompile(`\bnow\(\)`), "date('now')"},
}, },
@ -107,6 +113,22 @@ func (f flavor) translate(query string) string {
return query return query
} }
// translateArgs translates query parameters that may be unique to
// a specific SQL flavor. For example, standardizing "time.Time"
// types to UTC for clients that don't provide timezone support.
func (c *conn) translateArgs(args []interface{}) []interface{} {
if c.flavor.supportsTimezones {
return args
}
for i, arg := range args {
if t, ok := arg.(time.Time); ok {
args[i] = t.UTC()
}
}
return args
}
// conn is the main database connection. // conn is the main database connection.
type conn struct { type conn struct {
db *sql.DB db *sql.DB
@ -122,17 +144,17 @@ func (c *conn) Close() error {
func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) { func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) {
query = c.flavor.translate(query) query = c.flavor.translate(query)
return c.db.Exec(query, args...) return c.db.Exec(query, c.translateArgs(args)...)
} }
func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) { func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = c.flavor.translate(query) query = c.flavor.translate(query)
return c.db.Query(query, args...) return c.db.Query(query, c.translateArgs(args)...)
} }
func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row { func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
query = c.flavor.translate(query) query = c.flavor.translate(query)
return c.db.QueryRow(query, args...) return c.db.QueryRow(query, c.translateArgs(args)...)
} }
// ExecTx runs a method which operates on a transaction. // ExecTx runs a method which operates on a transaction.
@ -163,15 +185,15 @@ type trans struct {
func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) { func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) {
query = t.c.flavor.translate(query) query = t.c.flavor.translate(query)
return t.tx.Exec(query, args...) return t.tx.Exec(query, t.c.translateArgs(args)...)
} }
func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) { func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = t.c.flavor.translate(query) query = t.c.flavor.translate(query)
return t.tx.Query(query, args...) return t.tx.Query(query, t.c.translateArgs(args)...)
} }
func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row { func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row {
query = t.c.flavor.translate(query) query = t.c.flavor.translate(query)
return t.tx.QueryRow(query, args...) return t.tx.QueryRow(query, t.c.translateArgs(args)...)
} }

View file

@ -44,11 +44,9 @@ type GCResult struct {
AuthCodes int64 AuthCodes int64
} }
// Storage is the storage interface used by the server. Implementations, at minimum // Storage is the storage interface used by the server. Implementations are
// require compare-and-swap atomic actions. // required to be able to perform atomic compare-and-swap updates and either
// // support timezones or standardize on UTC.
// Implementations are expected to perform their own garbage collection of
// expired objects (expect keys, which are handled by the server).
type Storage interface { type Storage interface {
Close() error Close() error