postgres: refactor error handling to fix retrying

prior to this change, many of the functions in the ExecTx callback would
wrap the error before returning it. this made it impossible to check
for the error code.

instead, the error wrapping has been moved to be external to the
`ExecTx` callback, so that the error code can be checked and
serialization failures can be retried.
This commit is contained in:
Alex Suraci 2018-11-19 11:34:45 -05:00
parent 5d67da1472
commit 587081a643
2 changed files with 152 additions and 88 deletions

View file

@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
}
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
r, err := getAuthRequest(tx, id)
if err != nil {
return err
@ -144,6 +144,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
if err != nil {
return err
}
_, err = tx.Exec(`
update auth_request
set
@ -163,21 +164,31 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
a.ConnectorID, a.ConnectorData,
a.Expiry, r.ID,
)
if err != nil {
return fmt.Errorf("update auth request: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update auth request: %v", err)
}
return nil
}
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
return getAuthRequest(c, id)
req, err := getAuthRequest(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.AuthRequest{}, storage.ErrNotFound
}
return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err)
}
return req, nil
}
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(`
select
select
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_email, claims_email_verified,
@ -192,10 +203,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select auth request: %v", err)
return a, err
}
return a, nil
}
@ -269,20 +277,22 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert refresh_token: %v", err)
return fmt.Errorf("insert refresh token: %v", err)
}
return nil
}
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update refresh_token
set
@ -308,15 +318,25 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, id,
)
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
}
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
return getRefresh(c, id)
req, err := getRefresh(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.RefreshToken{}, storage.ErrNotFound
}
return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err)
}
return req, nil
}
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
@ -342,14 +362,15 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
from refresh_token;
`)
if err != nil {
return nil, fmt.Errorf("query: %v", err)
return nil, fmt.Errorf("select refresh tokens: %v", err)
}
var tokens []storage.RefreshToken
for rows.Next() {
r, err := scanRefresh(rows)
if err != nil {
return nil, err
return nil, fmt.Errorf("scan refresh token: %s", err)
}
tokens = append(tokens, r)
}
if err := rows.Err(); err != nil {
@ -367,10 +388,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
&r.Token, &r.CreatedAt, &r.LastUsed,
)
if err != nil {
if err == sql.ErrNoRows {
return r, storage.ErrNotFound
}
return r, fmt.Errorf("scan refresh_token: %v", err)
return r, err
}
return r, nil
}
@ -381,12 +399,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
// server. Test this, and consider adding a COUNT() command beforehand.
old, err := getKeys(tx)
if err != nil {
if err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
}
if err == sql.ErrNoRows {
firstUpdate = true
old = storage.Keys{}
} else if err != nil {
return err
}
nk, err := updater(old)
@ -405,12 +422,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
encoder(nk.SigningKeyPub), nk.NextRotation,
)
if err != nil {
return fmt.Errorf("insert: %v", err)
return err
}
} else {
_, err = tx.Exec(`
update keys
set
set
verification_keys = $1,
signing_key = $2,
signing_key_pub = $3,
@ -421,15 +438,24 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
)
if err != nil {
return fmt.Errorf("update: %v", err)
return err
}
}
return nil
})
}
func (c *conn) GetKeys() (keys storage.Keys, err error) {
return getKeys(c)
func (c *conn) GetKeys() (storage.Keys, error) {
keys, err := getKeys(c)
if err != nil {
if err == sql.ErrNoRows {
return storage.Keys{}, storage.ErrNotFound
}
return storage.Keys{}, fmt.Errorf("select keys: %s", err)
}
return keys, nil
}
func getKeys(q querier) (keys storage.Keys, err error) {
@ -443,20 +469,18 @@ func getKeys(q querier) (keys storage.Keys, err error) {
decoder(&keys.SigningKeyPub), &keys.NextRotation,
)
if err != nil {
if err == sql.ErrNoRows {
return keys, storage.ErrNotFound
}
return keys, fmt.Errorf("query keys: %v", err)
return keys, err
}
return keys, nil
}
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
cli, err := getClient(tx, id)
if err != nil {
return err
}
nc, err := updater(cli)
if err != nil {
return err
@ -474,11 +498,13 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage
where id = $7;
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
)
if err != nil {
return fmt.Errorf("update client: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update client: %v", err)
}
return nil
}
func (c *conn) CreateClient(cli storage.Client) error {
@ -509,7 +535,16 @@ func getClient(q querier, id string) (storage.Client, error) {
}
func (c *conn) GetClient(id string) (storage.Client, error) {
return getClient(c, id)
client, err := getClient(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Client{}, storage.ErrNotFound
}
return storage.Client{}, fmt.Errorf("select client: %v", err)
}
return client, nil
}
func (c *conn) ListClients() ([]storage.Client, error) {
@ -525,12 +560,12 @@ func (c *conn) ListClients() ([]storage.Client, error) {
for rows.Next() {
cli, err := scanClient(rows)
if err != nil {
return nil, err
return nil, fmt.Errorf("scan client: %s", err)
}
clients = append(clients, cli)
}
if err := rows.Err(); err != nil {
return nil, err
return nil, fmt.Errorf("scan: %s", err)
}
return clients, nil
}
@ -541,10 +576,7 @@ func scanClient(s scanner) (cli storage.Client, err error) {
&cli.Public, &cli.Name, &cli.LogoURL,
)
if err != nil {
if err == sql.ErrNoRows {
return cli, storage.ErrNotFound
}
return cli, fmt.Errorf("get client: %v", err)
return cli, err
}
return cli, nil
}
@ -571,7 +603,7 @@ func (c *conn) CreatePassword(p storage.Password) error {
}
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
p, err := getPassword(tx, email)
if err != nil {
return err
@ -581,6 +613,7 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
if err != nil {
return err
}
_, err = tx.Exec(`
update password
set
@ -589,15 +622,25 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
`,
np.Hash, np.Username, np.UserID, p.Email,
)
if err != nil {
return fmt.Errorf("update password: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update password: %v", err)
}
return nil
}
func (c *conn) GetPassword(email string) (storage.Password, error) {
return getPassword(c, email)
pass, err := getPassword(c, email)
if err != nil {
if err == sql.ErrNoRows {
return storage.Password{}, storage.ErrNotFound
}
return storage.Password{}, fmt.Errorf("get password: %s", err)
}
return pass, nil
}
func getPassword(q querier, email string) (p storage.Password, err error) {
@ -622,12 +665,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) {
for rows.Next() {
p, err := scanPassword(rows)
if err != nil {
return nil, err
return nil, fmt.Errorf("scan password: %s", err)
}
passwords = append(passwords, p)
}
if err := rows.Err(); err != nil {
return nil, err
return nil, fmt.Errorf("scan: %s", err)
}
return passwords, nil
}
@ -637,10 +680,7 @@ func scanPassword(s scanner) (p storage.Password, err error) {
&p.Email, &p.Hash, &p.Username, &p.UserID,
)
if err != nil {
if err == sql.ErrNoRows {
return p, storage.ErrNotFound
}
return p, fmt.Errorf("select password: %v", err)
return p, err
}
return p, nil
}
@ -666,7 +706,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID)
if err != nil {
return err
@ -676,6 +716,7 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
if err != nil {
return err
}
_, err = tx.Exec(`
update offline_session
set
@ -684,15 +725,26 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
`,
encoder(newSession.Refresh), s.UserID, s.ConnID,
)
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
}
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
return getOfflineSessions(c, userID, connID)
sessions, err := getOfflineSessions(c, userID, connID)
if err != nil {
if err == sql.ErrNoRows {
return storage.OfflineSessions{}, storage.ErrNotFound
}
return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err)
}
return sessions, nil
}
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
@ -709,10 +761,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
&o.UserID, &o.ConnID, decoder(&o.Refresh),
)
if err != nil {
if err == sql.ErrNoRows {
return o, storage.ErrNotFound
}
return o, fmt.Errorf("select offline session: %v", err)
return o, err
}
return o, nil
}
@ -738,7 +787,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error {
}
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
connector, err := getConnector(tx, id)
if err != nil {
return err
@ -748,9 +797,10 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
if err != nil {
return err
}
_, err = tx.Exec(`
update connector
set
set
type = $1,
name = $2,
resource_version = $3,
@ -759,15 +809,26 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
`,
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
)
if err != nil {
return fmt.Errorf("update connector: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update connector: %v", err)
}
return nil
}
func (c *conn) GetConnector(id string) (storage.Connector, error) {
return getConnector(c, id)
connector, err := getConnector(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Connector{}, storage.ErrNotFound
}
return storage.Connector{}, fmt.Errorf("get connector: %s", err)
}
return connector, nil
}
func getConnector(q querier, id string) (storage.Connector, error) {
@ -784,10 +845,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) {
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
)
if err != nil {
if err == sql.ErrNoRows {
return c, storage.ErrNotFound
}
return c, fmt.Errorf("select connector: %v", err)
return c, err
}
return c, nil
}
@ -805,12 +863,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) {
for rows.Next() {
conn, err := scanConnector(rows)
if err != nil {
return nil, err
return nil, fmt.Errorf("scan connector: %s", err)
}
connectors = append(connectors, conn)
}
if err := rows.Err(); err != nil {
return nil, err
return nil, fmt.Errorf("scan: %s", err)
}
return connectors, nil
}

View file

@ -44,13 +44,14 @@ var (
// The "github.com/lib/pq" driver is the default flavor. All others are
// translations of this.
flavorPostgres = flavor{
// The default behavior for Postgres transactions is consistent reads, not consistent writes.
// For each transaction opened, ensure it has the correct isolation level.
// The default behavior for Postgres transactions is consistent reads, not
// consistent writes. For each transaction opened, ensure it has the
// correct isolation level.
//
// See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html
//
// NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a
// session level didn't work for some edge cases. Might be something worth exploring.
// Be careful not to wrap sql errors in the callback 'fn', otherwise
// serialization failures will not be detected and retried.
executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error {
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
@ -66,6 +67,11 @@ var (
}
if err := fn(tx); err != nil {
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" {
// serialization error; retry
continue
}
return err
}