forked from mystiq/dex
postgres: refactor error handling to fix retrying
prior to this change, many of the functions in the ExecTx callback would wrap the error before returning it. this made it impossible to check for the error code. instead, the error wrapping has been moved to be external to the `ExecTx` callback, so that the error code can be checked and serialization failures can be retried.
This commit is contained in:
parent
5d67da1472
commit
587081a643
2 changed files with 152 additions and 88 deletions
|
@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
|
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
|
||||||
return c.ExecTx(func(tx *trans) error {
|
err := c.ExecTx(func(tx *trans) error {
|
||||||
r, err := getAuthRequest(tx, id)
|
r, err := getAuthRequest(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -144,6 +144,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec(`
|
_, err = tx.Exec(`
|
||||||
update auth_request
|
update auth_request
|
||||||
set
|
set
|
||||||
|
@ -163,21 +164,31 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
|
||||||
a.ConnectorID, a.ConnectorData,
|
a.ConnectorID, a.ConnectorData,
|
||||||
a.Expiry, r.ID,
|
a.Expiry, r.ID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
return err
|
||||||
return fmt.Errorf("update auth request: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update auth request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||||
return getAuthRequest(c, id)
|
req, err := getAuthRequest(c, id)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return storage.AuthRequest{}, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
|
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
|
||||||
err = q.QueryRow(`
|
err = q.QueryRow(`
|
||||||
select
|
select
|
||||||
id, client_id, response_types, scopes, redirect_uri, nonce, state,
|
id, client_id, response_types, scopes, redirect_uri, nonce, state,
|
||||||
force_approval_prompt, logged_in,
|
force_approval_prompt, logged_in,
|
||||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||||
|
@ -192,10 +203,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
|
||||||
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
|
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
return a, err
|
||||||
return a, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
return a, fmt.Errorf("select auth request: %v", err)
|
|
||||||
}
|
}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
@ -269,20 +277,22 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||||
if c.alreadyExistsCheck(err) {
|
if c.alreadyExistsCheck(err) {
|
||||||
return storage.ErrAlreadyExists
|
return storage.ErrAlreadyExists
|
||||||
}
|
}
|
||||||
return fmt.Errorf("insert refresh_token: %v", err)
|
return fmt.Errorf("insert refresh token: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||||
return c.ExecTx(func(tx *trans) error {
|
err := c.ExecTx(func(tx *trans) error {
|
||||||
r, err := getRefresh(tx, id)
|
r, err := getRefresh(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if r, err = updater(r); err != nil {
|
if r, err = updater(r); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec(`
|
_, err = tx.Exec(`
|
||||||
update refresh_token
|
update refresh_token
|
||||||
set
|
set
|
||||||
|
@ -308,15 +318,25 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
|
||||||
r.ConnectorID, r.ConnectorData,
|
r.ConnectorID, r.ConnectorData,
|
||||||
r.Token, r.CreatedAt, r.LastUsed, id,
|
r.Token, r.CreatedAt, r.LastUsed, id,
|
||||||
)
|
)
|
||||||
if err != nil {
|
return err
|
||||||
return fmt.Errorf("update refresh token: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update refresh token: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
|
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||||
return getRefresh(c, id)
|
req, err := getRefresh(c, id)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return storage.RefreshToken{}, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
|
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
|
||||||
|
@ -342,14 +362,15 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||||
from refresh_token;
|
from refresh_token;
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query: %v", err)
|
return nil, fmt.Errorf("select refresh tokens: %v", err)
|
||||||
}
|
}
|
||||||
var tokens []storage.RefreshToken
|
var tokens []storage.RefreshToken
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
r, err := scanRefresh(rows)
|
r, err := scanRefresh(rows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("scan refresh token: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens = append(tokens, r)
|
tokens = append(tokens, r)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
|
@ -367,10 +388,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
|
||||||
&r.Token, &r.CreatedAt, &r.LastUsed,
|
&r.Token, &r.CreatedAt, &r.LastUsed,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
return r, err
|
||||||
return r, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
return r, fmt.Errorf("scan refresh_token: %v", err)
|
|
||||||
}
|
}
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
@ -381,12 +399,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
||||||
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
|
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
|
||||||
// server. Test this, and consider adding a COUNT() command beforehand.
|
// server. Test this, and consider adding a COUNT() command beforehand.
|
||||||
old, err := getKeys(tx)
|
old, err := getKeys(tx)
|
||||||
if err != nil {
|
if err == sql.ErrNoRows {
|
||||||
if err != storage.ErrNotFound {
|
|
||||||
return fmt.Errorf("get keys: %v", err)
|
|
||||||
}
|
|
||||||
firstUpdate = true
|
firstUpdate = true
|
||||||
old = storage.Keys{}
|
old = storage.Keys{}
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nk, err := updater(old)
|
nk, err := updater(old)
|
||||||
|
@ -405,12 +422,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
||||||
encoder(nk.SigningKeyPub), nk.NextRotation,
|
encoder(nk.SigningKeyPub), nk.NextRotation,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("insert: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, err = tx.Exec(`
|
_, err = tx.Exec(`
|
||||||
update keys
|
update keys
|
||||||
set
|
set
|
||||||
verification_keys = $1,
|
verification_keys = $1,
|
||||||
signing_key = $2,
|
signing_key = $2,
|
||||||
signing_key_pub = $3,
|
signing_key_pub = $3,
|
||||||
|
@ -421,15 +438,24 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
||||||
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
|
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("update: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) GetKeys() (keys storage.Keys, err error) {
|
func (c *conn) GetKeys() (storage.Keys, error) {
|
||||||
return getKeys(c)
|
keys, err := getKeys(c)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return storage.Keys{}, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.Keys{}, fmt.Errorf("select keys: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getKeys(q querier) (keys storage.Keys, err error) {
|
func getKeys(q querier) (keys storage.Keys, err error) {
|
||||||
|
@ -443,20 +469,18 @@ func getKeys(q querier) (keys storage.Keys, err error) {
|
||||||
decoder(&keys.SigningKeyPub), &keys.NextRotation,
|
decoder(&keys.SigningKeyPub), &keys.NextRotation,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
return keys, err
|
||||||
return keys, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
return keys, fmt.Errorf("query keys: %v", err)
|
|
||||||
}
|
}
|
||||||
return keys, nil
|
return keys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||||
return c.ExecTx(func(tx *trans) error {
|
err := c.ExecTx(func(tx *trans) error {
|
||||||
cli, err := getClient(tx, id)
|
cli, err := getClient(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nc, err := updater(cli)
|
nc, err := updater(cli)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -474,11 +498,13 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage
|
||||||
where id = $7;
|
where id = $7;
|
||||||
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
|
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
|
||||||
)
|
)
|
||||||
if err != nil {
|
return err
|
||||||
return fmt.Errorf("update client: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) CreateClient(cli storage.Client) error {
|
func (c *conn) CreateClient(cli storage.Client) error {
|
||||||
|
@ -509,7 +535,16 @@ func getClient(q querier, id string) (storage.Client, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) GetClient(id string) (storage.Client, error) {
|
func (c *conn) GetClient(id string) (storage.Client, error) {
|
||||||
return getClient(c, id)
|
client, err := getClient(c, id)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return storage.Client{}, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.Client{}, fmt.Errorf("select client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) ListClients() ([]storage.Client, error) {
|
func (c *conn) ListClients() ([]storage.Client, error) {
|
||||||
|
@ -525,12 +560,12 @@ func (c *conn) ListClients() ([]storage.Client, error) {
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
cli, err := scanClient(rows)
|
cli, err := scanClient(rows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("scan client: %s", err)
|
||||||
}
|
}
|
||||||
clients = append(clients, cli)
|
clients = append(clients, cli)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("scan: %s", err)
|
||||||
}
|
}
|
||||||
return clients, nil
|
return clients, nil
|
||||||
}
|
}
|
||||||
|
@ -541,10 +576,7 @@ func scanClient(s scanner) (cli storage.Client, err error) {
|
||||||
&cli.Public, &cli.Name, &cli.LogoURL,
|
&cli.Public, &cli.Name, &cli.LogoURL,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
return cli, err
|
||||||
return cli, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
return cli, fmt.Errorf("get client: %v", err)
|
|
||||||
}
|
}
|
||||||
return cli, nil
|
return cli, nil
|
||||||
}
|
}
|
||||||
|
@ -571,7 +603,7 @@ func (c *conn) CreatePassword(p storage.Password) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
|
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
|
||||||
return c.ExecTx(func(tx *trans) error {
|
err := c.ExecTx(func(tx *trans) error {
|
||||||
p, err := getPassword(tx, email)
|
p, err := getPassword(tx, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -581,6 +613,7 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec(`
|
_, err = tx.Exec(`
|
||||||
update password
|
update password
|
||||||
set
|
set
|
||||||
|
@ -589,15 +622,25 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
|
||||||
`,
|
`,
|
||||||
np.Hash, np.Username, np.UserID, p.Email,
|
np.Hash, np.Username, np.UserID, p.Email,
|
||||||
)
|
)
|
||||||
if err != nil {
|
return err
|
||||||
return fmt.Errorf("update password: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update password: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) GetPassword(email string) (storage.Password, error) {
|
func (c *conn) GetPassword(email string) (storage.Password, error) {
|
||||||
return getPassword(c, email)
|
pass, err := getPassword(c, email)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return storage.Password{}, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.Password{}, fmt.Errorf("get password: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return pass, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPassword(q querier, email string) (p storage.Password, err error) {
|
func getPassword(q querier, email string) (p storage.Password, err error) {
|
||||||
|
@ -622,12 +665,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) {
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
p, err := scanPassword(rows)
|
p, err := scanPassword(rows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("scan password: %s", err)
|
||||||
}
|
}
|
||||||
passwords = append(passwords, p)
|
passwords = append(passwords, p)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("scan: %s", err)
|
||||||
}
|
}
|
||||||
return passwords, nil
|
return passwords, nil
|
||||||
}
|
}
|
||||||
|
@ -637,10 +680,7 @@ func scanPassword(s scanner) (p storage.Password, err error) {
|
||||||
&p.Email, &p.Hash, &p.Username, &p.UserID,
|
&p.Email, &p.Hash, &p.Username, &p.UserID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
return p, err
|
||||||
return p, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
return p, fmt.Errorf("select password: %v", err)
|
|
||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
@ -666,7 +706,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||||
return c.ExecTx(func(tx *trans) error {
|
err := c.ExecTx(func(tx *trans) error {
|
||||||
s, err := getOfflineSessions(tx, userID, connID)
|
s, err := getOfflineSessions(tx, userID, connID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -676,6 +716,7 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec(`
|
_, err = tx.Exec(`
|
||||||
update offline_session
|
update offline_session
|
||||||
set
|
set
|
||||||
|
@ -684,15 +725,26 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
|
||||||
`,
|
`,
|
||||||
encoder(newSession.Refresh), s.UserID, s.ConnID,
|
encoder(newSession.Refresh), s.UserID, s.ConnID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
return err
|
||||||
return fmt.Errorf("update offline session: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update offline session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
|
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
|
||||||
return getOfflineSessions(c, userID, connID)
|
sessions, err := getOfflineSessions(c, userID, connID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return storage.OfflineSessions{}, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
|
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
|
||||||
|
@ -709,10 +761,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
|
||||||
&o.UserID, &o.ConnID, decoder(&o.Refresh),
|
&o.UserID, &o.ConnID, decoder(&o.Refresh),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
return o, err
|
||||||
return o, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
return o, fmt.Errorf("select offline session: %v", err)
|
|
||||||
}
|
}
|
||||||
return o, nil
|
return o, nil
|
||||||
}
|
}
|
||||||
|
@ -738,7 +787,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
|
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
|
||||||
return c.ExecTx(func(tx *trans) error {
|
err := c.ExecTx(func(tx *trans) error {
|
||||||
connector, err := getConnector(tx, id)
|
connector, err := getConnector(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -748,9 +797,10 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec(`
|
_, err = tx.Exec(`
|
||||||
update connector
|
update connector
|
||||||
set
|
set
|
||||||
type = $1,
|
type = $1,
|
||||||
name = $2,
|
name = $2,
|
||||||
resource_version = $3,
|
resource_version = $3,
|
||||||
|
@ -759,15 +809,26 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
|
||||||
`,
|
`,
|
||||||
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
|
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
return err
|
||||||
return fmt.Errorf("update connector: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update connector: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) GetConnector(id string) (storage.Connector, error) {
|
func (c *conn) GetConnector(id string) (storage.Connector, error) {
|
||||||
return getConnector(c, id)
|
connector, err := getConnector(c, id)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return storage.Connector{}, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.Connector{}, fmt.Errorf("get connector: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return connector, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getConnector(q querier, id string) (storage.Connector, error) {
|
func getConnector(q querier, id string) (storage.Connector, error) {
|
||||||
|
@ -784,10 +845,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) {
|
||||||
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
|
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
return c, err
|
||||||
return c, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
return c, fmt.Errorf("select connector: %v", err)
|
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
@ -805,12 +863,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) {
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
conn, err := scanConnector(rows)
|
conn, err := scanConnector(rows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("scan connector: %s", err)
|
||||||
}
|
}
|
||||||
connectors = append(connectors, conn)
|
connectors = append(connectors, conn)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("scan: %s", err)
|
||||||
}
|
}
|
||||||
return connectors, nil
|
return connectors, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,13 +44,14 @@ var (
|
||||||
// The "github.com/lib/pq" driver is the default flavor. All others are
|
// The "github.com/lib/pq" driver is the default flavor. All others are
|
||||||
// translations of this.
|
// translations of this.
|
||||||
flavorPostgres = flavor{
|
flavorPostgres = flavor{
|
||||||
// The default behavior for Postgres transactions is consistent reads, not consistent writes.
|
// The default behavior for Postgres transactions is consistent reads, not
|
||||||
// For each transaction opened, ensure it has the correct isolation level.
|
// consistent writes. For each transaction opened, ensure it has the
|
||||||
|
// correct isolation level.
|
||||||
//
|
//
|
||||||
// See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html
|
// See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html
|
||||||
//
|
//
|
||||||
// NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a
|
// Be careful not to wrap sql errors in the callback 'fn', otherwise
|
||||||
// session level didn't work for some edge cases. Might be something worth exploring.
|
// serialization failures will not be detected and retried.
|
||||||
executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error {
|
executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -66,6 +67,11 @@ var (
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fn(tx); err != nil {
|
if err := fn(tx); err != nil {
|
||||||
|
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" {
|
||||||
|
// serialization error; retry
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue