forked from mystiq/dex
storage: add etcd storage
This patch adds etcd storage implementation. This should be useful in environments where - we dont want to depends on a separate, hard to maintain SQL cluster - we dont want to incur the overhead of talking to kubernetes apiservers - kubernetes is not available yet, or if kubernetes depends on dex to perform authentication and the operator would like to remove any circular dependency if possible.
This commit is contained in:
parent
2b13bdd12d
commit
ca114f7812
6 changed files with 1058 additions and 0 deletions
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/coreos/dex/server"
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/coreos/dex/storage/etcd"
|
||||
"github.com/coreos/dex/storage/kubernetes"
|
||||
"github.com/coreos/dex/storage/memory"
|
||||
"github.com/coreos/dex/storage/sql"
|
||||
|
@ -124,6 +125,7 @@ type StorageConfig interface {
|
|||
}
|
||||
|
||||
var storages = map[string]func() StorageConfig{
|
||||
"etcd": func() StorageConfig { return new(etcd.Etcd) },
|
||||
"kubernetes": func() StorageConfig { return new(kubernetes.Config) },
|
||||
"memory": func() StorageConfig { return new(memory.Config) },
|
||||
"sqlite3": func() StorageConfig { return new(sql.SQLite3) },
|
||||
|
|
92
storage/etcd/config.go
Normal file
92
storage/etcd/config.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package etcd
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/coreos/etcd/clientv3/namespace"
|
||||
"github.com/coreos/etcd/pkg/transport"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDialTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// SSL represents SSL options for etcd databases.
|
||||
type SSL struct {
|
||||
ServerName string
|
||||
CAFile string
|
||||
KeyFile string
|
||||
CertFile string
|
||||
}
|
||||
|
||||
// Etcd options for connecting to etcd databases.
|
||||
// If you are using a shared etcd cluster for storage, it might be useful to
|
||||
// configure an etcd namespace either via Namespace field or using `etcd grpc-proxy
|
||||
// --namespace=<prefix>`
|
||||
type Etcd struct {
|
||||
Endpoints []string
|
||||
Namespace string
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
// Open creates a new storage implementation backed by Etcd
|
||||
func (p *Etcd) Open(logger logrus.FieldLogger) (storage.Storage, error) {
|
||||
return p.open(logger)
|
||||
}
|
||||
|
||||
func (p *Etcd) open(logger logrus.FieldLogger) (*conn, error) {
|
||||
cfg := clientv3.Config{
|
||||
Endpoints: p.Endpoints,
|
||||
DialTimeout: defaultDialTimeout * time.Second,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
}
|
||||
|
||||
var cfgtls *transport.TLSInfo
|
||||
tlsinfo := transport.TLSInfo{}
|
||||
if p.SSL.CertFile != "" {
|
||||
tlsinfo.CertFile = p.SSL.CertFile
|
||||
cfgtls = &tlsinfo
|
||||
}
|
||||
|
||||
if p.SSL.KeyFile != "" {
|
||||
tlsinfo.KeyFile = p.SSL.KeyFile
|
||||
cfgtls = &tlsinfo
|
||||
}
|
||||
|
||||
if p.SSL.CAFile != "" {
|
||||
tlsinfo.CAFile = p.SSL.CAFile
|
||||
cfgtls = &tlsinfo
|
||||
}
|
||||
|
||||
if p.SSL.ServerName != "" {
|
||||
tlsinfo.ServerName = p.SSL.ServerName
|
||||
cfgtls = &tlsinfo
|
||||
}
|
||||
|
||||
if cfgtls != nil {
|
||||
clientTLS, err := cfgtls.ClientConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.TLS = clientTLS
|
||||
}
|
||||
|
||||
db, err := clientv3.New(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Namespace) > 0 {
|
||||
db.KV = namespace.NewKV(db.KV, p.Namespace)
|
||||
}
|
||||
c := &conn{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
return c, nil
|
||||
}
|
532
storage/etcd/etcd.go
Normal file
532
storage/etcd/etcd.go
Normal file
|
@ -0,0 +1,532 @@
|
|||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
clientPrefix = "client/"
|
||||
authCodePrefix = "auth_code/"
|
||||
refreshTokenPrefix = "refresh_token/"
|
||||
authRequestPrefix = "auth_req/"
|
||||
passwordPrefix = "password/"
|
||||
offlineSessionPrefix = "offline_session/"
|
||||
connectorPrefix = "connector/"
|
||||
keysName = "openid-connect-keys"
|
||||
|
||||
// defaultStorageTimeout will be applied to all storage's operations.
|
||||
defaultStorageTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
db *clientv3.Client
|
||||
logger logrus.FieldLogger
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
authRequests, err := c.listAuthRequests(ctx)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
var delErr error
|
||||
for _, authRequest := range authRequests {
|
||||
if now.After(authRequest.Expiry) {
|
||||
if err := c.deleteKey(ctx, keyID(authRequestPrefix, authRequest.ID)); err != nil {
|
||||
c.logger.Errorf("failed to delete auth request: %v", err)
|
||||
delErr = fmt.Errorf("failed to delete auth request: %v", err)
|
||||
}
|
||||
result.AuthRequests++
|
||||
}
|
||||
}
|
||||
if delErr != nil {
|
||||
return result, delErr
|
||||
}
|
||||
|
||||
authCodes, err := c.listAuthCodes(ctx)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
for _, authCode := range authCodes {
|
||||
if now.After(authCode.Expiry) {
|
||||
if err := c.deleteKey(ctx, keyID(authCodePrefix, authCode.ID)); err != nil {
|
||||
c.logger.Errorf("failed to delete auth code %v", err)
|
||||
delErr = fmt.Errorf("failed to delete auth code: %v", err)
|
||||
}
|
||||
result.AuthCodes++
|
||||
}
|
||||
}
|
||||
return result, delErr
|
||||
}
|
||||
|
||||
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a))
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
var req AuthRequest
|
||||
if err = c.getKey(ctx, keyID(authRequestPrefix, id), &req); err != nil {
|
||||
return
|
||||
}
|
||||
return toStorageAuthRequest(req), nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(authRequestPrefix, id), func(currentValue []byte) ([]byte, error) {
|
||||
var current AuthRequest
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(toStorageAuthRequest(current))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(fromStorageAuthRequest(updated))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeleteAuthRequest(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(authRequestPrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a))
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(authCodePrefix, id), &a)
|
||||
return a, err
|
||||
}
|
||||
|
||||
func (c *conn) DeleteAuthCode(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(authCodePrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r))
|
||||
}
|
||||
|
||||
func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
var token RefreshToken
|
||||
if err = c.getKey(ctx, keyID(refreshTokenPrefix, id), &token); err != nil {
|
||||
return
|
||||
}
|
||||
return toStorageRefreshToken(token), nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) {
|
||||
var current RefreshToken
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal([]byte(currentValue), ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(toStorageRefreshToken(current))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(fromStorageRefreshToken(updated))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeleteRefresh(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(refreshTokenPrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, refreshTokenPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var token RefreshToken
|
||||
if err = json.Unmarshal(v.Value, &token); err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
tokens = append(tokens, toStorageRefreshToken(token))
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateClient(cli storage.Client) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli)
|
||||
}
|
||||
|
||||
func (c *conn) GetClient(id string) (cli storage.Client, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(clientPrefix, id), &cli)
|
||||
return cli, err
|
||||
}
|
||||
|
||||
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(clientPrefix, id), func(currentValue []byte) ([]byte, error) {
|
||||
var current storage.Client
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(updated)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeleteClient(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(clientPrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) ListClients() (clients []storage.Client, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, clientPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return clients, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var cli storage.Client
|
||||
if err = json.Unmarshal(v.Value, &cli); err != nil {
|
||||
return clients, err
|
||||
}
|
||||
clients = append(clients, cli)
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreatePassword(p storage.Password) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p)
|
||||
}
|
||||
|
||||
func (c *conn) GetPassword(email string) (p storage.Password, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyEmail(passwordPrefix, email), &p)
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyEmail(passwordPrefix, email), func(currentValue []byte) ([]byte, error) {
|
||||
var current storage.Password
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(updated)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeletePassword(email string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyEmail(passwordPrefix, email))
|
||||
}
|
||||
|
||||
func (c *conn) ListPasswords() (passwords []storage.Password, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, passwordPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return passwords, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var p storage.Password
|
||||
if err = json.Unmarshal(v.Value, &p); err != nil {
|
||||
return passwords, err
|
||||
}
|
||||
passwords = append(passwords, p)
|
||||
}
|
||||
return passwords, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keySession(offlineSessionPrefix, s.UserID, s.ConnID), fromStorageOfflineSessions(s))
|
||||
}
|
||||
|
||||
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keySession(offlineSessionPrefix, userID, connID), func(currentValue []byte) ([]byte, error) {
|
||||
var current OfflineSessions
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(toStorageOfflineSessions(current))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(fromStorageOfflineSessions(updated))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.OfflineSessions, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
var os OfflineSessions
|
||||
if err = c.getKey(ctx, keySession(offlineSessionPrefix, userID, connID), &os); err != nil {
|
||||
return
|
||||
}
|
||||
return toStorageOfflineSessions(os), nil
|
||||
}
|
||||
|
||||
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keySession(offlineSessionPrefix, userID, connID))
|
||||
}
|
||||
|
||||
func (c *conn) CreateConnector(connector storage.Connector) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector)
|
||||
}
|
||||
|
||||
func (c *conn) GetConnector(id string) (conn storage.Connector, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(connectorPrefix, id), &conn)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(connectorPrefix, id), func(currentValue []byte) ([]byte, error) {
|
||||
var current storage.Connector
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(updated)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeleteConnector(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(connectorPrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) ListConnectors() (connectors []storage.Connector, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, connectorPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var c storage.Connector
|
||||
if err = json.Unmarshal(v.Value, &c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connectors = append(connectors, c)
|
||||
}
|
||||
return connectors, nil
|
||||
}
|
||||
|
||||
func (c *conn) GetKeys() (keys storage.Keys, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, keysName)
|
||||
if err != nil {
|
||||
return keys, err
|
||||
}
|
||||
if res.Count > 0 && len(res.Kvs) > 0 {
|
||||
err = json.Unmarshal(res.Kvs[0].Value, &keys)
|
||||
}
|
||||
return keys, err
|
||||
}
|
||||
|
||||
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keysName, func(currentValue []byte) ([]byte, error) {
|
||||
var current storage.Keys
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(updated)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) deleteKey(ctx context.Context, key string) error {
|
||||
res, err := c.db.Delete(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.Deleted == 0 {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) getKey(ctx context.Context, key string, value interface{}) error {
|
||||
r, err := c.db.Get(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.Count == 0 {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
return json.Unmarshal(r.Kvs[0].Value, value)
|
||||
}
|
||||
|
||||
func (c *conn) listAuthRequests(ctx context.Context) (reqs []AuthRequest, err error) {
|
||||
res, err := c.db.Get(ctx, authRequestPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return reqs, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var r AuthRequest
|
||||
if err = json.Unmarshal(v.Value, &r); err != nil {
|
||||
return reqs, err
|
||||
}
|
||||
reqs = append(reqs, r)
|
||||
}
|
||||
return reqs, nil
|
||||
}
|
||||
|
||||
func (c *conn) listAuthCodes(ctx context.Context) (codes []AuthCode, err error) {
|
||||
res, err := c.db.Get(ctx, authCodePrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return codes, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var c AuthCode
|
||||
if err = json.Unmarshal(v.Value, &c); err != nil {
|
||||
return codes, err
|
||||
}
|
||||
codes = append(codes, c)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
func (c *conn) txnCreate(ctx context.Context, key string, value interface{}) error {
|
||||
b, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txn := c.db.Txn(ctx)
|
||||
res, err := txn.
|
||||
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
|
||||
Then(clientv3.OpPut(key, string(b))).
|
||||
Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !res.Succeeded {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) txnUpdate(ctx context.Context, key string, update func(current []byte) ([]byte, error)) error {
|
||||
getResp, err := c.db.Get(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var currentValue []byte
|
||||
var modRev int64
|
||||
if len(getResp.Kvs) > 0 {
|
||||
currentValue = getResp.Kvs[0].Value
|
||||
modRev = getResp.Kvs[0].ModRevision
|
||||
}
|
||||
|
||||
updatedValue, err := update(currentValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn := c.db.Txn(ctx)
|
||||
updateResp, err := txn.
|
||||
If(clientv3.Compare(clientv3.ModRevision(key), "=", modRev)).
|
||||
Then(clientv3.OpPut(key, string(updatedValue))).
|
||||
Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updateResp.Succeeded {
|
||||
return fmt.Errorf("failed to update key=%q: concurrent conflicting update happened", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func keyID(prefix, id string) string { return prefix + id }
|
||||
func keyEmail(prefix, email string) string { return prefix + strings.ToLower(email) }
|
||||
func keySession(prefix, userID, connID string) string {
|
||||
return prefix + strings.ToLower(userID+"|"+connID)
|
||||
}
|
94
storage/etcd/etcd_test.go
Normal file
94
storage/etcd/etcd_test.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/coreos/dex/storage/conformance"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func withTimeout(t time.Duration, f func()) {
|
||||
c := make(chan struct{})
|
||||
defer close(c)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(t):
|
||||
// Dump a stack trace of the program. Useful for debugging deadlocks.
|
||||
buf := make([]byte, 2<<20)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", buf[:runtime.Stack(buf, true)])
|
||||
panic("test took too long")
|
||||
}
|
||||
}()
|
||||
|
||||
f()
|
||||
}
|
||||
|
||||
func cleanDB(c *conn) error {
|
||||
ctx := context.TODO()
|
||||
for _, prefix := range []string{
|
||||
clientPrefix,
|
||||
authCodePrefix,
|
||||
refreshTokenPrefix,
|
||||
authRequestPrefix,
|
||||
passwordPrefix,
|
||||
offlineSessionPrefix,
|
||||
connectorPrefix,
|
||||
} {
|
||||
_, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var logger = &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
func TestEtcd(t *testing.T) {
|
||||
testEtcdEnv := "DEX_ETCD_ENDPOINTS"
|
||||
endpointsStr := os.Getenv(testEtcdEnv)
|
||||
if endpointsStr == "" {
|
||||
t.Skipf("test environment variable %q not set, skipping", testEtcdEnv)
|
||||
return
|
||||
}
|
||||
endpoints := strings.Split(endpointsStr, ",")
|
||||
|
||||
newStorage := func() storage.Storage {
|
||||
s := &Etcd{
|
||||
Endpoints: endpoints,
|
||||
}
|
||||
conn, err := s.open(logger)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stdout, err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := cleanDB(conn); err != nil {
|
||||
fmt.Fprintln(os.Stdout, err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
withTimeout(time.Second*10, func() {
|
||||
conformance.RunTests(t, newStorage)
|
||||
})
|
||||
|
||||
withTimeout(time.Minute*1, func() {
|
||||
conformance.RunTransactionTests(t, newStorage)
|
||||
})
|
||||
}
|
109
storage/etcd/standup.sh
Executable file
109
storage/etcd/standup.sh
Executable file
|
@ -0,0 +1,109 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [ "$EUID" -ne 0 ]
|
||||
then echo "Please run as root"
|
||||
exit
|
||||
fi
|
||||
|
||||
function usage {
|
||||
cat << EOF >> /dev/stderr
|
||||
Usage: sudo ./standup.sh [create|destroy] [etcd]
|
||||
|
||||
This is a script for standing up test databases. It uses systemd to daemonize
|
||||
rkt containers running on a local loopback IP.
|
||||
|
||||
The general workflow is to create a daemonized container, use the output to set
|
||||
the test environment variables, run the tests, then destroy the container.
|
||||
|
||||
sudo ./standup.sh create etcd
|
||||
# Copy environment variables and run tests.
|
||||
go test -v -i # always install test dependencies
|
||||
go test -v
|
||||
sudo ./standup.sh destroy etcd
|
||||
|
||||
EOF
|
||||
exit 2
|
||||
}
|
||||
|
||||
function main {
|
||||
if [ "$#" -ne 2 ]; then
|
||||
usage
|
||||
exit 2
|
||||
fi
|
||||
|
||||
case "$1" in
|
||||
"create")
|
||||
case "$2" in
|
||||
"etcd")
|
||||
create_etcd;;
|
||||
*)
|
||||
usage
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
"destroy")
|
||||
case "$2" in
|
||||
"etcd")
|
||||
destroy_etcd;;
|
||||
*)
|
||||
usage
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
function wait_for_file {
|
||||
while [ ! -f $1 ]; do
|
||||
sleep 1
|
||||
done
|
||||
}
|
||||
|
||||
function wait_for_container {
|
||||
while [ -z "$( rkt list --full | grep $1 | grep running )" ]; do
|
||||
sleep 1
|
||||
done
|
||||
}
|
||||
|
||||
function create_etcd {
|
||||
UUID_FILE=/tmp/dex-etcd-uuid
|
||||
if [ -f $UUID_FILE ]; then
|
||||
echo "etcd database already exists, try ./standup.sh destroy etcd"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
echo "Starting etcd . To view progress run:"
|
||||
echo ""
|
||||
echo " journalctl -fu dex-etcd"
|
||||
echo ""
|
||||
UNIFIED_CGROUP_HIERARCHY=no \
|
||||
systemd-run --unit=dex-etcd \
|
||||
rkt run --uuid-file-save=$UUID_FILE --insecure-options=image \
|
||||
--net=host \
|
||||
docker://quay.io/coreos/etcd:v3.2.9
|
||||
|
||||
wait_for_file $UUID_FILE
|
||||
|
||||
UUID=$( cat $UUID_FILE )
|
||||
wait_for_container $UUID
|
||||
echo "To run tests export the following environment variables:"
|
||||
echo ""
|
||||
echo " export DEX_ETCD_ENDPOINTS=http://localhost:2379"
|
||||
echo ""
|
||||
}
|
||||
|
||||
function destroy_etcd {
|
||||
UUID_FILE=/tmp/dex-etcd-uuid
|
||||
systemctl stop dex-etcd
|
||||
rkt rm --uuid-file=$UUID_FILE
|
||||
rm $UUID_FILE
|
||||
}
|
||||
|
||||
|
||||
main $@
|
229
storage/etcd/types.go
Normal file
229
storage/etcd/types.go
Normal file
|
@ -0,0 +1,229 @@
|
|||
package etcd
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// AuthCode is a mirrored struct from storage with JSON struct tags
|
||||
type AuthCode struct {
|
||||
ID string `json:"ID"`
|
||||
ClientID string `json:"clientID"`
|
||||
RedirectURI string `json:"redirectURI"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
|
||||
ConnectorID string `json:"connectorID,omitempty"`
|
||||
ConnectorData []byte `json:"connectorData,omitempty"`
|
||||
Claims Claims `json:"claims,omitempty"`
|
||||
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
|
||||
func fromStorageAuthCode(a storage.AuthCode) AuthCode {
|
||||
return AuthCode{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
RedirectURI: a.RedirectURI,
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: a.ConnectorData,
|
||||
Nonce: a.Nonce,
|
||||
Scopes: a.Scopes,
|
||||
Claims: fromStorageClaims(a.Claims),
|
||||
Expiry: a.Expiry,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRequest is a mirrored struct from storage with JSON struct tags
|
||||
type AuthRequest struct {
|
||||
ID string `json:"id"`
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
Scopes []string `json:"scopes"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
Nonce string `json:"nonce"`
|
||||
State string `json:"state"`
|
||||
|
||||
ForceApprovalPrompt bool `json:"force_approval_prompt"`
|
||||
|
||||
Expiry time.Time `json:"expiry"`
|
||||
|
||||
LoggedIn bool `json:"logged_in"`
|
||||
|
||||
Claims Claims `json:"claims"`
|
||||
|
||||
ConnectorID string `json:"connector_id"`
|
||||
ConnectorData []byte `json:"connector_data"`
|
||||
}
|
||||
|
||||
func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
|
||||
return AuthRequest{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
ResponseTypes: a.ResponseTypes,
|
||||
Scopes: a.Scopes,
|
||||
RedirectURI: a.RedirectURI,
|
||||
Nonce: a.Nonce,
|
||||
State: a.State,
|
||||
ForceApprovalPrompt: a.ForceApprovalPrompt,
|
||||
Expiry: a.Expiry,
|
||||
LoggedIn: a.LoggedIn,
|
||||
Claims: fromStorageClaims(a.Claims),
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: a.ConnectorData,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageAuthRequest(a AuthRequest) storage.AuthRequest {
|
||||
return storage.AuthRequest{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
ResponseTypes: a.ResponseTypes,
|
||||
Scopes: a.Scopes,
|
||||
RedirectURI: a.RedirectURI,
|
||||
Nonce: a.Nonce,
|
||||
State: a.State,
|
||||
ForceApprovalPrompt: a.ForceApprovalPrompt,
|
||||
LoggedIn: a.LoggedIn,
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: a.ConnectorData,
|
||||
Expiry: a.Expiry,
|
||||
Claims: toStorageClaims(a.Claims),
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshToken is a mirrored struct from storage with JSON struct tags
|
||||
type RefreshToken struct {
|
||||
ID string `json:"id"`
|
||||
|
||||
Token string `json:"token"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
ConnectorID string `json:"connector_id"`
|
||||
ConnectorData []byte `json:"connector_data"`
|
||||
Claims Claims `json:"claims"`
|
||||
|
||||
Scopes []string `json:"scopes"`
|
||||
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
|
||||
return storage.RefreshToken{
|
||||
ID: r.ID,
|
||||
Token: r.Token,
|
||||
CreatedAt: r.CreatedAt,
|
||||
LastUsed: r.LastUsed,
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
ConnectorData: r.ConnectorData,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: toStorageClaims(r.Claims),
|
||||
}
|
||||
}
|
||||
|
||||
func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
|
||||
return RefreshToken{
|
||||
ID: r.ID,
|
||||
Token: r.Token,
|
||||
CreatedAt: r.CreatedAt,
|
||||
LastUsed: r.LastUsed,
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
ConnectorData: r.ConnectorData,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: fromStorageClaims(r.Claims),
|
||||
}
|
||||
}
|
||||
|
||||
// Claims is a mirrored struct from storage with JSON struct tags.
|
||||
type Claims struct {
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func fromStorageClaims(i storage.Claims) Claims {
|
||||
return Claims{
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageClaims(i Claims) storage.Claims {
|
||||
return storage.Claims{
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
}
|
||||
}
|
||||
|
||||
// Keys is a mirrored struct from storage with JSON struct tags
|
||||
type Keys struct {
|
||||
SigningKey *jose.JSONWebKey `json:"signing_key,omitempty"`
|
||||
SigningKeyPub *jose.JSONWebKey `json:"signing_key_pub,omitempty"`
|
||||
VerificationKeys []storage.VerificationKey `json:"verification_keys"`
|
||||
NextRotation time.Time `json:"next_rotation"`
|
||||
}
|
||||
|
||||
func fromStorageKeys(keys storage.Keys) Keys {
|
||||
return Keys{
|
||||
SigningKey: keys.SigningKey,
|
||||
SigningKeyPub: keys.SigningKeyPub,
|
||||
VerificationKeys: keys.VerificationKeys,
|
||||
NextRotation: keys.NextRotation,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageKeys(keys Keys) storage.Keys {
|
||||
return storage.Keys{
|
||||
SigningKey: keys.SigningKey,
|
||||
SigningKeyPub: keys.SigningKeyPub,
|
||||
VerificationKeys: keys.VerificationKeys,
|
||||
NextRotation: keys.NextRotation,
|
||||
}
|
||||
}
|
||||
|
||||
// OfflineSessions is a mirrored struct from storage with JSON struct tags
|
||||
type OfflineSessions struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
ConnID string `json:"conn_id,omitempty"`
|
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||
}
|
||||
|
||||
func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
|
||||
return OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
||||
s := storage.OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
}
|
||||
if s.Refresh == nil {
|
||||
// Server code assumes this will be non-nil.
|
||||
s.Refresh = make(map[string]*storage.RefreshTokenRef)
|
||||
}
|
||||
return s
|
||||
}
|
Loading…
Reference in a new issue