Merge pull request #749 from ericchiang/postgres-timezones
storage: fix postgres timezone handling
This commit is contained in:
commit
c58dd948c7
4 changed files with 121 additions and 31 deletions
|
@ -48,6 +48,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
|
||||||
{"PasswordCRUD", testPasswordCRUD},
|
{"PasswordCRUD", testPasswordCRUD},
|
||||||
{"KeysCRUD", testKeysCRUD},
|
{"KeysCRUD", testKeysCRUD},
|
||||||
{"GarbageCollection", testGC},
|
{"GarbageCollection", testGC},
|
||||||
|
{"TimezoneSupport", testTimezones},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -370,14 +371,23 @@ func testKeysCRUD(t *testing.T, s storage.Storage) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testGC(t *testing.T, s storage.Storage) {
|
func testGC(t *testing.T, s storage.Storage) {
|
||||||
n := time.Now().UTC()
|
est, err := time.LoadLocation("America/New_York")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
pst, err := time.LoadLocation("America/Los_Angeles")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiry := time.Now().In(est)
|
||||||
c := storage.AuthCode{
|
c := storage.AuthCode{
|
||||||
ID: storage.NewID(),
|
ID: storage.NewID(),
|
||||||
ClientID: "foobar",
|
ClientID: "foobar",
|
||||||
RedirectURI: "https://localhost:80/callback",
|
RedirectURI: "https://localhost:80/callback",
|
||||||
Nonce: "foobar",
|
Nonce: "foobar",
|
||||||
Scopes: []string{"openid", "email"},
|
Scopes: []string{"openid", "email"},
|
||||||
Expiry: n.Add(time.Second),
|
Expiry: expiry,
|
||||||
ConnectorID: "ldap",
|
ConnectorID: "ldap",
|
||||||
ConnectorData: []byte(`{"some":"data"}`),
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
Claims: storage.Claims{
|
Claims: storage.Claims{
|
||||||
|
@ -393,14 +403,21 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("failed creating auth code: %v", err)
|
t.Fatalf("failed creating auth code: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.GarbageCollect(n); err != nil {
|
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||||
|
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||||
|
if err != nil {
|
||||||
t.Errorf("garbage collection failed: %v", err)
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
} else {
|
||||||
|
if result.AuthCodes != 0 || result.AuthRequests != 0 {
|
||||||
|
t.Errorf("expected no garbage collection results, got %#v", result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, err := s.GetAuthCode(c.ID); err != nil {
|
if _, err := s.GetAuthCode(c.ID); err != nil {
|
||||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||||
t.Errorf("garbage collection failed: %v", err)
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
} else if r.AuthCodes != 1 {
|
} else if r.AuthCodes != 1 {
|
||||||
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
|
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
|
||||||
|
@ -422,7 +439,7 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
State: "bar",
|
State: "bar",
|
||||||
ForceApprovalPrompt: true,
|
ForceApprovalPrompt: true,
|
||||||
LoggedIn: true,
|
LoggedIn: true,
|
||||||
Expiry: n,
|
Expiry: expiry,
|
||||||
ConnectorID: "ldap",
|
ConnectorID: "ldap",
|
||||||
ConnectorData: []byte(`{"some":"data"}`),
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
Claims: storage.Claims{
|
Claims: storage.Claims{
|
||||||
|
@ -438,14 +455,21 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("failed creating auth request: %v", err)
|
t.Fatalf("failed creating auth request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.GarbageCollect(n); err != nil {
|
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||||
|
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||||
|
if err != nil {
|
||||||
t.Errorf("garbage collection failed: %v", err)
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
} else {
|
||||||
|
if result.AuthCodes != 0 || result.AuthRequests != 0 {
|
||||||
|
t.Errorf("expected no garbage collection results, got %#v", result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, err := s.GetAuthRequest(a.ID); err != nil {
|
if _, err := s.GetAuthRequest(a.ID); err != nil {
|
||||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||||
t.Errorf("garbage collection failed: %v", err)
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
} else if r.AuthRequests != 1 {
|
} else if r.AuthRequests != 1 {
|
||||||
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
|
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
|
||||||
|
@ -457,3 +481,49 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testTimezones tests that backends either fully support timezones or
|
||||||
|
// do the correct standardization.
|
||||||
|
func testTimezones(t *testing.T, s storage.Storage) {
|
||||||
|
est, err := time.LoadLocation("America/New_York")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Create an expiry with timezone info. Only expect backends to be
|
||||||
|
// accurate to the millisecond
|
||||||
|
expiry := time.Now().In(est).Round(time.Millisecond)
|
||||||
|
|
||||||
|
c := storage.AuthCode{
|
||||||
|
ID: storage.NewID(),
|
||||||
|
ClientID: "foobar",
|
||||||
|
RedirectURI: "https://localhost:80/callback",
|
||||||
|
Nonce: "foobar",
|
||||||
|
Scopes: []string{"openid", "email"},
|
||||||
|
Expiry: expiry,
|
||||||
|
ConnectorID: "ldap",
|
||||||
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: "1",
|
||||||
|
Username: "jane",
|
||||||
|
Email: "jane.doe@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Groups: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := s.CreateAuthCode(c); err != nil {
|
||||||
|
t.Fatalf("failed creating auth code: %v", err)
|
||||||
|
}
|
||||||
|
got, err := s.GetAuthCode(c.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get auth code: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that if the resulting time is converted to the same
|
||||||
|
// timezone, it's the same value. We DO NOT expect timezones
|
||||||
|
// to be preserved.
|
||||||
|
gotTime := got.Expiry.In(est)
|
||||||
|
wantTime := expiry
|
||||||
|
if !gotTime.Equal(wantTime) {
|
||||||
|
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ func (c *conn) migrate() (int, error) {
|
||||||
_, err := c.Exec(`
|
_, err := c.Exec(`
|
||||||
create table if not exists migrations (
|
create table if not exists migrations (
|
||||||
num integer not null,
|
num integer not null,
|
||||||
at timestamp not null
|
at timestamptz not null
|
||||||
);
|
);
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -100,7 +100,7 @@ var migrations = []migration{
|
||||||
connector_id text not null,
|
connector_id text not null,
|
||||||
connector_data bytea,
|
connector_data bytea,
|
||||||
|
|
||||||
expiry timestamp not null
|
expiry timestamptz not null
|
||||||
);
|
);
|
||||||
|
|
||||||
create table auth_code (
|
create table auth_code (
|
||||||
|
@ -119,7 +119,7 @@ var migrations = []migration{
|
||||||
connector_id text not null,
|
connector_id text not null,
|
||||||
connector_data bytea,
|
connector_data bytea,
|
||||||
|
|
||||||
expiry timestamp not null
|
expiry timestamptz not null
|
||||||
);
|
);
|
||||||
|
|
||||||
create table refresh_token (
|
create table refresh_token (
|
||||||
|
@ -151,7 +151,7 @@ var migrations = []migration{
|
||||||
verification_keys bytea not null, -- JSON array
|
verification_keys bytea not null, -- JSON array
|
||||||
signing_key bytea not null, -- JSON object
|
signing_key bytea not null, -- JSON object
|
||||||
signing_key_pub bytea not null, -- JSON object
|
signing_key_pub bytea not null, -- JSON object
|
||||||
next_rotation timestamp not null
|
next_rotation timestamptz not null
|
||||||
);
|
);
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
|
|
|
@ -4,6 +4,7 @@ package sql
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Sirupsen/logrus"
|
"github.com/Sirupsen/logrus"
|
||||||
"github.com/cockroachdb/cockroach-go/crdb"
|
"github.com/cockroachdb/cockroach-go/crdb"
|
||||||
|
@ -28,6 +29,9 @@ type flavor struct {
|
||||||
//
|
//
|
||||||
// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44
|
// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44
|
||||||
executeTx func(db *sql.DB, fn func(*sql.Tx) error) error
|
executeTx func(db *sql.DB, fn func(*sql.Tx) error) error
|
||||||
|
|
||||||
|
// Does the flavor support timezones?
|
||||||
|
supportsTimezones bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// A regexp with a replacement string.
|
// A regexp with a replacement string.
|
||||||
|
@ -69,6 +73,8 @@ var (
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
},
|
},
|
||||||
|
|
||||||
|
supportsTimezones: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
flavorSQLite3 = flavor{
|
flavorSQLite3 = flavor{
|
||||||
|
@ -80,7 +86,7 @@ var (
|
||||||
{matchLiteral("boolean"), "integer"},
|
{matchLiteral("boolean"), "integer"},
|
||||||
// Translate other types.
|
// Translate other types.
|
||||||
{matchLiteral("bytea"), "blob"},
|
{matchLiteral("bytea"), "blob"},
|
||||||
// {matchLiteral("timestamp"), "integer"},
|
{matchLiteral("timestamptz"), "timestamp"},
|
||||||
// SQLite doesn't have a "now()" method, replace with "date('now')"
|
// SQLite doesn't have a "now()" method, replace with "date('now')"
|
||||||
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
|
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
|
||||||
},
|
},
|
||||||
|
@ -107,6 +113,22 @@ func (f flavor) translate(query string) string {
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// translateArgs translates query parameters that may be unique to
|
||||||
|
// a specific SQL flavor. For example, standardizing "time.Time"
|
||||||
|
// types to UTC for clients that don't provide timezone support.
|
||||||
|
func (c *conn) translateArgs(args []interface{}) []interface{} {
|
||||||
|
if c.flavor.supportsTimezones {
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, arg := range args {
|
||||||
|
if t, ok := arg.(time.Time); ok {
|
||||||
|
args[i] = t.UTC()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
// conn is the main database connection.
|
// conn is the main database connection.
|
||||||
type conn struct {
|
type conn struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -122,17 +144,17 @@ func (c *conn) Close() error {
|
||||||
|
|
||||||
func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||||
query = c.flavor.translate(query)
|
query = c.flavor.translate(query)
|
||||||
return c.db.Exec(query, args...)
|
return c.db.Exec(query, c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
query = c.flavor.translate(query)
|
query = c.flavor.translate(query)
|
||||||
return c.db.Query(query, args...)
|
return c.db.Query(query, c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
|
func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||||
query = c.flavor.translate(query)
|
query = c.flavor.translate(query)
|
||||||
return c.db.QueryRow(query, args...)
|
return c.db.QueryRow(query, c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecTx runs a method which operates on a transaction.
|
// ExecTx runs a method which operates on a transaction.
|
||||||
|
@ -163,15 +185,15 @@ type trans struct {
|
||||||
|
|
||||||
func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) {
|
func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||||
query = t.c.flavor.translate(query)
|
query = t.c.flavor.translate(query)
|
||||||
return t.tx.Exec(query, args...)
|
return t.tx.Exec(query, t.c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
query = t.c.flavor.translate(query)
|
query = t.c.flavor.translate(query)
|
||||||
return t.tx.Query(query, args...)
|
return t.tx.Query(query, t.c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row {
|
func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||||
query = t.c.flavor.translate(query)
|
query = t.c.flavor.translate(query)
|
||||||
return t.tx.QueryRow(query, args...)
|
return t.tx.QueryRow(query, t.c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,11 +44,9 @@ type GCResult struct {
|
||||||
AuthCodes int64
|
AuthCodes int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Storage is the storage interface used by the server. Implementations, at minimum
|
// Storage is the storage interface used by the server. Implementations are
|
||||||
// require compare-and-swap atomic actions.
|
// required to be able to perform atomic compare-and-swap updates and either
|
||||||
//
|
// support timezones or standardize on UTC.
|
||||||
// Implementations are expected to perform their own garbage collection of
|
|
||||||
// expired objects (expect keys, which are handled by the server).
|
|
||||||
type Storage interface {
|
type Storage interface {
|
||||||
Close() error
|
Close() error
|
||||||
|
|
||||||
|
|
Reference in a new issue