forked from mystiq/dex
storage: fix postgres timezone handling
Dex's Postgres client currently uses the `timestamp` datatype for storing times. This lops of timezones with no conversion, causing times to lose locality information. We could convert all times to UTC before storing them, but this is a backward incompatible change for upgrades, since the new version of dex would still be reading times from the database with no locality. Because of this intrinsic issue that current Postgres users don't save any timezone data, we chose to treat any existing installation as corrupted and change the datatype used for times to `timestamptz`. This is a breaking change, but it seems hard to offer an alternative that's both correct and backward compatible. Additionally, an internal flag has been added to SQL flavors, `supportsTimezones`. This allows us to handle SQLite3, which doesn't support timezones, while still storing timezones in other flavors. Flavors that don't support timezones are explicitly converted to UTC.
This commit is contained in:
parent
dd3133072c
commit
fd20b213bb
4 changed files with 121 additions and 31 deletions
|
@ -48,6 +48,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
|
||||||
{"PasswordCRUD", testPasswordCRUD},
|
{"PasswordCRUD", testPasswordCRUD},
|
||||||
{"KeysCRUD", testKeysCRUD},
|
{"KeysCRUD", testKeysCRUD},
|
||||||
{"GarbageCollection", testGC},
|
{"GarbageCollection", testGC},
|
||||||
|
{"TimezoneSupport", testTimezones},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -370,14 +371,23 @@ func testKeysCRUD(t *testing.T, s storage.Storage) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testGC(t *testing.T, s storage.Storage) {
|
func testGC(t *testing.T, s storage.Storage) {
|
||||||
n := time.Now().UTC()
|
est, err := time.LoadLocation("America/New_York")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
pst, err := time.LoadLocation("America/Los_Angeles")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiry := time.Now().In(est)
|
||||||
c := storage.AuthCode{
|
c := storage.AuthCode{
|
||||||
ID: storage.NewID(),
|
ID: storage.NewID(),
|
||||||
ClientID: "foobar",
|
ClientID: "foobar",
|
||||||
RedirectURI: "https://localhost:80/callback",
|
RedirectURI: "https://localhost:80/callback",
|
||||||
Nonce: "foobar",
|
Nonce: "foobar",
|
||||||
Scopes: []string{"openid", "email"},
|
Scopes: []string{"openid", "email"},
|
||||||
Expiry: n.Add(time.Second),
|
Expiry: expiry,
|
||||||
ConnectorID: "ldap",
|
ConnectorID: "ldap",
|
||||||
ConnectorData: []byte(`{"some":"data"}`),
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
Claims: storage.Claims{
|
Claims: storage.Claims{
|
||||||
|
@ -393,14 +403,21 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("failed creating auth code: %v", err)
|
t.Fatalf("failed creating auth code: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.GarbageCollect(n); err != nil {
|
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||||
t.Errorf("garbage collection failed: %v", err)
|
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||||
}
|
if err != nil {
|
||||||
if _, err := s.GetAuthCode(c.ID); err != nil {
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
} else {
|
||||||
|
if result.AuthCodes != 0 || result.AuthRequests != 0 {
|
||||||
|
t.Errorf("expected no garbage collection results, got %#v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := s.GetAuthCode(c.ID); err != nil {
|
||||||
|
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||||
t.Errorf("garbage collection failed: %v", err)
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
} else if r.AuthCodes != 1 {
|
} else if r.AuthCodes != 1 {
|
||||||
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
|
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
|
||||||
|
@ -422,7 +439,7 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
State: "bar",
|
State: "bar",
|
||||||
ForceApprovalPrompt: true,
|
ForceApprovalPrompt: true,
|
||||||
LoggedIn: true,
|
LoggedIn: true,
|
||||||
Expiry: n,
|
Expiry: expiry,
|
||||||
ConnectorID: "ldap",
|
ConnectorID: "ldap",
|
||||||
ConnectorData: []byte(`{"some":"data"}`),
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
Claims: storage.Claims{
|
Claims: storage.Claims{
|
||||||
|
@ -438,14 +455,21 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
t.Fatalf("failed creating auth request: %v", err)
|
t.Fatalf("failed creating auth request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.GarbageCollect(n); err != nil {
|
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||||
t.Errorf("garbage collection failed: %v", err)
|
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||||
}
|
if err != nil {
|
||||||
if _, err := s.GetAuthRequest(a.ID); err != nil {
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
} else {
|
||||||
|
if result.AuthCodes != 0 || result.AuthRequests != 0 {
|
||||||
|
t.Errorf("expected no garbage collection results, got %#v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := s.GetAuthRequest(a.ID); err != nil {
|
||||||
|
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||||
t.Errorf("garbage collection failed: %v", err)
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
} else if r.AuthRequests != 1 {
|
} else if r.AuthRequests != 1 {
|
||||||
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
|
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
|
||||||
|
@ -457,3 +481,49 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testTimezones tests that backends either fully support timezones or
|
||||||
|
// do the correct standardization.
|
||||||
|
func testTimezones(t *testing.T, s storage.Storage) {
|
||||||
|
est, err := time.LoadLocation("America/New_York")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Create an expiry with timezone info. Only expect backends to be
|
||||||
|
// accurate to the millisecond
|
||||||
|
expiry := time.Now().In(est).Round(time.Millisecond)
|
||||||
|
|
||||||
|
c := storage.AuthCode{
|
||||||
|
ID: storage.NewID(),
|
||||||
|
ClientID: "foobar",
|
||||||
|
RedirectURI: "https://localhost:80/callback",
|
||||||
|
Nonce: "foobar",
|
||||||
|
Scopes: []string{"openid", "email"},
|
||||||
|
Expiry: expiry,
|
||||||
|
ConnectorID: "ldap",
|
||||||
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: "1",
|
||||||
|
Username: "jane",
|
||||||
|
Email: "jane.doe@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Groups: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := s.CreateAuthCode(c); err != nil {
|
||||||
|
t.Fatalf("failed creating auth code: %v", err)
|
||||||
|
}
|
||||||
|
got, err := s.GetAuthCode(c.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get auth code: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that if the resulting time is converted to the same
|
||||||
|
// timezone, it's the same value. We DO NOT expect timezones
|
||||||
|
// to be preserved.
|
||||||
|
gotTime := got.Expiry.In(est)
|
||||||
|
wantTime := expiry
|
||||||
|
if !gotTime.Equal(wantTime) {
|
||||||
|
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ func (c *conn) migrate() (int, error) {
|
||||||
_, err := c.Exec(`
|
_, err := c.Exec(`
|
||||||
create table if not exists migrations (
|
create table if not exists migrations (
|
||||||
num integer not null,
|
num integer not null,
|
||||||
at timestamp not null
|
at timestamptz not null
|
||||||
);
|
);
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -100,7 +100,7 @@ var migrations = []migration{
|
||||||
connector_id text not null,
|
connector_id text not null,
|
||||||
connector_data bytea,
|
connector_data bytea,
|
||||||
|
|
||||||
expiry timestamp not null
|
expiry timestamptz not null
|
||||||
);
|
);
|
||||||
|
|
||||||
create table auth_code (
|
create table auth_code (
|
||||||
|
@ -119,7 +119,7 @@ var migrations = []migration{
|
||||||
connector_id text not null,
|
connector_id text not null,
|
||||||
connector_data bytea,
|
connector_data bytea,
|
||||||
|
|
||||||
expiry timestamp not null
|
expiry timestamptz not null
|
||||||
);
|
);
|
||||||
|
|
||||||
create table refresh_token (
|
create table refresh_token (
|
||||||
|
@ -151,7 +151,7 @@ var migrations = []migration{
|
||||||
verification_keys bytea not null, -- JSON array
|
verification_keys bytea not null, -- JSON array
|
||||||
signing_key bytea not null, -- JSON object
|
signing_key bytea not null, -- JSON object
|
||||||
signing_key_pub bytea not null, -- JSON object
|
signing_key_pub bytea not null, -- JSON object
|
||||||
next_rotation timestamp not null
|
next_rotation timestamptz not null
|
||||||
);
|
);
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
|
|
|
@ -4,6 +4,7 @@ package sql
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Sirupsen/logrus"
|
"github.com/Sirupsen/logrus"
|
||||||
"github.com/cockroachdb/cockroach-go/crdb"
|
"github.com/cockroachdb/cockroach-go/crdb"
|
||||||
|
@ -28,6 +29,9 @@ type flavor struct {
|
||||||
//
|
//
|
||||||
// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44
|
// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44
|
||||||
executeTx func(db *sql.DB, fn func(*sql.Tx) error) error
|
executeTx func(db *sql.DB, fn func(*sql.Tx) error) error
|
||||||
|
|
||||||
|
// Does the flavor support timezones?
|
||||||
|
supportsTimezones bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// A regexp with a replacement string.
|
// A regexp with a replacement string.
|
||||||
|
@ -69,6 +73,8 @@ var (
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
},
|
},
|
||||||
|
|
||||||
|
supportsTimezones: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
flavorSQLite3 = flavor{
|
flavorSQLite3 = flavor{
|
||||||
|
@ -80,7 +86,7 @@ var (
|
||||||
{matchLiteral("boolean"), "integer"},
|
{matchLiteral("boolean"), "integer"},
|
||||||
// Translate other types.
|
// Translate other types.
|
||||||
{matchLiteral("bytea"), "blob"},
|
{matchLiteral("bytea"), "blob"},
|
||||||
// {matchLiteral("timestamp"), "integer"},
|
{matchLiteral("timestamptz"), "timestamp"},
|
||||||
// SQLite doesn't have a "now()" method, replace with "date('now')"
|
// SQLite doesn't have a "now()" method, replace with "date('now')"
|
||||||
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
|
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
|
||||||
},
|
},
|
||||||
|
@ -107,6 +113,22 @@ func (f flavor) translate(query string) string {
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// translateArgs translates query parameters that may be unique to
|
||||||
|
// a specific SQL flavor. For example, standardizing "time.Time"
|
||||||
|
// types to UTC for clients that don't provide timezone support.
|
||||||
|
func (c *conn) translateArgs(args []interface{}) []interface{} {
|
||||||
|
if c.flavor.supportsTimezones {
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, arg := range args {
|
||||||
|
if t, ok := arg.(time.Time); ok {
|
||||||
|
args[i] = t.UTC()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
// conn is the main database connection.
|
// conn is the main database connection.
|
||||||
type conn struct {
|
type conn struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -122,17 +144,17 @@ func (c *conn) Close() error {
|
||||||
|
|
||||||
func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||||
query = c.flavor.translate(query)
|
query = c.flavor.translate(query)
|
||||||
return c.db.Exec(query, args...)
|
return c.db.Exec(query, c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
query = c.flavor.translate(query)
|
query = c.flavor.translate(query)
|
||||||
return c.db.Query(query, args...)
|
return c.db.Query(query, c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
|
func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||||
query = c.flavor.translate(query)
|
query = c.flavor.translate(query)
|
||||||
return c.db.QueryRow(query, args...)
|
return c.db.QueryRow(query, c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecTx runs a method which operates on a transaction.
|
// ExecTx runs a method which operates on a transaction.
|
||||||
|
@ -163,15 +185,15 @@ type trans struct {
|
||||||
|
|
||||||
func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) {
|
func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||||
query = t.c.flavor.translate(query)
|
query = t.c.flavor.translate(query)
|
||||||
return t.tx.Exec(query, args...)
|
return t.tx.Exec(query, t.c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
query = t.c.flavor.translate(query)
|
query = t.c.flavor.translate(query)
|
||||||
return t.tx.Query(query, args...)
|
return t.tx.Query(query, t.c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row {
|
func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||||
query = t.c.flavor.translate(query)
|
query = t.c.flavor.translate(query)
|
||||||
return t.tx.QueryRow(query, args...)
|
return t.tx.QueryRow(query, t.c.translateArgs(args)...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,11 +44,9 @@ type GCResult struct {
|
||||||
AuthCodes int64
|
AuthCodes int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Storage is the storage interface used by the server. Implementations, at minimum
|
// Storage is the storage interface used by the server. Implementations are
|
||||||
// require compare-and-swap atomic actions.
|
// required to be able to perform atomic compare-and-swap updates and either
|
||||||
//
|
// support timezones or standardize on UTC.
|
||||||
// Implementations are expected to perform their own garbage collection of
|
|
||||||
// expired objects (expect keys, which are handled by the server).
|
|
||||||
type Storage interface {
|
type Storage interface {
|
||||||
Close() error
|
Close() error
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue