From 4c39bc20ae15dc5dadc530a9da8936ed6fea90e3 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Wed, 15 Mar 2017 16:56:47 -0700 Subject: [PATCH] storage: make static storages query real storages for some actions If dex is configured with static passwords or clients, let the API still add or modify objects in the backing storage, so long as their IDs don't conflict with the static ones. List options now aggregate resources from the static list and backing storage. --- storage/memory/static_clients_test.go | 55 -------- storage/memory/static_test.go | 192 ++++++++++++++++++++++++++ storage/static.go | 94 ++++++++++--- 3 files changed, 264 insertions(+), 77 deletions(-) delete mode 100644 storage/memory/static_clients_test.go create mode 100644 storage/memory/static_test.go diff --git a/storage/memory/static_clients_test.go b/storage/memory/static_clients_test.go deleted file mode 100644 index aab8597e..00000000 --- a/storage/memory/static_clients_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package memory - -import ( - "os" - "reflect" - "testing" - - "github.com/Sirupsen/logrus" - "github.com/coreos/dex/storage" -) - -func TestStaticClients(t *testing.T) { - logger := &logrus.Logger{ - Out: os.Stderr, - Formatter: &logrus.TextFormatter{DisableColors: true}, - Level: logrus.DebugLevel, - } - s := New(logger) - - c1 := storage.Client{ID: "foo", Secret: "foo_secret"} - c2 := storage.Client{ID: "bar", Secret: "bar_secret"} - s.CreateClient(c1) - s2 := storage.WithStaticClients(s, []storage.Client{c2}) - - tests := []struct { - id string - s storage.Storage - wantErr bool - wantClient storage.Client - }{ - {"foo", s, false, c1}, - {"bar", s, true, storage.Client{}}, - {"foo", s2, true, storage.Client{}}, - {"bar", s2, false, c2}, - } - - for i, tc := range tests { - gotClient, err := tc.s.GetClient(tc.id) - if err != nil { - if !tc.wantErr { - t.Errorf("case %d: GetClient(%q) %v", i, tc.id, err) - } - continue - } - - if tc.wantErr { - t.Errorf("case %d: GetClient(%q) expected error", i, tc.id) - continue - } - - if !reflect.DeepEqual(tc.wantClient, gotClient) { - t.Errorf("case %d: expected=%#v got=%#v", i, tc.wantClient, gotClient) - } - } -} diff --git a/storage/memory/static_test.go b/storage/memory/static_test.go new file mode 100644 index 00000000..33140612 --- /dev/null +++ b/storage/memory/static_test.go @@ -0,0 +1,192 @@ +package memory + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/Sirupsen/logrus" + "github.com/coreos/dex/storage" +) + +func TestStaticClients(t *testing.T) { + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + backing := New(logger) + + c1 := storage.Client{ID: "foo", Secret: "foo_secret"} + c2 := storage.Client{ID: "bar", Secret: "bar_secret"} + c3 := storage.Client{ID: "spam", Secret: "spam_secret"} + + backing.CreateClient(c1) + s := storage.WithStaticClients(backing, []storage.Client{c2}) + + tests := []struct { + name string + action func() error + wantErr bool + }{ + { + name: "get client from static storage", + action: func() error { + _, err := s.GetClient(c2.ID) + return err + }, + }, + { + name: "get client from backing storage", + action: func() error { + _, err := s.GetClient(c1.ID) + return err + }, + }, + { + name: "update static client", + action: func() error { + updater := func(c storage.Client) (storage.Client, error) { + c.Secret = "new_" + c.Secret + return c, nil + } + return s.UpdateClient(c2.ID, updater) + }, + wantErr: true, + }, + { + name: "update non-static client", + action: func() error { + updater := func(c storage.Client) (storage.Client, error) { + c.Secret = "new_" + c.Secret + return c, nil + } + return s.UpdateClient(c1.ID, updater) + }, + }, + { + name: "list clients", + action: func() error { + clients, err := s.ListClients() + if err != nil { + return err + } + if n := len(clients); n != 2 { + return fmt.Errorf("expected 2 clients got %d", n) + } + return nil + }, + }, + { + name: "create client", + action: func() error { + return s.CreateClient(c3) + }, + }, + } + + for _, tc := range tests { + err := tc.action() + if err != nil && !tc.wantErr { + t.Errorf("%s: %v", tc.name, err) + } + if err == nil && tc.wantErr { + t.Errorf("%s: expected error, didn't get one", tc.name) + } + } +} + +func TestStaticPasswords(t *testing.T) { + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + backing := New(logger) + + p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"} + p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"} + p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"} + + backing.CreatePassword(p1) + s := storage.WithStaticPasswords(backing, []storage.Password{p2}) + + tests := []struct { + name string + action func() error + wantErr bool + }{ + { + name: "get password from static storage", + action: func() error { + _, err := s.GetPassword(p2.Email) + return err + }, + }, + { + name: "get password from backing storage", + action: func() error { + _, err := s.GetPassword(p1.Email) + return err + }, + }, + { + name: "get password from static storage with casing", + action: func() error { + _, err := s.GetPassword(strings.ToUpper(p2.Email)) + return err + }, + }, + { + name: "update static password", + action: func() error { + updater := func(p storage.Password) (storage.Password, error) { + p.Username = "new_" + p.Username + return p, nil + } + return s.UpdatePassword(p2.Email, updater) + }, + wantErr: true, + }, + { + name: "update non-static password", + action: func() error { + updater := func(p storage.Password) (storage.Password, error) { + p.Username = "new_" + p.Username + return p, nil + } + return s.UpdatePassword(p1.Email, updater) + }, + }, + { + name: "list passwords", + action: func() error { + passwords, err := s.ListPasswords() + if err != nil { + return err + } + if n := len(passwords); n != 2 { + return fmt.Errorf("expected 2 passwords got %d", n) + } + return nil + }, + }, + { + name: "create password", + action: func() error { + return s.CreatePassword(p3) + }, + }, + } + + for _, tc := range tests { + err := tc.action() + if err != nil && !tc.wantErr { + t.Errorf("%s: %v", tc.name, err) + } + if err == nil && tc.wantErr { + t.Errorf("%s: expected error, didn't get one", tc.name) + } + } +} diff --git a/storage/static.go b/storage/static.go index 4076a613..b2ab5d4f 100644 --- a/storage/static.go +++ b/storage/static.go @@ -19,11 +19,7 @@ type staticClientsStorage struct { clientsByID map[string]Client } -// WithStaticClients returns a storage with a read-only set of clients. Write actions, -// such as creating other clients, will fail. -// -// In the future the returned storage may allow creating and storing additional clients -// in the underlying storage. +// WithStaticClients adds a read-only set of clients to the underlying storages. func WithStaticClients(s Storage, staticClients []Client) Storage { clientsByID := make(map[string]Client, len(staticClients)) for _, client := range staticClients { @@ -36,25 +32,50 @@ func (s staticClientsStorage) GetClient(id string) (Client, error) { if client, ok := s.clientsByID[id]; ok { return client, nil } - return Client{}, ErrNotFound + return s.Storage.GetClient(id) +} + +func (s staticClientsStorage) isStatic(id string) bool { + _, ok := s.clientsByID[id] + return ok } func (s staticClientsStorage) ListClients() ([]Client, error) { - clients := make([]Client, len(s.clients)) - copy(clients, s.clients) - return clients, nil + clients, err := s.Storage.ListClients() + if err != nil { + return nil, err + } + n := 0 + for _, client := range clients { + // If a client in the backing storage has the same ID as a static client + // prefer the static client. + if !s.isStatic(client.ID) { + clients[n] = client + n++ + } + } + return append(clients[:n], s.clients...), nil } func (s staticClientsStorage) CreateClient(c Client) error { - return errors.New("static clients: read-only cannot create client") + if s.isStatic(c.ID) { + return errors.New("static clients: read-only cannot create client") + } + return s.Storage.CreateClient(c) } func (s staticClientsStorage) DeleteClient(id string) error { - return errors.New("static clients: read-only cannot delete client") + if s.isStatic(id) { + return errors.New("static clients: read-only cannot delete client") + } + return s.Storage.DeleteClient(id) } func (s staticClientsStorage) UpdateClient(id string, updater func(old Client) (Client, error)) error { - return errors.New("static clients: read-only cannot update client") + if s.isStatic(id) { + return errors.New("static clients: read-only cannot update client") + } + return s.Storage.UpdateClient(id, updater) } type staticPasswordsStorage struct { @@ -76,27 +97,56 @@ func WithStaticPasswords(s Storage, staticPasswords []Password) Storage { return staticPasswordsStorage{s, staticPasswords, passwordsByEmail} } +func (s staticPasswordsStorage) isStatic(email string) bool { + _, ok := s.passwordsByEmail[strings.ToLower(email)] + return ok +} + func (s staticPasswordsStorage) GetPassword(email string) (Password, error) { - if password, ok := s.passwordsByEmail[strings.ToLower(email)]; ok { + // TODO(ericchiang): BLAH. We really need to figure out how to handle + // lower cased emails better. + email = strings.ToLower(email) + if password, ok := s.passwordsByEmail[email]; ok { return password, nil } - return Password{}, ErrNotFound + return s.Storage.GetPassword(email) } func (s staticPasswordsStorage) ListPasswords() ([]Password, error) { - passwords := make([]Password, len(s.passwords)) - copy(passwords, s.passwords) - return passwords, nil + passwords, err := s.Storage.ListPasswords() + if err != nil { + return nil, err + } + + n := 0 + for _, password := range passwords { + // If an entry has the same email as those provided in the static + // values, prefer the static value. + if !s.isStatic(password.Email) { + passwords[n] = password + n++ + } + } + return append(passwords[:n], s.passwords...), nil } func (s staticPasswordsStorage) CreatePassword(p Password) error { - return errors.New("static passwords: read-only cannot create password") + if s.isStatic(p.Email) { + return errors.New("static passwords: read-only cannot create password") + } + return s.Storage.CreatePassword(p) } -func (s staticPasswordsStorage) DeletePassword(id string) error { - return errors.New("static passwords: read-only cannot create password") +func (s staticPasswordsStorage) DeletePassword(email string) error { + if s.isStatic(email) { + return errors.New("static passwords: read-only cannot create password") + } + return s.Storage.DeletePassword(email) } -func (s staticPasswordsStorage) UpdatePassword(id string, updater func(old Password) (Password, error)) error { - return errors.New("static passwords: read-only cannot update password") +func (s staticPasswordsStorage) UpdatePassword(email string, updater func(old Password) (Password, error)) error { + if s.isStatic(email) { + return errors.New("static passwords: read-only cannot update password") + } + return s.Storage.UpdatePassword(email, updater) }