forked from mystiq/dex
ca114f7812
This patch adds etcd storage implementation. This should be useful in environments where - we dont want to depends on a separate, hard to maintain SQL cluster - we dont want to incur the overhead of talking to kubernetes apiservers - kubernetes is not available yet, or if kubernetes depends on dex to perform authentication and the operator would like to remove any circular dependency if possible.
532 lines
16 KiB
Go
532 lines
16 KiB
Go
package etcd
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/coreos/dex/storage"
|
|
"github.com/coreos/etcd/clientv3"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
clientPrefix = "client/"
|
|
authCodePrefix = "auth_code/"
|
|
refreshTokenPrefix = "refresh_token/"
|
|
authRequestPrefix = "auth_req/"
|
|
passwordPrefix = "password/"
|
|
offlineSessionPrefix = "offline_session/"
|
|
connectorPrefix = "connector/"
|
|
keysName = "openid-connect-keys"
|
|
|
|
// defaultStorageTimeout will be applied to all storage's operations.
|
|
defaultStorageTimeout = 5 * time.Second
|
|
)
|
|
|
|
type conn struct {
|
|
db *clientv3.Client
|
|
logger logrus.FieldLogger
|
|
}
|
|
|
|
func (c *conn) Close() error {
|
|
return c.db.Close()
|
|
}
|
|
|
|
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
authRequests, err := c.listAuthRequests(ctx)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
var delErr error
|
|
for _, authRequest := range authRequests {
|
|
if now.After(authRequest.Expiry) {
|
|
if err := c.deleteKey(ctx, keyID(authRequestPrefix, authRequest.ID)); err != nil {
|
|
c.logger.Errorf("failed to delete auth request: %v", err)
|
|
delErr = fmt.Errorf("failed to delete auth request: %v", err)
|
|
}
|
|
result.AuthRequests++
|
|
}
|
|
}
|
|
if delErr != nil {
|
|
return result, delErr
|
|
}
|
|
|
|
authCodes, err := c.listAuthCodes(ctx)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
for _, authCode := range authCodes {
|
|
if now.After(authCode.Expiry) {
|
|
if err := c.deleteKey(ctx, keyID(authCodePrefix, authCode.ID)); err != nil {
|
|
c.logger.Errorf("failed to delete auth code %v", err)
|
|
delErr = fmt.Errorf("failed to delete auth code: %v", err)
|
|
}
|
|
result.AuthCodes++
|
|
}
|
|
}
|
|
return result, delErr
|
|
}
|
|
|
|
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a))
|
|
}
|
|
|
|
func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
var req AuthRequest
|
|
if err = c.getKey(ctx, keyID(authRequestPrefix, id), &req); err != nil {
|
|
return
|
|
}
|
|
return toStorageAuthRequest(req), nil
|
|
}
|
|
|
|
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyID(authRequestPrefix, id), func(currentValue []byte) ([]byte, error) {
|
|
var current AuthRequest
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(toStorageAuthRequest(current))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(fromStorageAuthRequest(updated))
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeleteAuthRequest(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(authRequestPrefix, id))
|
|
}
|
|
|
|
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a))
|
|
}
|
|
|
|
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyID(authCodePrefix, id), &a)
|
|
return a, err
|
|
}
|
|
|
|
func (c *conn) DeleteAuthCode(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(authCodePrefix, id))
|
|
}
|
|
|
|
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r))
|
|
}
|
|
|
|
func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
var token RefreshToken
|
|
if err = c.getKey(ctx, keyID(refreshTokenPrefix, id), &token); err != nil {
|
|
return
|
|
}
|
|
return toStorageRefreshToken(token), nil
|
|
}
|
|
|
|
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) {
|
|
var current RefreshToken
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal([]byte(currentValue), ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(toStorageRefreshToken(current))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(fromStorageRefreshToken(updated))
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeleteRefresh(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(refreshTokenPrefix, id))
|
|
}
|
|
|
|
func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, refreshTokenPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return tokens, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var token RefreshToken
|
|
if err = json.Unmarshal(v.Value, &token); err != nil {
|
|
return tokens, err
|
|
}
|
|
tokens = append(tokens, toStorageRefreshToken(token))
|
|
}
|
|
return tokens, nil
|
|
}
|
|
|
|
func (c *conn) CreateClient(cli storage.Client) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli)
|
|
}
|
|
|
|
func (c *conn) GetClient(id string) (cli storage.Client, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyID(clientPrefix, id), &cli)
|
|
return cli, err
|
|
}
|
|
|
|
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyID(clientPrefix, id), func(currentValue []byte) ([]byte, error) {
|
|
var current storage.Client
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(current)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(updated)
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeleteClient(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(clientPrefix, id))
|
|
}
|
|
|
|
func (c *conn) ListClients() (clients []storage.Client, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, clientPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return clients, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var cli storage.Client
|
|
if err = json.Unmarshal(v.Value, &cli); err != nil {
|
|
return clients, err
|
|
}
|
|
clients = append(clients, cli)
|
|
}
|
|
return clients, nil
|
|
}
|
|
|
|
func (c *conn) CreatePassword(p storage.Password) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p)
|
|
}
|
|
|
|
func (c *conn) GetPassword(email string) (p storage.Password, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyEmail(passwordPrefix, email), &p)
|
|
return p, err
|
|
}
|
|
|
|
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyEmail(passwordPrefix, email), func(currentValue []byte) ([]byte, error) {
|
|
var current storage.Password
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(current)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(updated)
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeletePassword(email string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyEmail(passwordPrefix, email))
|
|
}
|
|
|
|
func (c *conn) ListPasswords() (passwords []storage.Password, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, passwordPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return passwords, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var p storage.Password
|
|
if err = json.Unmarshal(v.Value, &p); err != nil {
|
|
return passwords, err
|
|
}
|
|
passwords = append(passwords, p)
|
|
}
|
|
return passwords, nil
|
|
}
|
|
|
|
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keySession(offlineSessionPrefix, s.UserID, s.ConnID), fromStorageOfflineSessions(s))
|
|
}
|
|
|
|
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keySession(offlineSessionPrefix, userID, connID), func(currentValue []byte) ([]byte, error) {
|
|
var current OfflineSessions
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(toStorageOfflineSessions(current))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(fromStorageOfflineSessions(updated))
|
|
})
|
|
}
|
|
|
|
func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.OfflineSessions, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
var os OfflineSessions
|
|
if err = c.getKey(ctx, keySession(offlineSessionPrefix, userID, connID), &os); err != nil {
|
|
return
|
|
}
|
|
return toStorageOfflineSessions(os), nil
|
|
}
|
|
|
|
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keySession(offlineSessionPrefix, userID, connID))
|
|
}
|
|
|
|
func (c *conn) CreateConnector(connector storage.Connector) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector)
|
|
}
|
|
|
|
func (c *conn) GetConnector(id string) (conn storage.Connector, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
err = c.getKey(ctx, keyID(connectorPrefix, id), &conn)
|
|
return conn, err
|
|
}
|
|
|
|
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keyID(connectorPrefix, id), func(currentValue []byte) ([]byte, error) {
|
|
var current storage.Connector
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(current)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(updated)
|
|
})
|
|
}
|
|
|
|
func (c *conn) DeleteConnector(id string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.deleteKey(ctx, keyID(connectorPrefix, id))
|
|
}
|
|
|
|
func (c *conn) ListConnectors() (connectors []storage.Connector, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, connectorPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var c storage.Connector
|
|
if err = json.Unmarshal(v.Value, &c); err != nil {
|
|
return nil, err
|
|
}
|
|
connectors = append(connectors, c)
|
|
}
|
|
return connectors, nil
|
|
}
|
|
|
|
func (c *conn) GetKeys() (keys storage.Keys, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
res, err := c.db.Get(ctx, keysName)
|
|
if err != nil {
|
|
return keys, err
|
|
}
|
|
if res.Count > 0 && len(res.Kvs) > 0 {
|
|
err = json.Unmarshal(res.Kvs[0].Value, &keys)
|
|
}
|
|
return keys, err
|
|
}
|
|
|
|
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
|
defer cancel()
|
|
return c.txnUpdate(ctx, keysName, func(currentValue []byte) ([]byte, error) {
|
|
var current storage.Keys
|
|
if len(currentValue) > 0 {
|
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
updated, err := updater(current)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(updated)
|
|
})
|
|
}
|
|
|
|
func (c *conn) deleteKey(ctx context.Context, key string) error {
|
|
res, err := c.db.Delete(ctx, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if res.Deleted == 0 {
|
|
return storage.ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *conn) getKey(ctx context.Context, key string, value interface{}) error {
|
|
r, err := c.db.Get(ctx, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if r.Count == 0 {
|
|
return storage.ErrNotFound
|
|
}
|
|
return json.Unmarshal(r.Kvs[0].Value, value)
|
|
}
|
|
|
|
func (c *conn) listAuthRequests(ctx context.Context) (reqs []AuthRequest, err error) {
|
|
res, err := c.db.Get(ctx, authRequestPrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return reqs, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var r AuthRequest
|
|
if err = json.Unmarshal(v.Value, &r); err != nil {
|
|
return reqs, err
|
|
}
|
|
reqs = append(reqs, r)
|
|
}
|
|
return reqs, nil
|
|
}
|
|
|
|
func (c *conn) listAuthCodes(ctx context.Context) (codes []AuthCode, err error) {
|
|
res, err := c.db.Get(ctx, authCodePrefix, clientv3.WithPrefix())
|
|
if err != nil {
|
|
return codes, err
|
|
}
|
|
for _, v := range res.Kvs {
|
|
var c AuthCode
|
|
if err = json.Unmarshal(v.Value, &c); err != nil {
|
|
return codes, err
|
|
}
|
|
codes = append(codes, c)
|
|
}
|
|
return codes, nil
|
|
}
|
|
|
|
func (c *conn) txnCreate(ctx context.Context, key string, value interface{}) error {
|
|
b, err := json.Marshal(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
txn := c.db.Txn(ctx)
|
|
res, err := txn.
|
|
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
|
|
Then(clientv3.OpPut(key, string(b))).
|
|
Commit()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !res.Succeeded {
|
|
return storage.ErrAlreadyExists
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *conn) txnUpdate(ctx context.Context, key string, update func(current []byte) ([]byte, error)) error {
|
|
getResp, err := c.db.Get(ctx, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var currentValue []byte
|
|
var modRev int64
|
|
if len(getResp.Kvs) > 0 {
|
|
currentValue = getResp.Kvs[0].Value
|
|
modRev = getResp.Kvs[0].ModRevision
|
|
}
|
|
|
|
updatedValue, err := update(currentValue)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
txn := c.db.Txn(ctx)
|
|
updateResp, err := txn.
|
|
If(clientv3.Compare(clientv3.ModRevision(key), "=", modRev)).
|
|
Then(clientv3.OpPut(key, string(updatedValue))).
|
|
Commit()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !updateResp.Succeeded {
|
|
return fmt.Errorf("failed to update key=%q: concurrent conflicting update happened", key)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func keyID(prefix, id string) string { return prefix + id }
|
|
func keyEmail(prefix, email string) string { return prefix + strings.ToLower(email) }
|
|
func keySession(prefix, userID, connID string) string {
|
|
return prefix + strings.ToLower(userID+"|"+connID)
|
|
}
|