forked from mystiq/dex
9c699b1028
Extracted test cases from OAuth2Code flow tests to reuse in device flow deviceHandler unit tests to test specific device endpoints Include client secret as an optional parameter for standards compliance Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
639 lines
19 KiB
Go
639 lines
19 KiB
Go
package etcd
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.etcd.io/etcd/clientv3"
|
|
|
|
"github.com/dexidp/dex/pkg/log"
|
|
"github.com/dexidp/dex/storage"
|
|
)
|
|
|
|
const (
|
|
clientPrefix = "client/"
|
|
authCodePrefix = "auth_code/"
|
|
refreshTokenPrefix = "refresh_token/"
|
|
authRequestPrefix = "auth_req/"
|
|
passwordPrefix = "password/"
|
|
offlineSessionPrefix = "offline_session/"
|
|
connectorPrefix = "connector/"
|
|
keysName = "openid-connect-keys"
|
|
deviceRequestPrefix = "device_req/"
|
|
deviceTokenPrefix = "device_token/"
|
|
|
|
// defaultStorageTimeout will be applied to all storage's operations.
|
|
defaultStorageTimeout = 5 * time.Second
|
|
)
|
|
|
|
type conn struct {
|
|
db *clientv3.Client
|
|
logger log.Logger
|
|
}
|
|
|
|
func (c *conn) Close() error {
|
|
return c.db.Close()
|
|
}
|
|
|
|
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
authRequests, err := c.listAuthRequests(ctx)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
var delErr error
|
|
for _, authRequest := range authRequests {
|
|
if now.After(authRequest.Expiry) {
|
|
if err := c.deleteKey(ctx, keyID(authRequestPrefix, authRequest.ID)); err != nil {
|
|
c.logger.Errorf("failed to delete auth request: %v", err)
|
|
delErr = fmt.Errorf("failed to delete auth request: %v", err)
|
|
}
|
|
result.AuthRequests++
|
|
}
|
|
}
|
|
if delErr != nil {
|
|
return result, delErr
|
|
}
|
|
|
|
authCodes, err := c.listAuthCodes(ctx)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
for _, authCode := range authCodes {
|
|
if now.After(authCode.Expiry) {
|
|
if err := c.deleteKey(ctx, keyID(authCodePrefix, authCode.ID)); err != nil {
|
|
c.logger.Errorf("failed to delete auth code %v", err)
|
|
delErr = fmt.Errorf("failed to delete auth code: %v", err)
|
|
}
|
|
result.AuthCodes++
|
|
}
|
|
}
|
|
|
|
deviceRequests, err := c.listDeviceRequests(ctx)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
for _, deviceRequest := range deviceRequests {
|
|
if now.After(deviceRequest.Expiry) {
|
|
if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil {
|
|
c.logger.Errorf("failed to delete device request %v", err)
|
|
delErr = fmt.Errorf("failed to delete device request: %v", err)
|
|
}
|
|
result.DeviceRequests++
|
|
}
|
|
}
|
|
|
|
deviceTokens, err := c.listDeviceTokens(ctx)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
for _, deviceToken := range deviceTokens {
|
|
if now.After(deviceToken.Expiry) {
|
|
if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil {
|
|
c.logger.Errorf("failed to delete device token %v", err)
|
|
delErr = fmt.Errorf("failed to delete device token: %v", err)
|
|
}
|
|
result.DeviceTokens++
|
|
}
|
|
}
|
|
return result, delErr
|
|
}
|
|
|
|
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a))
|
|
}
|
|
|
|
func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
var req AuthRequest
|
|
if err = c.getKey(ctx, keyID(authRequestPrefix, id), &req); err != nil {
|
|
return
|
|
}
|
|
return toStorageAuthRequest(req), nil
|
|
}
|
|
|
|
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyID(authRequestPrefix, id), func(currentValue []byte) ([]byte, error) {
|
|
var current AuthRequest
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(toStorageAuthRequest(current))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(fromStorageAuthRequest(updated))
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeleteAuthRequest(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(authRequestPrefix, id))
|
|
}
|
|
|
|
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a))
|
|
}
|
|
|
|
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyID(authCodePrefix, id), &a)
|
|
return a, err
|
|
}
|
|
|
|
func (c *conn) DeleteAuthCode(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(authCodePrefix, id))
|
|
}
|
|
|
|
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r))
|
|
}
|
|
|
|
func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
var token RefreshToken
|
|
if err = c.getKey(ctx, keyID(refreshTokenPrefix, id), &token); err != nil {
|
|
return
|
|
}
|
|
return toStorageRefreshToken(token), nil
|
|
}
|
|
|
|
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) {
|
|
var current RefreshToken
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(toStorageRefreshToken(current))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(fromStorageRefreshToken(updated))
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeleteRefresh(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(refreshTokenPrefix, id))
|
|
}
|
|
|
|
func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, refreshTokenPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return tokens, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var token RefreshToken
|
|
if err = json.Unmarshal(v.Value, &token); err != nil {
|
|
return tokens, err
|
|
}
|
|
tokens = append(tokens, toStorageRefreshToken(token))
|
|
}
|
|
return tokens, nil
|
|
}
|
|
|
|
func (c *conn) CreateClient(cli storage.Client) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli)
|
|
}
|
|
|
|
func (c *conn) GetClient(id string) (cli storage.Client, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyID(clientPrefix, id), &cli)
|
|
return cli, err
|
|
}
|
|
|
|
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyID(clientPrefix, id), func(currentValue []byte) ([]byte, error) {
|
|
var current storage.Client
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(current)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(updated)
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeleteClient(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(clientPrefix, id))
|
|
}
|
|
|
|
func (c *conn) ListClients() (clients []storage.Client, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, clientPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return clients, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var cli storage.Client
|
|
if err = json.Unmarshal(v.Value, &cli); err != nil {
|
|
return clients, err
|
|
}
|
|
clients = append(clients, cli)
|
|
}
|
|
return clients, nil
|
|
}
|
|
|
|
func (c *conn) CreatePassword(p storage.Password) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p)
|
|
}
|
|
|
|
func (c *conn) GetPassword(email string) (p storage.Password, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyEmail(passwordPrefix, email), &p)
|
|
return p, err
|
|
}
|
|
|
|
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyEmail(passwordPrefix, email), func(currentValue []byte) ([]byte, error) {
|
|
var current storage.Password
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(current)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(updated)
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeletePassword(email string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyEmail(passwordPrefix, email))
|
|
}
|
|
|
|
func (c *conn) ListPasswords() (passwords []storage.Password, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, passwordPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return passwords, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var p storage.Password
|
|
if err = json.Unmarshal(v.Value, &p); err != nil {
|
|
return passwords, err
|
|
}
|
|
passwords = append(passwords, p)
|
|
}
|
|
return passwords, nil
|
|
}
|
|
|
|
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keySession(offlineSessionPrefix, s.UserID, s.ConnID), fromStorageOfflineSessions(s))
|
|
}
|
|
|
|
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keySession(offlineSessionPrefix, userID, connID), func(currentValue []byte) ([]byte, error) {
|
|
var current OfflineSessions
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(toStorageOfflineSessions(current))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(fromStorageOfflineSessions(updated))
|
|
})
|
|
}
|
|
|
|
func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.OfflineSessions, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
var os OfflineSessions
|
|
if err = c.getKey(ctx, keySession(offlineSessionPrefix, userID, connID), &os); err != nil {
|
|
return
|
|
}
|
|
return toStorageOfflineSessions(os), nil
|
|
}
|
|
|
|
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keySession(offlineSessionPrefix, userID, connID))
|
|
}
|
|
|
|
func (c *conn) CreateConnector(connector storage.Connector) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector)
|
|
}
|
|
|
|
func (c *conn) GetConnector(id string) (conn storage.Connector, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyID(connectorPrefix, id), &conn)
|
|
return conn, err
|
|
}
|
|
|
|
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyID(connectorPrefix, id), func(currentValue []byte) ([]byte, error) {
|
|
var current storage.Connector
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(current)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(updated)
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeleteConnector(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(connectorPrefix, id))
|
|
}
|
|
|
|
func (c *conn) ListConnectors() (connectors []storage.Connector, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, connectorPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var c storage.Connector
|
|
if err = json.Unmarshal(v.Value, &c); err != nil {
|
|
return nil, err
|
|
}
|
|
connectors = append(connectors, c)
|
|
}
|
|
return connectors, nil
|
|
}
|
|
|
|
func (c *conn) GetKeys() (keys storage.Keys, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, keysName)
|
|
if err != nil {
|
|
return keys, err
|
|
}
|
|
if res.Count > 0 && len(res.Kvs) > 0 {
|
|
err = json.Unmarshal(res.Kvs[0].Value, &keys)
|
|
}
|
|
return keys, err
|
|
}
|
|
|
|
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keysName, func(currentValue []byte) ([]byte, error) {
|
|
var current storage.Keys
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(current)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(updated)
|
|
})
|
|
}
|
|
|
|
func (c *conn) deleteKey(ctx context.Context, key string) error {
|
|
res, err := c.db.Delete(ctx, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if res.Deleted == 0 {
|
|
return storage.ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *conn) getKey(ctx context.Context, key string, value interface{}) error {
|
|
r, err := c.db.Get(ctx, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if r.Count == 0 {
|
|
return storage.ErrNotFound
|
|
}
|
|
return json.Unmarshal(r.Kvs[0].Value, value)
|
|
}
|
|
|
|
func (c *conn) listAuthRequests(ctx context.Context) (reqs []AuthRequest, err error) {
|
|
res, err := c.db.Get(ctx, authRequestPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return reqs, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var r AuthRequest
|
|
if err = json.Unmarshal(v.Value, &r); err != nil {
|
|
return reqs, err
|
|
}
|
|
reqs = append(reqs, r)
|
|
}
|
|
return reqs, nil
|
|
}
|
|
|
|
func (c *conn) listAuthCodes(ctx context.Context) (codes []AuthCode, err error) {
|
|
res, err := c.db.Get(ctx, authCodePrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return codes, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var c AuthCode
|
|
if err = json.Unmarshal(v.Value, &c); err != nil {
|
|
return codes, err
|
|
}
|
|
codes = append(codes, c)
|
|
}
|
|
return codes, nil
|
|
}
|
|
|
|
func (c *conn) txnCreate(ctx context.Context, key string, value interface{}) error {
|
|
b, err := json.Marshal(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
txn := c.db.Txn(ctx)
|
|
res, err := txn.
|
|
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
|
|
Then(clientv3.OpPut(key, string(b))).
|
|
Commit()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !res.Succeeded {
|
|
return storage.ErrAlreadyExists
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *conn) txnUpdate(ctx context.Context, key string, update func(current []byte) ([]byte, error)) error {
|
|
getResp, err := c.db.Get(ctx, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var currentValue []byte
|
|
var modRev int64
|
|
if len(getResp.Kvs) > 0 {
|
|
currentValue = getResp.Kvs[0].Value
|
|
modRev = getResp.Kvs[0].ModRevision
|
|
}
|
|
|
|
updatedValue, err := update(currentValue)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
txn := c.db.Txn(ctx)
|
|
updateResp, err := txn.
|
|
If(clientv3.Compare(clientv3.ModRevision(key), "=", modRev)).
|
|
Then(clientv3.OpPut(key, string(updatedValue))).
|
|
Commit()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !updateResp.Succeeded {
|
|
return fmt.Errorf("failed to update key=%q: concurrent conflicting update happened", key)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func keyID(prefix, id string) string { return prefix + id }
|
|
func keyEmail(prefix, email string) string { return prefix + strings.ToLower(email) }
|
|
func keySession(prefix, userID, connID string) string {
|
|
return prefix + strings.ToLower(userID+"|"+connID)
|
|
}
|
|
|
|
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
|
|
}
|
|
|
|
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r)
|
|
return r, err
|
|
}
|
|
|
|
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
|
|
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return requests, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var r DeviceRequest
|
|
if err = json.Unmarshal(v.Value, &r); err != nil {
|
|
return requests, err
|
|
}
|
|
requests = append(requests, r)
|
|
}
|
|
return requests, nil
|
|
}
|
|
|
|
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t))
|
|
}
|
|
|
|
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
|
|
return t, err
|
|
}
|
|
|
|
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
|
|
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return deviceTokens, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var dt DeviceToken
|
|
if err = json.Unmarshal(v.Value, &dt); err != nil {
|
|
return deviceTokens, err
|
|
}
|
|
deviceTokens = append(deviceTokens, dt)
|
|
}
|
|
return deviceTokens, nil
|
|
}
|
|
|
|
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
|
|
var current DeviceToken
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(toStorageDeviceToken(current))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(fromStorageDeviceToken(updated))
|
|
})
|
|
}
|