diff --git a/storage/sql/crud.go b/storage/sql/crud.go index d7c055ab..a1406e20 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { } func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { r, err := getAuthRequest(tx, id) if err != nil { return err @@ -144,6 +144,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) if err != nil { return err } + _, err = tx.Exec(` update auth_request set @@ -163,21 +164,31 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) a.ConnectorID, a.ConnectorData, a.Expiry, r.ID, ) - if err != nil { - return fmt.Errorf("update auth request: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update auth request: %v", err) + } + return nil } func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { - return getAuthRequest(c, id) + req, err := getAuthRequest(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.AuthRequest{}, storage.ErrNotFound + } + + return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err) + } + + return req, nil } func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { err = q.QueryRow(` - select + select id, client_id, response_types, scopes, redirect_uri, nonce, state, force_approval_prompt, logged_in, claims_user_id, claims_username, claims_email, claims_email_verified, @@ -192,10 +203,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { &a.ConnectorID, &a.ConnectorData, &a.Expiry, ) if err != nil { - if err == sql.ErrNoRows { - return a, storage.ErrNotFound - } - return a, fmt.Errorf("select auth request: %v", err) + return a, err } return a, nil } @@ -269,20 +277,22 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { if c.alreadyExistsCheck(err) { return storage.ErrAlreadyExists } - return fmt.Errorf("insert refresh_token: %v", err) + return fmt.Errorf("insert refresh token: %v", err) } return nil } func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { r, err := getRefresh(tx, id) if err != nil { return err } + if r, err = updater(r); err != nil { return err } + _, err = tx.Exec(` update refresh_token set @@ -308,15 +318,25 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok r.ConnectorID, r.ConnectorData, r.Token, r.CreatedAt, r.LastUsed, id, ) - if err != nil { - return fmt.Errorf("update refresh token: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update refresh token: %v", err) + } + return nil } func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { - return getRefresh(c, id) + req, err := getRefresh(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.RefreshToken{}, storage.ErrNotFound + } + + return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err) + } + + return req, nil } func getRefresh(q querier, id string) (storage.RefreshToken, error) { @@ -342,14 +362,15 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { from refresh_token; `) if err != nil { - return nil, fmt.Errorf("query: %v", err) + return nil, fmt.Errorf("select refresh tokens: %v", err) } var tokens []storage.RefreshToken for rows.Next() { r, err := scanRefresh(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan refresh token: %s", err) } + tokens = append(tokens, r) } if err := rows.Err(); err != nil { @@ -367,10 +388,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) { &r.Token, &r.CreatedAt, &r.LastUsed, ) if err != nil { - if err == sql.ErrNoRows { - return r, storage.ErrNotFound - } - return r, fmt.Errorf("scan refresh_token: %v", err) + return r, err } return r, nil } @@ -381,12 +399,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) // TODO(ericchiang): errors may cause a transaction be rolled back by the SQL // server. Test this, and consider adding a COUNT() command beforehand. old, err := getKeys(tx) - if err != nil { - if err != storage.ErrNotFound { - return fmt.Errorf("get keys: %v", err) - } + if err == sql.ErrNoRows { firstUpdate = true old = storage.Keys{} + } else if err != nil { + return err } nk, err := updater(old) @@ -405,12 +422,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) encoder(nk.SigningKeyPub), nk.NextRotation, ) if err != nil { - return fmt.Errorf("insert: %v", err) + return err } } else { _, err = tx.Exec(` update keys - set + set verification_keys = $1, signing_key = $2, signing_key_pub = $3, @@ -421,15 +438,24 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID, ) if err != nil { - return fmt.Errorf("update: %v", err) + return err } } return nil }) } -func (c *conn) GetKeys() (keys storage.Keys, err error) { - return getKeys(c) +func (c *conn) GetKeys() (storage.Keys, error) { + keys, err := getKeys(c) + if err != nil { + if err == sql.ErrNoRows { + return storage.Keys{}, storage.ErrNotFound + } + + return storage.Keys{}, fmt.Errorf("select keys: %s", err) + } + + return keys, nil } func getKeys(q querier) (keys storage.Keys, err error) { @@ -443,20 +469,18 @@ func getKeys(q querier) (keys storage.Keys, err error) { decoder(&keys.SigningKeyPub), &keys.NextRotation, ) if err != nil { - if err == sql.ErrNoRows { - return keys, storage.ErrNotFound - } - return keys, fmt.Errorf("query keys: %v", err) + return keys, err } return keys, nil } func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { cli, err := getClient(tx, id) if err != nil { return err } + nc, err := updater(cli) if err != nil { return err @@ -474,11 +498,13 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage where id = $7; `, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id, ) - if err != nil { - return fmt.Errorf("update client: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update client: %v", err) + } + + return nil } func (c *conn) CreateClient(cli storage.Client) error { @@ -509,7 +535,16 @@ func getClient(q querier, id string) (storage.Client, error) { } func (c *conn) GetClient(id string) (storage.Client, error) { - return getClient(c, id) + client, err := getClient(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.Client{}, storage.ErrNotFound + } + + return storage.Client{}, fmt.Errorf("select client: %v", err) + } + + return client, nil } func (c *conn) ListClients() ([]storage.Client, error) { @@ -525,12 +560,12 @@ func (c *conn) ListClients() ([]storage.Client, error) { for rows.Next() { cli, err := scanClient(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan client: %s", err) } clients = append(clients, cli) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return clients, nil } @@ -541,10 +576,7 @@ func scanClient(s scanner) (cli storage.Client, err error) { &cli.Public, &cli.Name, &cli.LogoURL, ) if err != nil { - if err == sql.ErrNoRows { - return cli, storage.ErrNotFound - } - return cli, fmt.Errorf("get client: %v", err) + return cli, err } return cli, nil } @@ -571,7 +603,7 @@ func (c *conn) CreatePassword(p storage.Password) error { } func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { p, err := getPassword(tx, email) if err != nil { return err @@ -581,6 +613,7 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st if err != nil { return err } + _, err = tx.Exec(` update password set @@ -589,15 +622,25 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st `, np.Hash, np.Username, np.UserID, p.Email, ) - if err != nil { - return fmt.Errorf("update password: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update password: %v", err) + } + return nil } func (c *conn) GetPassword(email string) (storage.Password, error) { - return getPassword(c, email) + pass, err := getPassword(c, email) + if err != nil { + if err == sql.ErrNoRows { + return storage.Password{}, storage.ErrNotFound + } + + return storage.Password{}, fmt.Errorf("get password: %s", err) + } + + return pass, nil } func getPassword(q querier, email string) (p storage.Password, err error) { @@ -622,12 +665,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) { for rows.Next() { p, err := scanPassword(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan password: %s", err) } passwords = append(passwords, p) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return passwords, nil } @@ -637,10 +680,7 @@ func scanPassword(s scanner) (p storage.Password, err error) { &p.Email, &p.Hash, &p.Username, &p.UserID, ) if err != nil { - if err == sql.ErrNoRows { - return p, storage.ErrNotFound - } - return p, fmt.Errorf("select password: %v", err) + return p, err } return p, nil } @@ -666,7 +706,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { } func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { s, err := getOfflineSessions(tx, userID, connID) if err != nil { return err @@ -676,6 +716,7 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( if err != nil { return err } + _, err = tx.Exec(` update offline_session set @@ -684,15 +725,26 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( `, encoder(newSession.Refresh), s.UserID, s.ConnID, ) - if err != nil { - return fmt.Errorf("update offline session: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update offline session: %v", err) + } + + return nil } func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) { - return getOfflineSessions(c, userID, connID) + sessions, err := getOfflineSessions(c, userID, connID) + if err != nil { + if err == sql.ErrNoRows { + return storage.OfflineSessions{}, storage.ErrNotFound + } + + return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err) + } + + return sessions, nil } func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { @@ -709,10 +761,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { &o.UserID, &o.ConnID, decoder(&o.Refresh), ) if err != nil { - if err == sql.ErrNoRows { - return o, storage.ErrNotFound - } - return o, fmt.Errorf("select offline session: %v", err) + return o, err } return o, nil } @@ -738,7 +787,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error { } func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { connector, err := getConnector(tx, id) if err != nil { return err @@ -748,9 +797,10 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto if err != nil { return err } + _, err = tx.Exec(` update connector - set + set type = $1, name = $2, resource_version = $3, @@ -759,15 +809,26 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto `, newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID, ) - if err != nil { - return fmt.Errorf("update connector: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update connector: %v", err) + } + + return nil } func (c *conn) GetConnector(id string) (storage.Connector, error) { - return getConnector(c, id) + connector, err := getConnector(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.Connector{}, storage.ErrNotFound + } + + return storage.Connector{}, fmt.Errorf("get connector: %s", err) + } + + return connector, nil } func getConnector(q querier, id string) (storage.Connector, error) { @@ -784,10 +845,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) { &c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config, ) if err != nil { - if err == sql.ErrNoRows { - return c, storage.ErrNotFound - } - return c, fmt.Errorf("select connector: %v", err) + return c, err } return c, nil } @@ -805,12 +863,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) { for rows.Next() { conn, err := scanConnector(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan connector: %s", err) } connectors = append(connectors, conn) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return connectors, nil } diff --git a/storage/sql/sql.go b/storage/sql/sql.go index 7f20cf9d..b51f6fcc 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -44,13 +44,14 @@ var ( // The "github.com/lib/pq" driver is the default flavor. All others are // translations of this. flavorPostgres = flavor{ - // The default behavior for Postgres transactions is consistent reads, not consistent writes. - // For each transaction opened, ensure it has the correct isolation level. + // The default behavior for Postgres transactions is consistent reads, not + // consistent writes. For each transaction opened, ensure it has the + // correct isolation level. // // See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html // - // NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a - // session level didn't work for some edge cases. Might be something worth exploring. + // Be careful not to wrap sql errors in the callback 'fn', otherwise + // serialization failures will not be detected and retried. executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -66,6 +67,11 @@ var ( } if err := fn(tx); err != nil { + if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" { + // serialization error; retry + continue + } + return err }