dex/storage/etcd/etcd.go
Justin Slowik 9c699b1028 Server integration test for Device Flow (#3)
Extracted test cases from OAuth2Code flow tests to reuse in device flow

deviceHandler unit tests to test specific device endpoints

Include client secret as an optional parameter for standards compliance

Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
2020-07-08 16:25:05 -04:00

639 lines
19 KiB
Go

package etcd
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"go.etcd.io/etcd/clientv3"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage"
)
const (
clientPrefix = "client/"
authCodePrefix = "auth_code/"
refreshTokenPrefix = "refresh_token/"
authRequestPrefix = "auth_req/"
passwordPrefix = "password/"
offlineSessionPrefix = "offline_session/"
connectorPrefix = "connector/"
keysName = "openid-connect-keys"
deviceRequestPrefix = "device_req/"
deviceTokenPrefix = "device_token/"
// defaultStorageTimeout will be applied to all storage's operations.
defaultStorageTimeout = 5 * time.Second
)
type conn struct {
db *clientv3.Client
logger log.Logger
}
func (c *conn) Close() error {
return c.db.Close()
}
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
authRequests, err := c.listAuthRequests(ctx)
if err != nil {
return result, err
}
var delErr error
for _, authRequest := range authRequests {
if now.After(authRequest.Expiry) {
if err := c.deleteKey(ctx, keyID(authRequestPrefix, authRequest.ID)); err != nil {
c.logger.Errorf("failed to delete auth request: %v", err)
delErr = fmt.Errorf("failed to delete auth request: %v", err)
}
result.AuthRequests++
}
}
if delErr != nil {
return result, delErr
}
authCodes, err := c.listAuthCodes(ctx)
if err != nil {
return result, err
}
for _, authCode := range authCodes {
if now.After(authCode.Expiry) {
if err := c.deleteKey(ctx, keyID(authCodePrefix, authCode.ID)); err != nil {
c.logger.Errorf("failed to delete auth code %v", err)
delErr = fmt.Errorf("failed to delete auth code: %v", err)
}
result.AuthCodes++
}
}
deviceRequests, err := c.listDeviceRequests(ctx)
if err != nil {
return result, err
}
for _, deviceRequest := range deviceRequests {
if now.After(deviceRequest.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil {
c.logger.Errorf("failed to delete device request %v", err)
delErr = fmt.Errorf("failed to delete device request: %v", err)
}
result.DeviceRequests++
}
}
deviceTokens, err := c.listDeviceTokens(ctx)
if err != nil {
return result, err
}
for _, deviceToken := range deviceTokens {
if now.After(deviceToken.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil {
c.logger.Errorf("failed to delete device token %v", err)
delErr = fmt.Errorf("failed to delete device token: %v", err)
}
result.DeviceTokens++
}
}
return result, delErr
}
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a))
}
func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
var req AuthRequest
if err = c.getKey(ctx, keyID(authRequestPrefix, id), &req); err != nil {
return
}
return toStorageAuthRequest(req), nil
}
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(authRequestPrefix, id), func(currentValue []byte) ([]byte, error) {
var current AuthRequest
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(toStorageAuthRequest(current))
if err != nil {
return nil, err
}
return json.Marshal(fromStorageAuthRequest(updated))
})
}
func (c *conn) DeleteAuthRequest(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(authRequestPrefix, id))
}
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a))
}
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(authCodePrefix, id), &a)
return a, err
}
func (c *conn) DeleteAuthCode(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(authCodePrefix, id))
}
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r))
}
func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
var token RefreshToken
if err = c.getKey(ctx, keyID(refreshTokenPrefix, id), &token); err != nil {
return
}
return toStorageRefreshToken(token), nil
}
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) {
var current RefreshToken
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(toStorageRefreshToken(current))
if err != nil {
return nil, err
}
return json.Marshal(fromStorageRefreshToken(updated))
})
}
func (c *conn) DeleteRefresh(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(refreshTokenPrefix, id))
}
func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, refreshTokenPrefix, clientv3.WithPrefix())
if err != nil {
return tokens, err
}
for _, v := range res.Kvs {
var token RefreshToken
if err = json.Unmarshal(v.Value, &token); err != nil {
return tokens, err
}
tokens = append(tokens, toStorageRefreshToken(token))
}
return tokens, nil
}
func (c *conn) CreateClient(cli storage.Client) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli)
}
func (c *conn) GetClient(id string) (cli storage.Client, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(clientPrefix, id), &cli)
return cli, err
}
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(clientPrefix, id), func(currentValue []byte) ([]byte, error) {
var current storage.Client
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(current)
if err != nil {
return nil, err
}
return json.Marshal(updated)
})
}
func (c *conn) DeleteClient(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(clientPrefix, id))
}
func (c *conn) ListClients() (clients []storage.Client, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, clientPrefix, clientv3.WithPrefix())
if err != nil {
return clients, err
}
for _, v := range res.Kvs {
var cli storage.Client
if err = json.Unmarshal(v.Value, &cli); err != nil {
return clients, err
}
clients = append(clients, cli)
}
return clients, nil
}
func (c *conn) CreatePassword(p storage.Password) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p)
}
func (c *conn) GetPassword(email string) (p storage.Password, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyEmail(passwordPrefix, email), &p)
return p, err
}
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyEmail(passwordPrefix, email), func(currentValue []byte) ([]byte, error) {
var current storage.Password
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(current)
if err != nil {
return nil, err
}
return json.Marshal(updated)
})
}
func (c *conn) DeletePassword(email string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyEmail(passwordPrefix, email))
}
func (c *conn) ListPasswords() (passwords []storage.Password, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, passwordPrefix, clientv3.WithPrefix())
if err != nil {
return passwords, err
}
for _, v := range res.Kvs {
var p storage.Password
if err = json.Unmarshal(v.Value, &p); err != nil {
return passwords, err
}
passwords = append(passwords, p)
}
return passwords, nil
}
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keySession(offlineSessionPrefix, s.UserID, s.ConnID), fromStorageOfflineSessions(s))
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keySession(offlineSessionPrefix, userID, connID), func(currentValue []byte) ([]byte, error) {
var current OfflineSessions
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(toStorageOfflineSessions(current))
if err != nil {
return nil, err
}
return json.Marshal(fromStorageOfflineSessions(updated))
})
}
func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.OfflineSessions, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
var os OfflineSessions
if err = c.getKey(ctx, keySession(offlineSessionPrefix, userID, connID), &os); err != nil {
return
}
return toStorageOfflineSessions(os), nil
}
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keySession(offlineSessionPrefix, userID, connID))
}
func (c *conn) CreateConnector(connector storage.Connector) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector)
}
func (c *conn) GetConnector(id string) (conn storage.Connector, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(connectorPrefix, id), &conn)
return conn, err
}
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(connectorPrefix, id), func(currentValue []byte) ([]byte, error) {
var current storage.Connector
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(current)
if err != nil {
return nil, err
}
return json.Marshal(updated)
})
}
func (c *conn) DeleteConnector(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(connectorPrefix, id))
}
func (c *conn) ListConnectors() (connectors []storage.Connector, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, connectorPrefix, clientv3.WithPrefix())
if err != nil {
return nil, err
}
for _, v := range res.Kvs {
var c storage.Connector
if err = json.Unmarshal(v.Value, &c); err != nil {
return nil, err
}
connectors = append(connectors, c)
}
return connectors, nil
}
func (c *conn) GetKeys() (keys storage.Keys, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, keysName)
if err != nil {
return keys, err
}
if res.Count > 0 && len(res.Kvs) > 0 {
err = json.Unmarshal(res.Kvs[0].Value, &keys)
}
return keys, err
}
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keysName, func(currentValue []byte) ([]byte, error) {
var current storage.Keys
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(current)
if err != nil {
return nil, err
}
return json.Marshal(updated)
})
}
func (c *conn) deleteKey(ctx context.Context, key string) error {
res, err := c.db.Delete(ctx, key)
if err != nil {
return err
}
if res.Deleted == 0 {
return storage.ErrNotFound
}
return nil
}
func (c *conn) getKey(ctx context.Context, key string, value interface{}) error {
r, err := c.db.Get(ctx, key)
if err != nil {
return err
}
if r.Count == 0 {
return storage.ErrNotFound
}
return json.Unmarshal(r.Kvs[0].Value, value)
}
func (c *conn) listAuthRequests(ctx context.Context) (reqs []AuthRequest, err error) {
res, err := c.db.Get(ctx, authRequestPrefix, clientv3.WithPrefix())
if err != nil {
return reqs, err
}
for _, v := range res.Kvs {
var r AuthRequest
if err = json.Unmarshal(v.Value, &r); err != nil {
return reqs, err
}
reqs = append(reqs, r)
}
return reqs, nil
}
func (c *conn) listAuthCodes(ctx context.Context) (codes []AuthCode, err error) {
res, err := c.db.Get(ctx, authCodePrefix, clientv3.WithPrefix())
if err != nil {
return codes, err
}
for _, v := range res.Kvs {
var c AuthCode
if err = json.Unmarshal(v.Value, &c); err != nil {
return codes, err
}
codes = append(codes, c)
}
return codes, nil
}
func (c *conn) txnCreate(ctx context.Context, key string, value interface{}) error {
b, err := json.Marshal(value)
if err != nil {
return err
}
txn := c.db.Txn(ctx)
res, err := txn.
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
Then(clientv3.OpPut(key, string(b))).
Commit()
if err != nil {
return err
}
if !res.Succeeded {
return storage.ErrAlreadyExists
}
return nil
}
func (c *conn) txnUpdate(ctx context.Context, key string, update func(current []byte) ([]byte, error)) error {
getResp, err := c.db.Get(ctx, key)
if err != nil {
return err
}
var currentValue []byte
var modRev int64
if len(getResp.Kvs) > 0 {
currentValue = getResp.Kvs[0].Value
modRev = getResp.Kvs[0].ModRevision
}
updatedValue, err := update(currentValue)
if err != nil {
return err
}
txn := c.db.Txn(ctx)
updateResp, err := txn.
If(clientv3.Compare(clientv3.ModRevision(key), "=", modRev)).
Then(clientv3.OpPut(key, string(updatedValue))).
Commit()
if err != nil {
return err
}
if !updateResp.Succeeded {
return fmt.Errorf("failed to update key=%q: concurrent conflicting update happened", key)
}
return nil
}
func keyID(prefix, id string) string { return prefix + id }
func keyEmail(prefix, email string) string { return prefix + strings.ToLower(email) }
func keySession(prefix, userID, connID string) string {
return prefix + strings.ToLower(userID+"|"+connID)
}
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
}
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r)
return r, err
}
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
if err != nil {
return requests, err
}
for _, v := range res.Kvs {
var r DeviceRequest
if err = json.Unmarshal(v.Value, &r); err != nil {
return requests, err
}
requests = append(requests, r)
}
return requests, nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t))
}
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
return t, err
}
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
if err != nil {
return deviceTokens, err
}
for _, v := range res.Kvs {
var dt DeviceToken
if err = json.Unmarshal(v.Value, &dt); err != nil {
return deviceTokens, err
}
deviceTokens = append(deviceTokens, dt)
}
return deviceTokens, nil
}
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
var current DeviceToken
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(toStorageDeviceToken(current))
if err != nil {
return nil, err
}
return json.Marshal(fromStorageDeviceToken(updated))
})
}