From 587081a643af2e9e6011750eccd469c0ef0f16c8 Mon Sep 17 00:00:00 2001 From: Alex Suraci Date: Mon, 19 Nov 2018 11:34:45 -0500 Subject: [PATCH] postgres: refactor error handling to fix retrying prior to this change, many of the functions in the ExecTx callback would wrap the error before returning it. this made it impossible to check for the error code. instead, the error wrapping has been moved to be external to the `ExecTx` callback, so that the error code can be checked and serialization failures can be retried. --- storage/sql/crud.go | 226 ++++++++++++++++++++++++++++---------------- storage/sql/sql.go | 14 ++- 2 files changed, 152 insertions(+), 88 deletions(-) diff --git a/storage/sql/crud.go b/storage/sql/crud.go index d7c055ab..a1406e20 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { } func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { r, err := getAuthRequest(tx, id) if err != nil { return err @@ -144,6 +144,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) if err != nil { return err } + _, err = tx.Exec(` update auth_request set @@ -163,21 +164,31 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) a.ConnectorID, a.ConnectorData, a.Expiry, r.ID, ) - if err != nil { - return fmt.Errorf("update auth request: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update auth request: %v", err) + } + return nil } func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { - return getAuthRequest(c, id) + req, err := getAuthRequest(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.AuthRequest{}, storage.ErrNotFound + } + + return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err) + } + + return req, nil } func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { err = q.QueryRow(` - select + select id, client_id, response_types, scopes, redirect_uri, nonce, state, force_approval_prompt, logged_in, claims_user_id, claims_username, claims_email, claims_email_verified, @@ -192,10 +203,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { &a.ConnectorID, &a.ConnectorData, &a.Expiry, ) if err != nil { - if err == sql.ErrNoRows { - return a, storage.ErrNotFound - } - return a, fmt.Errorf("select auth request: %v", err) + return a, err } return a, nil } @@ -269,20 +277,22 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { if c.alreadyExistsCheck(err) { return storage.ErrAlreadyExists } - return fmt.Errorf("insert refresh_token: %v", err) + return fmt.Errorf("insert refresh token: %v", err) } return nil } func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { r, err := getRefresh(tx, id) if err != nil { return err } + if r, err = updater(r); err != nil { return err } + _, err = tx.Exec(` update refresh_token set @@ -308,15 +318,25 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok r.ConnectorID, r.ConnectorData, r.Token, r.CreatedAt, r.LastUsed, id, ) - if err != nil { - return fmt.Errorf("update refresh token: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update refresh token: %v", err) + } + return nil } func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { - return getRefresh(c, id) + req, err := getRefresh(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.RefreshToken{}, storage.ErrNotFound + } + + return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err) + } + + return req, nil } func getRefresh(q querier, id string) (storage.RefreshToken, error) { @@ -342,14 +362,15 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { from refresh_token; `) if err != nil { - return nil, fmt.Errorf("query: %v", err) + return nil, fmt.Errorf("select refresh tokens: %v", err) } var tokens []storage.RefreshToken for rows.Next() { r, err := scanRefresh(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan refresh token: %s", err) } + tokens = append(tokens, r) } if err := rows.Err(); err != nil { @@ -367,10 +388,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) { &r.Token, &r.CreatedAt, &r.LastUsed, ) if err != nil { - if err == sql.ErrNoRows { - return r, storage.ErrNotFound - } - return r, fmt.Errorf("scan refresh_token: %v", err) + return r, err } return r, nil } @@ -381,12 +399,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) // TODO(ericchiang): errors may cause a transaction be rolled back by the SQL // server. Test this, and consider adding a COUNT() command beforehand. old, err := getKeys(tx) - if err != nil { - if err != storage.ErrNotFound { - return fmt.Errorf("get keys: %v", err) - } + if err == sql.ErrNoRows { firstUpdate = true old = storage.Keys{} + } else if err != nil { + return err } nk, err := updater(old) @@ -405,12 +422,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) encoder(nk.SigningKeyPub), nk.NextRotation, ) if err != nil { - return fmt.Errorf("insert: %v", err) + return err } } else { _, err = tx.Exec(` update keys - set + set verification_keys = $1, signing_key = $2, signing_key_pub = $3, @@ -421,15 +438,24 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID, ) if err != nil { - return fmt.Errorf("update: %v", err) + return err } } return nil }) } -func (c *conn) GetKeys() (keys storage.Keys, err error) { - return getKeys(c) +func (c *conn) GetKeys() (storage.Keys, error) { + keys, err := getKeys(c) + if err != nil { + if err == sql.ErrNoRows { + return storage.Keys{}, storage.ErrNotFound + } + + return storage.Keys{}, fmt.Errorf("select keys: %s", err) + } + + return keys, nil } func getKeys(q querier) (keys storage.Keys, err error) { @@ -443,20 +469,18 @@ func getKeys(q querier) (keys storage.Keys, err error) { decoder(&keys.SigningKeyPub), &keys.NextRotation, ) if err != nil { - if err == sql.ErrNoRows { - return keys, storage.ErrNotFound - } - return keys, fmt.Errorf("query keys: %v", err) + return keys, err } return keys, nil } func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { cli, err := getClient(tx, id) if err != nil { return err } + nc, err := updater(cli) if err != nil { return err @@ -474,11 +498,13 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage where id = $7; `, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id, ) - if err != nil { - return fmt.Errorf("update client: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update client: %v", err) + } + + return nil } func (c *conn) CreateClient(cli storage.Client) error { @@ -509,7 +535,16 @@ func getClient(q querier, id string) (storage.Client, error) { } func (c *conn) GetClient(id string) (storage.Client, error) { - return getClient(c, id) + client, err := getClient(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.Client{}, storage.ErrNotFound + } + + return storage.Client{}, fmt.Errorf("select client: %v", err) + } + + return client, nil } func (c *conn) ListClients() ([]storage.Client, error) { @@ -525,12 +560,12 @@ func (c *conn) ListClients() ([]storage.Client, error) { for rows.Next() { cli, err := scanClient(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan client: %s", err) } clients = append(clients, cli) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return clients, nil } @@ -541,10 +576,7 @@ func scanClient(s scanner) (cli storage.Client, err error) { &cli.Public, &cli.Name, &cli.LogoURL, ) if err != nil { - if err == sql.ErrNoRows { - return cli, storage.ErrNotFound - } - return cli, fmt.Errorf("get client: %v", err) + return cli, err } return cli, nil } @@ -571,7 +603,7 @@ func (c *conn) CreatePassword(p storage.Password) error { } func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { p, err := getPassword(tx, email) if err != nil { return err @@ -581,6 +613,7 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st if err != nil { return err } + _, err = tx.Exec(` update password set @@ -589,15 +622,25 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st `, np.Hash, np.Username, np.UserID, p.Email, ) - if err != nil { - return fmt.Errorf("update password: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update password: %v", err) + } + return nil } func (c *conn) GetPassword(email string) (storage.Password, error) { - return getPassword(c, email) + pass, err := getPassword(c, email) + if err != nil { + if err == sql.ErrNoRows { + return storage.Password{}, storage.ErrNotFound + } + + return storage.Password{}, fmt.Errorf("get password: %s", err) + } + + return pass, nil } func getPassword(q querier, email string) (p storage.Password, err error) { @@ -622,12 +665,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) { for rows.Next() { p, err := scanPassword(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan password: %s", err) } passwords = append(passwords, p) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return passwords, nil } @@ -637,10 +680,7 @@ func scanPassword(s scanner) (p storage.Password, err error) { &p.Email, &p.Hash, &p.Username, &p.UserID, ) if err != nil { - if err == sql.ErrNoRows { - return p, storage.ErrNotFound - } - return p, fmt.Errorf("select password: %v", err) + return p, err } return p, nil } @@ -666,7 +706,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { } func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { s, err := getOfflineSessions(tx, userID, connID) if err != nil { return err @@ -676,6 +716,7 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( if err != nil { return err } + _, err = tx.Exec(` update offline_session set @@ -684,15 +725,26 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( `, encoder(newSession.Refresh), s.UserID, s.ConnID, ) - if err != nil { - return fmt.Errorf("update offline session: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update offline session: %v", err) + } + + return nil } func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) { - return getOfflineSessions(c, userID, connID) + sessions, err := getOfflineSessions(c, userID, connID) + if err != nil { + if err == sql.ErrNoRows { + return storage.OfflineSessions{}, storage.ErrNotFound + } + + return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err) + } + + return sessions, nil } func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { @@ -709,10 +761,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { &o.UserID, &o.ConnID, decoder(&o.Refresh), ) if err != nil { - if err == sql.ErrNoRows { - return o, storage.ErrNotFound - } - return o, fmt.Errorf("select offline session: %v", err) + return o, err } return o, nil } @@ -738,7 +787,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error { } func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { connector, err := getConnector(tx, id) if err != nil { return err @@ -748,9 +797,10 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto if err != nil { return err } + _, err = tx.Exec(` update connector - set + set type = $1, name = $2, resource_version = $3, @@ -759,15 +809,26 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto `, newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID, ) - if err != nil { - return fmt.Errorf("update connector: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update connector: %v", err) + } + + return nil } func (c *conn) GetConnector(id string) (storage.Connector, error) { - return getConnector(c, id) + connector, err := getConnector(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.Connector{}, storage.ErrNotFound + } + + return storage.Connector{}, fmt.Errorf("get connector: %s", err) + } + + return connector, nil } func getConnector(q querier, id string) (storage.Connector, error) { @@ -784,10 +845,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) { &c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config, ) if err != nil { - if err == sql.ErrNoRows { - return c, storage.ErrNotFound - } - return c, fmt.Errorf("select connector: %v", err) + return c, err } return c, nil } @@ -805,12 +863,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) { for rows.Next() { conn, err := scanConnector(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan connector: %s", err) } connectors = append(connectors, conn) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return connectors, nil } diff --git a/storage/sql/sql.go b/storage/sql/sql.go index 7f20cf9d..b51f6fcc 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -44,13 +44,14 @@ var ( // The "github.com/lib/pq" driver is the default flavor. All others are // translations of this. flavorPostgres = flavor{ - // The default behavior for Postgres transactions is consistent reads, not consistent writes. - // For each transaction opened, ensure it has the correct isolation level. + // The default behavior for Postgres transactions is consistent reads, not + // consistent writes. For each transaction opened, ensure it has the + // correct isolation level. // // See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html // - // NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a - // session level didn't work for some edge cases. Might be something worth exploring. + // Be careful not to wrap sql errors in the callback 'fn', otherwise + // serialization failures will not be detected and retried. executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -66,6 +67,11 @@ var ( } if err := fn(tx); err != nil { + if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" { + // serialization error; retry + continue + } + return err }