dex/storage/sql/crud.go
Justin Slowik 9c699b1028 Server integration test for Device Flow (#3)
Extracted test cases from OAuth2Code flow tests to reuse in device flow

deviceHandler unit tests to test specific device endpoints

Include client secret as an optional parameter for standards compliance

Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
2020-07-08 16:25:05 -04:00

997 lines
25 KiB
Go

package sql
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/dexidp/dex/storage"
)
// TODO(ericchiang): The update, insert, and select methods queries are all
// very repetitive. Consider creating them programmatically.
// keysRowID is the ID of the only row we expect to populate the "keys" table.
const keysRowID = "keys"
// encoder wraps the underlying value in a JSON marshaler which is automatically
// called by the database/sql package.
//
// s := []string{"planes", "bears"}
// err := db.Exec(`insert into t1 (id, things) values (1, $1)`, encoder(s))
// if err != nil {
// // handle error
// }
//
// var r []byte
// err = db.QueryRow(`select things from t1 where id = 1;`).Scan(&r)
// if err != nil {
// // handle error
// }
// fmt.Printf("%s\n", r) // ["planes","bears"]
//
func encoder(i interface{}) driver.Valuer {
return jsonEncoder{i}
}
// decoder wraps the underlying value in a JSON unmarshaler which can then be passed
// to a database Scan() method.
func decoder(i interface{}) sql.Scanner {
return jsonDecoder{i}
}
type jsonEncoder struct {
i interface{}
}
func (j jsonEncoder) Value() (driver.Value, error) {
b, err := json.Marshal(j.i)
if err != nil {
return nil, fmt.Errorf("marshal: %v", err)
}
return b, nil
}
type jsonDecoder struct {
i interface{}
}
func (j jsonDecoder) Scan(dest interface{}) error {
if dest == nil {
return errors.New("nil value")
}
b, ok := dest.([]byte)
if !ok {
return fmt.Errorf("expected []byte got %T", dest)
}
if err := json.Unmarshal(b, &j.i); err != nil {
return fmt.Errorf("unmarshal: %v", err)
}
return nil
}
// Abstract conn vs trans.
type querier interface {
QueryRow(query string, args ...interface{}) *sql.Row
}
// Abstract row vs rows.
type scanner interface {
Scan(dest ...interface{}) error
}
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
r, err := c.Exec(`delete from auth_request where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc auth_request: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.AuthRequests = n
}
r, err = c.Exec(`delete from auth_code where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc auth_code: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.AuthCodes = n
}
r, err = c.Exec(`delete from device_request where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc device_request: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.DeviceRequests = n
}
r, err = c.Exec(`delete from device_token where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc device_token: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.DeviceTokens = n
}
return
}
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
_, err := c.Exec(`
insert into auth_request (
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
)
values (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
);
`,
a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn,
a.Claims.UserID, a.Claims.Username, a.Claims.PreferredUsername,
a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData,
a.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert auth request: %v", err)
}
return nil
}
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getAuthRequest(tx, id)
if err != nil {
return err
}
a, err := updater(r)
if err != nil {
return err
}
_, err = tx.Exec(`
update auth_request
set
client_id = $1, response_types = $2, scopes = $3, redirect_uri = $4,
nonce = $5, state = $6, force_approval_prompt = $7, logged_in = $8,
claims_user_id = $9, claims_username = $10, claims_preferred_username = $11,
claims_email = $12, claims_email_verified = $13,
claims_groups = $14,
connector_id = $15, connector_data = $16,
expiry = $17
where id = $18;
`,
a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn,
a.Claims.UserID, a.Claims.Username, a.Claims.PreferredUsername,
a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData,
a.Expiry, r.ID,
)
if err != nil {
return fmt.Errorf("update auth request: %v", err)
}
return nil
})
}
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
return getAuthRequest(c, id)
}
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(`
select
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, expiry
from auth_request where id = $1;
`, id).Scan(
&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State,
&a.ForceApprovalPrompt, &a.LoggedIn,
&a.Claims.UserID, &a.Claims.Username, &a.Claims.PreferredUsername,
&a.Claims.Email, &a.Claims.EmailVerified,
decoder(&a.Claims.Groups),
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select auth request: %v", err)
}
return a, nil
}
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
_, err := c.Exec(`
insert into auth_code (
id, client_id, scopes, nonce, redirect_uri,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
`,
a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID,
a.Claims.Username, a.Claims.PreferredUsername, a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert auth code: %v", err)
}
return nil
}
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
err = c.QueryRow(`
select
id, client_id, scopes, nonce, redirect_uri,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
from auth_code where id = $1;
`, id).Scan(
&a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID,
&a.Claims.Username, &a.Claims.PreferredUsername, &a.Claims.Email, &a.Claims.EmailVerified,
decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select auth code: %v", err)
}
return a, nil
}
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
_, err := c.Exec(`
insert into refresh_token (
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
token, created_at, last_used
)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15);
`,
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert refresh_token: %v", err)
}
return nil
}
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update refresh_token
set
client_id = $1,
scopes = $2,
nonce = $3,
claims_user_id = $4,
claims_username = $5,
claims_preferred_username = $6,
claims_email = $7,
claims_email_verified = $8,
claims_groups = $9,
connector_id = $10,
connector_data = $11,
token = $12,
created_at = $13,
last_used = $14
where
id = $15
`,
r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, id,
)
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
})
}
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
return getRefresh(c, id)
}
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
return scanRefresh(q.QueryRow(`
select
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data,
token, created_at, last_used
from refresh_token where id = $1;
`, id))
}
func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
rows, err := c.Query(`
select
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
token, created_at, last_used
from refresh_token;
`)
if err != nil {
return nil, fmt.Errorf("query: %v", err)
}
var tokens []storage.RefreshToken
for rows.Next() {
r, err := scanRefresh(rows)
if err != nil {
return nil, err
}
tokens = append(tokens, r)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %v", err)
}
return tokens, nil
}
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
err = s.Scan(
&r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
&r.Claims.UserID, &r.Claims.Username, &r.Claims.PreferredUsername,
&r.Claims.Email, &r.Claims.EmailVerified,
decoder(&r.Claims.Groups),
&r.ConnectorID, &r.ConnectorData,
&r.Token, &r.CreatedAt, &r.LastUsed,
)
if err != nil {
if err == sql.ErrNoRows {
return r, storage.ErrNotFound
}
return r, fmt.Errorf("scan refresh_token: %v", err)
}
return r, nil
}
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
return c.ExecTx(func(tx *trans) error {
firstUpdate := false
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
// server. Test this, and consider adding a COUNT() command beforehand.
old, err := getKeys(tx)
if err != nil {
if err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
}
firstUpdate = true
old = storage.Keys{}
}
nk, err := updater(old)
if err != nil {
return err
}
if firstUpdate {
_, err = tx.Exec(`
insert into keys (
id, verification_keys, signing_key, signing_key_pub, next_rotation
)
values ($1, $2, $3, $4, $5);
`,
keysRowID, encoder(nk.VerificationKeys), encoder(nk.SigningKey),
encoder(nk.SigningKeyPub), nk.NextRotation,
)
if err != nil {
return fmt.Errorf("insert: %v", err)
}
} else {
_, err = tx.Exec(`
update keys
set
verification_keys = $1,
signing_key = $2,
signing_key_pub = $3,
next_rotation = $4
where id = $5;
`,
encoder(nk.VerificationKeys), encoder(nk.SigningKey),
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
)
if err != nil {
return fmt.Errorf("update: %v", err)
}
}
return nil
})
}
func (c *conn) GetKeys() (keys storage.Keys, err error) {
return getKeys(c)
}
func getKeys(q querier) (keys storage.Keys, err error) {
err = q.QueryRow(`
select
verification_keys, signing_key, signing_key_pub, next_rotation
from keys
where id=$1
`, keysRowID).Scan(
decoder(&keys.VerificationKeys), decoder(&keys.SigningKey),
decoder(&keys.SigningKeyPub), &keys.NextRotation,
)
if err != nil {
if err == sql.ErrNoRows {
return keys, storage.ErrNotFound
}
return keys, fmt.Errorf("query keys: %v", err)
}
return keys, nil
}
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
return c.ExecTx(func(tx *trans) error {
cli, err := getClient(tx, id)
if err != nil {
return err
}
nc, err := updater(cli)
if err != nil {
return err
}
_, err = tx.Exec(`
update client
set
secret = $1,
redirect_uris = $2,
trusted_peers = $3,
public = $4,
name = $5,
logo_url = $6
where id = $7;
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
)
if err != nil {
return fmt.Errorf("update client: %v", err)
}
return nil
})
}
func (c *conn) CreateClient(cli storage.Client) error {
_, err := c.Exec(`
insert into client (
id, secret, redirect_uris, trusted_peers, public, name, logo_url
)
values ($1, $2, $3, $4, $5, $6, $7);
`,
cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers),
cli.Public, cli.Name, cli.LogoURL,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert client: %v", err)
}
return nil
}
func getClient(q querier, id string) (storage.Client, error) {
return scanClient(q.QueryRow(`
select
id, secret, redirect_uris, trusted_peers, public, name, logo_url
from client where id = $1;
`, id))
}
func (c *conn) GetClient(id string) (storage.Client, error) {
return getClient(c, id)
}
func (c *conn) ListClients() ([]storage.Client, error) {
rows, err := c.Query(`
select
id, secret, redirect_uris, trusted_peers, public, name, logo_url
from client;
`)
if err != nil {
return nil, err
}
var clients []storage.Client
for rows.Next() {
cli, err := scanClient(rows)
if err != nil {
return nil, err
}
clients = append(clients, cli)
}
if err := rows.Err(); err != nil {
return nil, err
}
return clients, nil
}
func scanClient(s scanner) (cli storage.Client, err error) {
err = s.Scan(
&cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers),
&cli.Public, &cli.Name, &cli.LogoURL,
)
if err != nil {
if err == sql.ErrNoRows {
return cli, storage.ErrNotFound
}
return cli, fmt.Errorf("get client: %v", err)
}
return cli, nil
}
func (c *conn) CreatePassword(p storage.Password) error {
p.Email = strings.ToLower(p.Email)
_, err := c.Exec(`
insert into password (
email, hash, username, user_id
)
values (
$1, $2, $3, $4
);
`,
p.Email, p.Hash, p.Username, p.UserID,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert password: %v", err)
}
return nil
}
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
return c.ExecTx(func(tx *trans) error {
p, err := getPassword(tx, email)
if err != nil {
return err
}
np, err := updater(p)
if err != nil {
return err
}
_, err = tx.Exec(`
update password
set
hash = $1, username = $2, user_id = $3
where email = $4;
`,
np.Hash, np.Username, np.UserID, p.Email,
)
if err != nil {
return fmt.Errorf("update password: %v", err)
}
return nil
})
}
func (c *conn) GetPassword(email string) (storage.Password, error) {
return getPassword(c, email)
}
func getPassword(q querier, email string) (p storage.Password, err error) {
return scanPassword(q.QueryRow(`
select
email, hash, username, user_id
from password where email = $1;
`, strings.ToLower(email)))
}
func (c *conn) ListPasswords() ([]storage.Password, error) {
rows, err := c.Query(`
select
email, hash, username, user_id
from password;
`)
if err != nil {
return nil, err
}
var passwords []storage.Password
for rows.Next() {
p, err := scanPassword(rows)
if err != nil {
return nil, err
}
passwords = append(passwords, p)
}
if err := rows.Err(); err != nil {
return nil, err
}
return passwords, nil
}
func scanPassword(s scanner) (p storage.Password, err error) {
err = s.Scan(
&p.Email, &p.Hash, &p.Username, &p.UserID,
)
if err != nil {
if err == sql.ErrNoRows {
return p, storage.ErrNotFound
}
return p, fmt.Errorf("select password: %v", err)
}
return p, nil
}
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
_, err := c.Exec(`
insert into offline_session (
user_id, conn_id, refresh, connector_data
)
values (
$1, $2, $3, $4
);
`,
s.UserID, s.ConnID, encoder(s.Refresh), s.ConnectorData,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert offline session: %v", err)
}
return nil
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID)
if err != nil {
return err
}
newSession, err := updater(s)
if err != nil {
return err
}
_, err = tx.Exec(`
update offline_session
set
refresh = $1,
connector_data = $2
where user_id = $3 AND conn_id = $4;
`,
encoder(newSession.Refresh), newSession.ConnectorData, s.UserID, s.ConnID,
)
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
})
}
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
return getOfflineSessions(c, userID, connID)
}
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
return scanOfflineSessions(q.QueryRow(`
select
user_id, conn_id, refresh, connector_data
from offline_session
where user_id = $1 AND conn_id = $2;
`, userID, connID))
}
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
err = s.Scan(
&o.UserID, &o.ConnID, decoder(&o.Refresh), &o.ConnectorData,
)
if err != nil {
if err == sql.ErrNoRows {
return o, storage.ErrNotFound
}
return o, fmt.Errorf("select offline session: %v", err)
}
return o, nil
}
func (c *conn) CreateConnector(connector storage.Connector) error {
_, err := c.Exec(`
insert into connector (
id, type, name, resource_version, config
)
values (
$1, $2, $3, $4, $5
);
`,
connector.ID, connector.Type, connector.Name, connector.ResourceVersion, connector.Config,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert connector: %v", err)
}
return nil
}
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
return c.ExecTx(func(tx *trans) error {
connector, err := getConnector(tx, id)
if err != nil {
return err
}
newConn, err := updater(connector)
if err != nil {
return err
}
_, err = tx.Exec(`
update connector
set
type = $1,
name = $2,
resource_version = $3,
config = $4
where id = $5;
`,
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
)
if err != nil {
return fmt.Errorf("update connector: %v", err)
}
return nil
})
}
func (c *conn) GetConnector(id string) (storage.Connector, error) {
return getConnector(c, id)
}
func getConnector(q querier, id string) (storage.Connector, error) {
return scanConnector(q.QueryRow(`
select
id, type, name, resource_version, config
from connector
where id = $1;
`, id))
}
func scanConnector(s scanner) (c storage.Connector, err error) {
err = s.Scan(
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
)
if err != nil {
if err == sql.ErrNoRows {
return c, storage.ErrNotFound
}
return c, fmt.Errorf("select connector: %v", err)
}
return c, nil
}
func (c *conn) ListConnectors() ([]storage.Connector, error) {
rows, err := c.Query(`
select
id, type, name, resource_version, config
from connector;
`)
if err != nil {
return nil, err
}
var connectors []storage.Connector
for rows.Next() {
conn, err := scanConnector(rows)
if err != nil {
return nil, err
}
connectors = append(connectors, conn)
}
if err := rows.Err(); err != nil {
return nil, err
}
return connectors, nil
}
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) }
func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) }
func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", "id", id) }
func (c *conn) DeletePassword(email string) error {
return c.delete("password", "email", strings.ToLower(email))
}
func (c *conn) DeleteConnector(id string) error { return c.delete("connector", "id", id) }
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
if err != nil {
return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID)
}
// For now mandate that the driver implements RowsAffected. If we ever need to support
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %v", err)
}
if n < 1 {
return storage.ErrNotFound
}
return nil
}
// Do NOT call directly. Does not escape table.
func (c *conn) delete(table, field, id string) error {
result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id)
if err != nil {
return fmt.Errorf("delete %s: %v", table, id)
}
// For now mandate that the driver implements RowsAffected. If we ever need to support
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %v", err)
}
if n < 1 {
return storage.ErrNotFound
}
return nil
}
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
_, err := c.Exec(`
insert into device_request (
user_code, device_code, client_id, client_secret, scopes, expiry
)
values (
$1, $2, $3, $4, $5, $6
);`,
d.UserCode, d.DeviceCode, d.ClientID, d.ClientSecret, encoder(d.Scopes), d.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert device request: %v", err)
}
return nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
_, err := c.Exec(`
insert into device_token (
device_code, status, token, expiry, last_request, poll_interval
)
values (
$1, $2, $3, $4, $5, $6
);`,
t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert device token: %v", err)
}
return nil
}
func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
return getDeviceRequest(c, userCode)
}
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
err = q.QueryRow(`
select
device_code, client_id, client_secret, scopes, expiry
from device_request where user_code = $1;
`, userCode).Scan(
&d.DeviceCode, &d.ClientID, &d.ClientSecret, decoder(&d.Scopes), &d.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return d, storage.ErrNotFound
}
return d, fmt.Errorf("select device token: %v", err)
}
d.UserCode = userCode
return d, nil
}
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
return getDeviceToken(c, deviceCode)
}
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
err = q.QueryRow(`
select
status, token, expiry, last_request, poll_interval
from device_token where device_code = $1;
`, deviceCode).Scan(
&a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select device token: %v", err)
}
a.DeviceCode = deviceCode
return a, nil
}
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getDeviceToken(tx, deviceCode)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update device_token
set
status = $1,
token = $2,
last_request = $3,
poll_interval = $4
where
device_code = $5
`,
r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode,
)
if err != nil {
return fmt.Errorf("update device token: %v", err)
}
return nil
})
}