storage: make static storages query real storages for some actions

If dex is configured with static passwords or clients, let the API
still add or modify objects in the backing storage, so long as
their IDs don't conflict with the static ones. List options now
aggregate resources from the static list and backing storage.
This commit is contained in:
Eric Chiang 2017-03-15 16:56:47 -07:00
parent d31bb1c8d5
commit 4c39bc20ae
3 changed files with 264 additions and 77 deletions

View file

@ -1,55 +0,0 @@
package memory
import (
"os"
"reflect"
"testing"
"github.com/Sirupsen/logrus"
"github.com/coreos/dex/storage"
)
func TestStaticClients(t *testing.T) {
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
s := New(logger)
c1 := storage.Client{ID: "foo", Secret: "foo_secret"}
c2 := storage.Client{ID: "bar", Secret: "bar_secret"}
s.CreateClient(c1)
s2 := storage.WithStaticClients(s, []storage.Client{c2})
tests := []struct {
id string
s storage.Storage
wantErr bool
wantClient storage.Client
}{
{"foo", s, false, c1},
{"bar", s, true, storage.Client{}},
{"foo", s2, true, storage.Client{}},
{"bar", s2, false, c2},
}
for i, tc := range tests {
gotClient, err := tc.s.GetClient(tc.id)
if err != nil {
if !tc.wantErr {
t.Errorf("case %d: GetClient(%q) %v", i, tc.id, err)
}
continue
}
if tc.wantErr {
t.Errorf("case %d: GetClient(%q) expected error", i, tc.id)
continue
}
if !reflect.DeepEqual(tc.wantClient, gotClient) {
t.Errorf("case %d: expected=%#v got=%#v", i, tc.wantClient, gotClient)
}
}
}

View file

@ -0,0 +1,192 @@
package memory
import (
"fmt"
"os"
"strings"
"testing"
"github.com/Sirupsen/logrus"
"github.com/coreos/dex/storage"
)
func TestStaticClients(t *testing.T) {
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
backing := New(logger)
c1 := storage.Client{ID: "foo", Secret: "foo_secret"}
c2 := storage.Client{ID: "bar", Secret: "bar_secret"}
c3 := storage.Client{ID: "spam", Secret: "spam_secret"}
backing.CreateClient(c1)
s := storage.WithStaticClients(backing, []storage.Client{c2})
tests := []struct {
name string
action func() error
wantErr bool
}{
{
name: "get client from static storage",
action: func() error {
_, err := s.GetClient(c2.ID)
return err
},
},
{
name: "get client from backing storage",
action: func() error {
_, err := s.GetClient(c1.ID)
return err
},
},
{
name: "update static client",
action: func() error {
updater := func(c storage.Client) (storage.Client, error) {
c.Secret = "new_" + c.Secret
return c, nil
}
return s.UpdateClient(c2.ID, updater)
},
wantErr: true,
},
{
name: "update non-static client",
action: func() error {
updater := func(c storage.Client) (storage.Client, error) {
c.Secret = "new_" + c.Secret
return c, nil
}
return s.UpdateClient(c1.ID, updater)
},
},
{
name: "list clients",
action: func() error {
clients, err := s.ListClients()
if err != nil {
return err
}
if n := len(clients); n != 2 {
return fmt.Errorf("expected 2 clients got %d", n)
}
return nil
},
},
{
name: "create client",
action: func() error {
return s.CreateClient(c3)
},
},
}
for _, tc := range tests {
err := tc.action()
if err != nil && !tc.wantErr {
t.Errorf("%s: %v", tc.name, err)
}
if err == nil && tc.wantErr {
t.Errorf("%s: expected error, didn't get one", tc.name)
}
}
}
func TestStaticPasswords(t *testing.T) {
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
backing := New(logger)
p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"}
p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"}
p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"}
backing.CreatePassword(p1)
s := storage.WithStaticPasswords(backing, []storage.Password{p2})
tests := []struct {
name string
action func() error
wantErr bool
}{
{
name: "get password from static storage",
action: func() error {
_, err := s.GetPassword(p2.Email)
return err
},
},
{
name: "get password from backing storage",
action: func() error {
_, err := s.GetPassword(p1.Email)
return err
},
},
{
name: "get password from static storage with casing",
action: func() error {
_, err := s.GetPassword(strings.ToUpper(p2.Email))
return err
},
},
{
name: "update static password",
action: func() error {
updater := func(p storage.Password) (storage.Password, error) {
p.Username = "new_" + p.Username
return p, nil
}
return s.UpdatePassword(p2.Email, updater)
},
wantErr: true,
},
{
name: "update non-static password",
action: func() error {
updater := func(p storage.Password) (storage.Password, error) {
p.Username = "new_" + p.Username
return p, nil
}
return s.UpdatePassword(p1.Email, updater)
},
},
{
name: "list passwords",
action: func() error {
passwords, err := s.ListPasswords()
if err != nil {
return err
}
if n := len(passwords); n != 2 {
return fmt.Errorf("expected 2 passwords got %d", n)
}
return nil
},
},
{
name: "create password",
action: func() error {
return s.CreatePassword(p3)
},
},
}
for _, tc := range tests {
err := tc.action()
if err != nil && !tc.wantErr {
t.Errorf("%s: %v", tc.name, err)
}
if err == nil && tc.wantErr {
t.Errorf("%s: expected error, didn't get one", tc.name)
}
}
}

View file

@ -19,11 +19,7 @@ type staticClientsStorage struct {
clientsByID map[string]Client
}
// WithStaticClients returns a storage with a read-only set of clients. Write actions,
// such as creating other clients, will fail.
//
// In the future the returned storage may allow creating and storing additional clients
// in the underlying storage.
// WithStaticClients adds a read-only set of clients to the underlying storages.
func WithStaticClients(s Storage, staticClients []Client) Storage {
clientsByID := make(map[string]Client, len(staticClients))
for _, client := range staticClients {
@ -36,25 +32,50 @@ func (s staticClientsStorage) GetClient(id string) (Client, error) {
if client, ok := s.clientsByID[id]; ok {
return client, nil
}
return Client{}, ErrNotFound
return s.Storage.GetClient(id)
}
func (s staticClientsStorage) isStatic(id string) bool {
_, ok := s.clientsByID[id]
return ok
}
func (s staticClientsStorage) ListClients() ([]Client, error) {
clients := make([]Client, len(s.clients))
copy(clients, s.clients)
return clients, nil
clients, err := s.Storage.ListClients()
if err != nil {
return nil, err
}
n := 0
for _, client := range clients {
// If a client in the backing storage has the same ID as a static client
// prefer the static client.
if !s.isStatic(client.ID) {
clients[n] = client
n++
}
}
return append(clients[:n], s.clients...), nil
}
func (s staticClientsStorage) CreateClient(c Client) error {
if s.isStatic(c.ID) {
return errors.New("static clients: read-only cannot create client")
}
return s.Storage.CreateClient(c)
}
func (s staticClientsStorage) DeleteClient(id string) error {
if s.isStatic(id) {
return errors.New("static clients: read-only cannot delete client")
}
return s.Storage.DeleteClient(id)
}
func (s staticClientsStorage) UpdateClient(id string, updater func(old Client) (Client, error)) error {
if s.isStatic(id) {
return errors.New("static clients: read-only cannot update client")
}
return s.Storage.UpdateClient(id, updater)
}
type staticPasswordsStorage struct {
@ -76,27 +97,56 @@ func WithStaticPasswords(s Storage, staticPasswords []Password) Storage {
return staticPasswordsStorage{s, staticPasswords, passwordsByEmail}
}
func (s staticPasswordsStorage) isStatic(email string) bool {
_, ok := s.passwordsByEmail[strings.ToLower(email)]
return ok
}
func (s staticPasswordsStorage) GetPassword(email string) (Password, error) {
if password, ok := s.passwordsByEmail[strings.ToLower(email)]; ok {
// TODO(ericchiang): BLAH. We really need to figure out how to handle
// lower cased emails better.
email = strings.ToLower(email)
if password, ok := s.passwordsByEmail[email]; ok {
return password, nil
}
return Password{}, ErrNotFound
return s.Storage.GetPassword(email)
}
func (s staticPasswordsStorage) ListPasswords() ([]Password, error) {
passwords := make([]Password, len(s.passwords))
copy(passwords, s.passwords)
return passwords, nil
passwords, err := s.Storage.ListPasswords()
if err != nil {
return nil, err
}
n := 0
for _, password := range passwords {
// If an entry has the same email as those provided in the static
// values, prefer the static value.
if !s.isStatic(password.Email) {
passwords[n] = password
n++
}
}
return append(passwords[:n], s.passwords...), nil
}
func (s staticPasswordsStorage) CreatePassword(p Password) error {
if s.isStatic(p.Email) {
return errors.New("static passwords: read-only cannot create password")
}
return s.Storage.CreatePassword(p)
}
func (s staticPasswordsStorage) DeletePassword(id string) error {
func (s staticPasswordsStorage) DeletePassword(email string) error {
if s.isStatic(email) {
return errors.New("static passwords: read-only cannot create password")
}
return s.Storage.DeletePassword(email)
}
func (s staticPasswordsStorage) UpdatePassword(id string, updater func(old Password) (Password, error)) error {
func (s staticPasswordsStorage) UpdatePassword(email string, updater func(old Password) (Password, error)) error {
if s.isStatic(email) {
return errors.New("static passwords: read-only cannot update password")
}
return s.Storage.UpdatePassword(email, updater)
}