Merge pull request #749 from ericchiang/postgres-timezones
storage: fix postgres timezone handling
This commit is contained in:
commit
c58dd948c7
4 changed files with 121 additions and 31 deletions
|
@ -48,6 +48,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
|
|||
{"PasswordCRUD", testPasswordCRUD},
|
||||
{"KeysCRUD", testKeysCRUD},
|
||||
{"GarbageCollection", testGC},
|
||||
{"TimezoneSupport", testTimezones},
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -370,14 +371,23 @@ func testKeysCRUD(t *testing.T, s storage.Storage) {
|
|||
}
|
||||
|
||||
func testGC(t *testing.T, s storage.Storage) {
|
||||
n := time.Now().UTC()
|
||||
est, err := time.LoadLocation("America/New_York")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pst, err := time.LoadLocation("America/Los_Angeles")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expiry := time.Now().In(est)
|
||||
c := storage.AuthCode{
|
||||
ID: storage.NewID(),
|
||||
ClientID: "foobar",
|
||||
RedirectURI: "https://localhost:80/callback",
|
||||
Nonce: "foobar",
|
||||
Scopes: []string{"openid", "email"},
|
||||
Expiry: n.Add(time.Second),
|
||||
Expiry: expiry,
|
||||
ConnectorID: "ldap",
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
Claims: storage.Claims{
|
||||
|
@ -393,14 +403,21 @@ func testGC(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("failed creating auth code: %v", err)
|
||||
}
|
||||
|
||||
if _, err := s.GarbageCollect(n); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
}
|
||||
if _, err := s.GetAuthCode(c.ID); err != nil {
|
||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||
if err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else {
|
||||
if result.AuthCodes != 0 || result.AuthRequests != 0 {
|
||||
t.Errorf("expected no garbage collection results, got %#v", result)
|
||||
}
|
||||
}
|
||||
if _, err := s.GetAuthCode(c.ID); err != nil {
|
||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
||||
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else if r.AuthCodes != 1 {
|
||||
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
|
||||
|
@ -422,7 +439,7 @@ func testGC(t *testing.T, s storage.Storage) {
|
|||
State: "bar",
|
||||
ForceApprovalPrompt: true,
|
||||
LoggedIn: true,
|
||||
Expiry: n,
|
||||
Expiry: expiry,
|
||||
ConnectorID: "ldap",
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
Claims: storage.Claims{
|
||||
|
@ -438,14 +455,21 @@ func testGC(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("failed creating auth request: %v", err)
|
||||
}
|
||||
|
||||
if _, err := s.GarbageCollect(n); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
}
|
||||
if _, err := s.GetAuthRequest(a.ID); err != nil {
|
||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||
if err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else {
|
||||
if result.AuthCodes != 0 || result.AuthRequests != 0 {
|
||||
t.Errorf("expected no garbage collection results, got %#v", result)
|
||||
}
|
||||
}
|
||||
if _, err := s.GetAuthRequest(a.ID); err != nil {
|
||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
||||
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else if r.AuthRequests != 1 {
|
||||
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
|
||||
|
@ -457,3 +481,49 @@ func testGC(t *testing.T, s storage.Storage) {
|
|||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// testTimezones tests that backends either fully support timezones or
|
||||
// do the correct standardization.
|
||||
func testTimezones(t *testing.T, s storage.Storage) {
|
||||
est, err := time.LoadLocation("America/New_York")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Create an expiry with timezone info. Only expect backends to be
|
||||
// accurate to the millisecond
|
||||
expiry := time.Now().In(est).Round(time.Millisecond)
|
||||
|
||||
c := storage.AuthCode{
|
||||
ID: storage.NewID(),
|
||||
ClientID: "foobar",
|
||||
RedirectURI: "https://localhost:80/callback",
|
||||
Nonce: "foobar",
|
||||
Scopes: []string{"openid", "email"},
|
||||
Expiry: expiry,
|
||||
ConnectorID: "ldap",
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
Claims: storage.Claims{
|
||||
UserID: "1",
|
||||
Username: "jane",
|
||||
Email: "jane.doe@example.com",
|
||||
EmailVerified: true,
|
||||
Groups: []string{"a", "b"},
|
||||
},
|
||||
}
|
||||
if err := s.CreateAuthCode(c); err != nil {
|
||||
t.Fatalf("failed creating auth code: %v", err)
|
||||
}
|
||||
got, err := s.GetAuthCode(c.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get auth code: %v", err)
|
||||
}
|
||||
|
||||
// Ensure that if the resulting time is converted to the same
|
||||
// timezone, it's the same value. We DO NOT expect timezones
|
||||
// to be preserved.
|
||||
gotTime := got.Expiry.In(est)
|
||||
wantTime := expiry
|
||||
if !gotTime.Equal(wantTime) {
|
||||
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ func (c *conn) migrate() (int, error) {
|
|||
_, err := c.Exec(`
|
||||
create table if not exists migrations (
|
||||
num integer not null,
|
||||
at timestamp not null
|
||||
at timestamptz not null
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
|
@ -100,7 +100,7 @@ var migrations = []migration{
|
|||
connector_id text not null,
|
||||
connector_data bytea,
|
||||
|
||||
expiry timestamp not null
|
||||
expiry timestamptz not null
|
||||
);
|
||||
|
||||
create table auth_code (
|
||||
|
@ -119,7 +119,7 @@ var migrations = []migration{
|
|||
connector_id text not null,
|
||||
connector_data bytea,
|
||||
|
||||
expiry timestamp not null
|
||||
expiry timestamptz not null
|
||||
);
|
||||
|
||||
create table refresh_token (
|
||||
|
@ -151,7 +151,7 @@ var migrations = []migration{
|
|||
verification_keys bytea not null, -- JSON array
|
||||
signing_key bytea not null, -- JSON object
|
||||
signing_key_pub bytea not null, -- JSON object
|
||||
next_rotation timestamp not null
|
||||
next_rotation timestamptz not null
|
||||
);
|
||||
`,
|
||||
},
|
||||
|
|
|
@ -4,6 +4,7 @@ package sql
|
|||
import (
|
||||
"database/sql"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/cockroachdb/cockroach-go/crdb"
|
||||
|
@ -28,6 +29,9 @@ type flavor struct {
|
|||
//
|
||||
// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44
|
||||
executeTx func(db *sql.DB, fn func(*sql.Tx) error) error
|
||||
|
||||
// Does the flavor support timezones?
|
||||
supportsTimezones bool
|
||||
}
|
||||
|
||||
// A regexp with a replacement string.
|
||||
|
@ -69,6 +73,8 @@ var (
|
|||
}
|
||||
return tx.Commit()
|
||||
},
|
||||
|
||||
supportsTimezones: true,
|
||||
}
|
||||
|
||||
flavorSQLite3 = flavor{
|
||||
|
@ -80,7 +86,7 @@ var (
|
|||
{matchLiteral("boolean"), "integer"},
|
||||
// Translate other types.
|
||||
{matchLiteral("bytea"), "blob"},
|
||||
// {matchLiteral("timestamp"), "integer"},
|
||||
{matchLiteral("timestamptz"), "timestamp"},
|
||||
// SQLite doesn't have a "now()" method, replace with "date('now')"
|
||||
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
|
||||
},
|
||||
|
@ -107,6 +113,22 @@ func (f flavor) translate(query string) string {
|
|||
return query
|
||||
}
|
||||
|
||||
// translateArgs translates query parameters that may be unique to
|
||||
// a specific SQL flavor. For example, standardizing "time.Time"
|
||||
// types to UTC for clients that don't provide timezone support.
|
||||
func (c *conn) translateArgs(args []interface{}) []interface{} {
|
||||
if c.flavor.supportsTimezones {
|
||||
return args
|
||||
}
|
||||
|
||||
for i, arg := range args {
|
||||
if t, ok := arg.(time.Time); ok {
|
||||
args[i] = t.UTC()
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// conn is the main database connection.
|
||||
type conn struct {
|
||||
db *sql.DB
|
||||
|
@ -122,17 +144,17 @@ func (c *conn) Close() error {
|
|||
|
||||
func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
query = c.flavor.translate(query)
|
||||
return c.db.Exec(query, args...)
|
||||
return c.db.Exec(query, c.translateArgs(args)...)
|
||||
}
|
||||
|
||||
func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
query = c.flavor.translate(query)
|
||||
return c.db.Query(query, args...)
|
||||
return c.db.Query(query, c.translateArgs(args)...)
|
||||
}
|
||||
|
||||
func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||
query = c.flavor.translate(query)
|
||||
return c.db.QueryRow(query, args...)
|
||||
return c.db.QueryRow(query, c.translateArgs(args)...)
|
||||
}
|
||||
|
||||
// ExecTx runs a method which operates on a transaction.
|
||||
|
@ -163,15 +185,15 @@ type trans struct {
|
|||
|
||||
func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
query = t.c.flavor.translate(query)
|
||||
return t.tx.Exec(query, args...)
|
||||
return t.tx.Exec(query, t.c.translateArgs(args)...)
|
||||
}
|
||||
|
||||
func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
query = t.c.flavor.translate(query)
|
||||
return t.tx.Query(query, args...)
|
||||
return t.tx.Query(query, t.c.translateArgs(args)...)
|
||||
}
|
||||
|
||||
func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||
query = t.c.flavor.translate(query)
|
||||
return t.tx.QueryRow(query, args...)
|
||||
return t.tx.QueryRow(query, t.c.translateArgs(args)...)
|
||||
}
|
||||
|
|
|
@ -44,11 +44,9 @@ type GCResult struct {
|
|||
AuthCodes int64
|
||||
}
|
||||
|
||||
// Storage is the storage interface used by the server. Implementations, at minimum
|
||||
// require compare-and-swap atomic actions.
|
||||
//
|
||||
// Implementations are expected to perform their own garbage collection of
|
||||
// expired objects (expect keys, which are handled by the server).
|
||||
// Storage is the storage interface used by the server. Implementations are
|
||||
// required to be able to perform atomic compare-and-swap updates and either
|
||||
// support timezones or standardize on UTC.
|
||||
type Storage interface {
|
||||
Close() error
|
||||
|
||||
|
|
Reference in a new issue