Merge pull request #855 from ericchiang/static-storage-fallthrough
storage: make static storages query real storages for some actions
This commit is contained in:
commit
95d237003a
3 changed files with 264 additions and 77 deletions
|
@ -1,55 +0,0 @@
|
||||||
package memory
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/Sirupsen/logrus"
|
|
||||||
"github.com/coreos/dex/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStaticClients(t *testing.T) {
|
|
||||||
logger := &logrus.Logger{
|
|
||||||
Out: os.Stderr,
|
|
||||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
|
||||||
Level: logrus.DebugLevel,
|
|
||||||
}
|
|
||||||
s := New(logger)
|
|
||||||
|
|
||||||
c1 := storage.Client{ID: "foo", Secret: "foo_secret"}
|
|
||||||
c2 := storage.Client{ID: "bar", Secret: "bar_secret"}
|
|
||||||
s.CreateClient(c1)
|
|
||||||
s2 := storage.WithStaticClients(s, []storage.Client{c2})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
id string
|
|
||||||
s storage.Storage
|
|
||||||
wantErr bool
|
|
||||||
wantClient storage.Client
|
|
||||||
}{
|
|
||||||
{"foo", s, false, c1},
|
|
||||||
{"bar", s, true, storage.Client{}},
|
|
||||||
{"foo", s2, true, storage.Client{}},
|
|
||||||
{"bar", s2, false, c2},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tc := range tests {
|
|
||||||
gotClient, err := tc.s.GetClient(tc.id)
|
|
||||||
if err != nil {
|
|
||||||
if !tc.wantErr {
|
|
||||||
t.Errorf("case %d: GetClient(%q) %v", i, tc.id, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.wantErr {
|
|
||||||
t.Errorf("case %d: GetClient(%q) expected error", i, tc.id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(tc.wantClient, gotClient) {
|
|
||||||
t.Errorf("case %d: expected=%#v got=%#v", i, tc.wantClient, gotClient)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
192
storage/memory/static_test.go
Normal file
192
storage/memory/static_test.go
Normal file
|
@ -0,0 +1,192 @@
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Sirupsen/logrus"
|
||||||
|
"github.com/coreos/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStaticClients(t *testing.T) {
|
||||||
|
logger := &logrus.Logger{
|
||||||
|
Out: os.Stderr,
|
||||||
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||||
|
Level: logrus.DebugLevel,
|
||||||
|
}
|
||||||
|
backing := New(logger)
|
||||||
|
|
||||||
|
c1 := storage.Client{ID: "foo", Secret: "foo_secret"}
|
||||||
|
c2 := storage.Client{ID: "bar", Secret: "bar_secret"}
|
||||||
|
c3 := storage.Client{ID: "spam", Secret: "spam_secret"}
|
||||||
|
|
||||||
|
backing.CreateClient(c1)
|
||||||
|
s := storage.WithStaticClients(backing, []storage.Client{c2})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action func() error
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "get client from static storage",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetClient(c2.ID)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "get client from backing storage",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetClient(c1.ID)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update static client",
|
||||||
|
action: func() error {
|
||||||
|
updater := func(c storage.Client) (storage.Client, error) {
|
||||||
|
c.Secret = "new_" + c.Secret
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
return s.UpdateClient(c2.ID, updater)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update non-static client",
|
||||||
|
action: func() error {
|
||||||
|
updater := func(c storage.Client) (storage.Client, error) {
|
||||||
|
c.Secret = "new_" + c.Secret
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
return s.UpdateClient(c1.ID, updater)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list clients",
|
||||||
|
action: func() error {
|
||||||
|
clients, err := s.ListClients()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n := len(clients); n != 2 {
|
||||||
|
return fmt.Errorf("expected 2 clients got %d", n)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create client",
|
||||||
|
action: func() error {
|
||||||
|
return s.CreateClient(c3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
err := tc.action()
|
||||||
|
if err != nil && !tc.wantErr {
|
||||||
|
t.Errorf("%s: %v", tc.name, err)
|
||||||
|
}
|
||||||
|
if err == nil && tc.wantErr {
|
||||||
|
t.Errorf("%s: expected error, didn't get one", tc.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticPasswords(t *testing.T) {
|
||||||
|
logger := &logrus.Logger{
|
||||||
|
Out: os.Stderr,
|
||||||
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||||
|
Level: logrus.DebugLevel,
|
||||||
|
}
|
||||||
|
backing := New(logger)
|
||||||
|
|
||||||
|
p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"}
|
||||||
|
p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"}
|
||||||
|
p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"}
|
||||||
|
|
||||||
|
backing.CreatePassword(p1)
|
||||||
|
s := storage.WithStaticPasswords(backing, []storage.Password{p2})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action func() error
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "get password from static storage",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetPassword(p2.Email)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "get password from backing storage",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetPassword(p1.Email)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "get password from static storage with casing",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetPassword(strings.ToUpper(p2.Email))
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update static password",
|
||||||
|
action: func() error {
|
||||||
|
updater := func(p storage.Password) (storage.Password, error) {
|
||||||
|
p.Username = "new_" + p.Username
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
return s.UpdatePassword(p2.Email, updater)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update non-static password",
|
||||||
|
action: func() error {
|
||||||
|
updater := func(p storage.Password) (storage.Password, error) {
|
||||||
|
p.Username = "new_" + p.Username
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
return s.UpdatePassword(p1.Email, updater)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list passwords",
|
||||||
|
action: func() error {
|
||||||
|
passwords, err := s.ListPasswords()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n := len(passwords); n != 2 {
|
||||||
|
return fmt.Errorf("expected 2 passwords got %d", n)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create password",
|
||||||
|
action: func() error {
|
||||||
|
return s.CreatePassword(p3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
err := tc.action()
|
||||||
|
if err != nil && !tc.wantErr {
|
||||||
|
t.Errorf("%s: %v", tc.name, err)
|
||||||
|
}
|
||||||
|
if err == nil && tc.wantErr {
|
||||||
|
t.Errorf("%s: expected error, didn't get one", tc.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,11 +19,7 @@ type staticClientsStorage struct {
|
||||||
clientsByID map[string]Client
|
clientsByID map[string]Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithStaticClients returns a storage with a read-only set of clients. Write actions,
|
// WithStaticClients adds a read-only set of clients to the underlying storages.
|
||||||
// such as creating other clients, will fail.
|
|
||||||
//
|
|
||||||
// In the future the returned storage may allow creating and storing additional clients
|
|
||||||
// in the underlying storage.
|
|
||||||
func WithStaticClients(s Storage, staticClients []Client) Storage {
|
func WithStaticClients(s Storage, staticClients []Client) Storage {
|
||||||
clientsByID := make(map[string]Client, len(staticClients))
|
clientsByID := make(map[string]Client, len(staticClients))
|
||||||
for _, client := range staticClients {
|
for _, client := range staticClients {
|
||||||
|
@ -36,25 +32,50 @@ func (s staticClientsStorage) GetClient(id string) (Client, error) {
|
||||||
if client, ok := s.clientsByID[id]; ok {
|
if client, ok := s.clientsByID[id]; ok {
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
return Client{}, ErrNotFound
|
return s.Storage.GetClient(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s staticClientsStorage) isStatic(id string) bool {
|
||||||
|
_, ok := s.clientsByID[id]
|
||||||
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticClientsStorage) ListClients() ([]Client, error) {
|
func (s staticClientsStorage) ListClients() ([]Client, error) {
|
||||||
clients := make([]Client, len(s.clients))
|
clients, err := s.Storage.ListClients()
|
||||||
copy(clients, s.clients)
|
if err != nil {
|
||||||
return clients, nil
|
return nil, err
|
||||||
|
}
|
||||||
|
n := 0
|
||||||
|
for _, client := range clients {
|
||||||
|
// If a client in the backing storage has the same ID as a static client
|
||||||
|
// prefer the static client.
|
||||||
|
if !s.isStatic(client.ID) {
|
||||||
|
clients[n] = client
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return append(clients[:n], s.clients...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticClientsStorage) CreateClient(c Client) error {
|
func (s staticClientsStorage) CreateClient(c Client) error {
|
||||||
return errors.New("static clients: read-only cannot create client")
|
if s.isStatic(c.ID) {
|
||||||
|
return errors.New("static clients: read-only cannot create client")
|
||||||
|
}
|
||||||
|
return s.Storage.CreateClient(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticClientsStorage) DeleteClient(id string) error {
|
func (s staticClientsStorage) DeleteClient(id string) error {
|
||||||
return errors.New("static clients: read-only cannot delete client")
|
if s.isStatic(id) {
|
||||||
|
return errors.New("static clients: read-only cannot delete client")
|
||||||
|
}
|
||||||
|
return s.Storage.DeleteClient(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticClientsStorage) UpdateClient(id string, updater func(old Client) (Client, error)) error {
|
func (s staticClientsStorage) UpdateClient(id string, updater func(old Client) (Client, error)) error {
|
||||||
return errors.New("static clients: read-only cannot update client")
|
if s.isStatic(id) {
|
||||||
|
return errors.New("static clients: read-only cannot update client")
|
||||||
|
}
|
||||||
|
return s.Storage.UpdateClient(id, updater)
|
||||||
}
|
}
|
||||||
|
|
||||||
type staticPasswordsStorage struct {
|
type staticPasswordsStorage struct {
|
||||||
|
@ -76,27 +97,56 @@ func WithStaticPasswords(s Storage, staticPasswords []Password) Storage {
|
||||||
return staticPasswordsStorage{s, staticPasswords, passwordsByEmail}
|
return staticPasswordsStorage{s, staticPasswords, passwordsByEmail}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s staticPasswordsStorage) isStatic(email string) bool {
|
||||||
|
_, ok := s.passwordsByEmail[strings.ToLower(email)]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) GetPassword(email string) (Password, error) {
|
func (s staticPasswordsStorage) GetPassword(email string) (Password, error) {
|
||||||
if password, ok := s.passwordsByEmail[strings.ToLower(email)]; ok {
|
// TODO(ericchiang): BLAH. We really need to figure out how to handle
|
||||||
|
// lower cased emails better.
|
||||||
|
email = strings.ToLower(email)
|
||||||
|
if password, ok := s.passwordsByEmail[email]; ok {
|
||||||
return password, nil
|
return password, nil
|
||||||
}
|
}
|
||||||
return Password{}, ErrNotFound
|
return s.Storage.GetPassword(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) ListPasswords() ([]Password, error) {
|
func (s staticPasswordsStorage) ListPasswords() ([]Password, error) {
|
||||||
passwords := make([]Password, len(s.passwords))
|
passwords, err := s.Storage.ListPasswords()
|
||||||
copy(passwords, s.passwords)
|
if err != nil {
|
||||||
return passwords, nil
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n := 0
|
||||||
|
for _, password := range passwords {
|
||||||
|
// If an entry has the same email as those provided in the static
|
||||||
|
// values, prefer the static value.
|
||||||
|
if !s.isStatic(password.Email) {
|
||||||
|
passwords[n] = password
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return append(passwords[:n], s.passwords...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) CreatePassword(p Password) error {
|
func (s staticPasswordsStorage) CreatePassword(p Password) error {
|
||||||
return errors.New("static passwords: read-only cannot create password")
|
if s.isStatic(p.Email) {
|
||||||
|
return errors.New("static passwords: read-only cannot create password")
|
||||||
|
}
|
||||||
|
return s.Storage.CreatePassword(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) DeletePassword(id string) error {
|
func (s staticPasswordsStorage) DeletePassword(email string) error {
|
||||||
return errors.New("static passwords: read-only cannot create password")
|
if s.isStatic(email) {
|
||||||
|
return errors.New("static passwords: read-only cannot create password")
|
||||||
|
}
|
||||||
|
return s.Storage.DeletePassword(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) UpdatePassword(id string, updater func(old Password) (Password, error)) error {
|
func (s staticPasswordsStorage) UpdatePassword(email string, updater func(old Password) (Password, error)) error {
|
||||||
return errors.New("static passwords: read-only cannot update password")
|
if s.isStatic(email) {
|
||||||
|
return errors.New("static passwords: read-only cannot update password")
|
||||||
|
}
|
||||||
|
return s.Storage.UpdatePassword(email, updater)
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue