forked from mystiq/dex
postgres: refactor error handling to fix retrying
prior to this change, many of the functions in the ExecTx callback would wrap the error before returning it. this made it impossible to check for the error code. instead, the error wrapping has been moved to be external to the `ExecTx` callback, so that the error code can be checked and serialization failures can be retried.
This commit is contained in:
parent
5d67da1472
commit
587081a643
2 changed files with 152 additions and 88 deletions
|
@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
|||
}
|
||||
|
||||
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
r, err := getAuthRequest(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -144,6 +144,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update auth_request
|
||||
set
|
||||
|
@ -163,16 +164,26 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
|
|||
a.ConnectorID, a.ConnectorData,
|
||||
a.Expiry, r.ID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update auth request: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||
return getAuthRequest(c, id)
|
||||
req, err := getAuthRequest(c, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.AuthRequest{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
|
||||
|
@ -192,10 +203,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
|
|||
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return a, storage.ErrNotFound
|
||||
}
|
||||
return a, fmt.Errorf("select auth request: %v", err)
|
||||
return a, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
@ -269,20 +277,22 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
|||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert refresh_token: %v", err)
|
||||
return fmt.Errorf("insert refresh token: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
r, err := getRefresh(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r, err = updater(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update refresh_token
|
||||
set
|
||||
|
@ -308,15 +318,25 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
|
|||
r.ConnectorID, r.ConnectorData,
|
||||
r.Token, r.CreatedAt, r.LastUsed, id,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update refresh token: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||
return getRefresh(c, id)
|
||||
req, err := getRefresh(c, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.RefreshToken{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
|
||||
|
@ -342,14 +362,15 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
|||
from refresh_token;
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query: %v", err)
|
||||
return nil, fmt.Errorf("select refresh tokens: %v", err)
|
||||
}
|
||||
var tokens []storage.RefreshToken
|
||||
for rows.Next() {
|
||||
r, err := scanRefresh(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("scan refresh token: %s", err)
|
||||
}
|
||||
|
||||
tokens = append(tokens, r)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
|
@ -367,10 +388,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
|
|||
&r.Token, &r.CreatedAt, &r.LastUsed,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return r, storage.ErrNotFound
|
||||
}
|
||||
return r, fmt.Errorf("scan refresh_token: %v", err)
|
||||
return r, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
@ -381,12 +399,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
|||
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
|
||||
// server. Test this, and consider adding a COUNT() command beforehand.
|
||||
old, err := getKeys(tx)
|
||||
if err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
return fmt.Errorf("get keys: %v", err)
|
||||
}
|
||||
if err == sql.ErrNoRows {
|
||||
firstUpdate = true
|
||||
old = storage.Keys{}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nk, err := updater(old)
|
||||
|
@ -405,7 +422,7 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
|||
encoder(nk.SigningKeyPub), nk.NextRotation,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert: %v", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
_, err = tx.Exec(`
|
||||
|
@ -421,15 +438,24 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
|||
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetKeys() (keys storage.Keys, err error) {
|
||||
return getKeys(c)
|
||||
func (c *conn) GetKeys() (storage.Keys, error) {
|
||||
keys, err := getKeys(c)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.Keys{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.Keys{}, fmt.Errorf("select keys: %s", err)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func getKeys(q querier) (keys storage.Keys, err error) {
|
||||
|
@ -443,20 +469,18 @@ func getKeys(q querier) (keys storage.Keys, err error) {
|
|||
decoder(&keys.SigningKeyPub), &keys.NextRotation,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return keys, storage.ErrNotFound
|
||||
}
|
||||
return keys, fmt.Errorf("query keys: %v", err)
|
||||
return keys, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
cli, err := getClient(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nc, err := updater(cli)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -474,11 +498,13 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage
|
|||
where id = $7;
|
||||
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update client: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) CreateClient(cli storage.Client) error {
|
||||
|
@ -509,7 +535,16 @@ func getClient(q querier, id string) (storage.Client, error) {
|
|||
}
|
||||
|
||||
func (c *conn) GetClient(id string) (storage.Client, error) {
|
||||
return getClient(c, id)
|
||||
client, err := getClient(c, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.Client{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.Client{}, fmt.Errorf("select client: %v", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *conn) ListClients() ([]storage.Client, error) {
|
||||
|
@ -525,12 +560,12 @@ func (c *conn) ListClients() ([]storage.Client, error) {
|
|||
for rows.Next() {
|
||||
cli, err := scanClient(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("scan client: %s", err)
|
||||
}
|
||||
clients = append(clients, cli)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("scan: %s", err)
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
|
@ -541,10 +576,7 @@ func scanClient(s scanner) (cli storage.Client, err error) {
|
|||
&cli.Public, &cli.Name, &cli.LogoURL,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return cli, storage.ErrNotFound
|
||||
}
|
||||
return cli, fmt.Errorf("get client: %v", err)
|
||||
return cli, err
|
||||
}
|
||||
return cli, nil
|
||||
}
|
||||
|
@ -571,7 +603,7 @@ func (c *conn) CreatePassword(p storage.Password) error {
|
|||
}
|
||||
|
||||
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
p, err := getPassword(tx, email)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -581,6 +613,7 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update password
|
||||
set
|
||||
|
@ -589,15 +622,25 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
|
|||
`,
|
||||
np.Hash, np.Username, np.UserID, p.Email,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update password: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetPassword(email string) (storage.Password, error) {
|
||||
return getPassword(c, email)
|
||||
pass, err := getPassword(c, email)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.Password{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.Password{}, fmt.Errorf("get password: %s", err)
|
||||
}
|
||||
|
||||
return pass, nil
|
||||
}
|
||||
|
||||
func getPassword(q querier, email string) (p storage.Password, err error) {
|
||||
|
@ -622,12 +665,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) {
|
|||
for rows.Next() {
|
||||
p, err := scanPassword(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("scan password: %s", err)
|
||||
}
|
||||
passwords = append(passwords, p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("scan: %s", err)
|
||||
}
|
||||
return passwords, nil
|
||||
}
|
||||
|
@ -637,10 +680,7 @@ func scanPassword(s scanner) (p storage.Password, err error) {
|
|||
&p.Email, &p.Hash, &p.Username, &p.UserID,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return p, storage.ErrNotFound
|
||||
}
|
||||
return p, fmt.Errorf("select password: %v", err)
|
||||
return p, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
@ -666,7 +706,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
|||
}
|
||||
|
||||
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
s, err := getOfflineSessions(tx, userID, connID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -676,6 +716,7 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update offline_session
|
||||
set
|
||||
|
@ -684,15 +725,26 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
|
|||
`,
|
||||
encoder(newSession.Refresh), s.UserID, s.ConnID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update offline session: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
|
||||
return getOfflineSessions(c, userID, connID)
|
||||
sessions, err := getOfflineSessions(c, userID, connID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.OfflineSessions{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
|
||||
|
@ -709,10 +761,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
|
|||
&o.UserID, &o.ConnID, decoder(&o.Refresh),
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return o, storage.ErrNotFound
|
||||
}
|
||||
return o, fmt.Errorf("select offline session: %v", err)
|
||||
return o, err
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
@ -738,7 +787,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error {
|
|||
}
|
||||
|
||||
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
connector, err := getConnector(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -748,6 +797,7 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update connector
|
||||
set
|
||||
|
@ -759,15 +809,26 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
|
|||
`,
|
||||
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update connector: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetConnector(id string) (storage.Connector, error) {
|
||||
return getConnector(c, id)
|
||||
connector, err := getConnector(c, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.Connector{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.Connector{}, fmt.Errorf("get connector: %s", err)
|
||||
}
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
func getConnector(q querier, id string) (storage.Connector, error) {
|
||||
|
@ -784,10 +845,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) {
|
|||
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return c, storage.ErrNotFound
|
||||
}
|
||||
return c, fmt.Errorf("select connector: %v", err)
|
||||
return c, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
@ -805,12 +863,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) {
|
|||
for rows.Next() {
|
||||
conn, err := scanConnector(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("scan connector: %s", err)
|
||||
}
|
||||
connectors = append(connectors, conn)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("scan: %s", err)
|
||||
}
|
||||
return connectors, nil
|
||||
}
|
||||
|
|
|
@ -44,13 +44,14 @@ var (
|
|||
// The "github.com/lib/pq" driver is the default flavor. All others are
|
||||
// translations of this.
|
||||
flavorPostgres = flavor{
|
||||
// The default behavior for Postgres transactions is consistent reads, not consistent writes.
|
||||
// For each transaction opened, ensure it has the correct isolation level.
|
||||
// The default behavior for Postgres transactions is consistent reads, not
|
||||
// consistent writes. For each transaction opened, ensure it has the
|
||||
// correct isolation level.
|
||||
//
|
||||
// See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html
|
||||
//
|
||||
// NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a
|
||||
// session level didn't work for some edge cases. Might be something worth exploring.
|
||||
// Be careful not to wrap sql errors in the callback 'fn', otherwise
|
||||
// serialization failures will not be detected and retried.
|
||||
executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
defer cancel()
|
||||
|
@ -66,6 +67,11 @@ var (
|
|||
}
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" {
|
||||
// serialization error; retry
|
||||
continue
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue