diff --git a/storage/memory/static_clients_test.go b/storage/memory/static_clients_test.go deleted file mode 100644 index aab8597e..00000000 --- a/storage/memory/static_clients_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package memory - -import ( - "os" - "reflect" - "testing" - - "github.com/Sirupsen/logrus" - "github.com/coreos/dex/storage" -) - -func TestStaticClients(t *testing.T) { - logger := &logrus.Logger{ - Out: os.Stderr, - Formatter: &logrus.TextFormatter{DisableColors: true}, - Level: logrus.DebugLevel, - } - s := New(logger) - - c1 := storage.Client{ID: "foo", Secret: "foo_secret"} - c2 := storage.Client{ID: "bar", Secret: "bar_secret"} - s.CreateClient(c1) - s2 := storage.WithStaticClients(s, []storage.Client{c2}) - - tests := []struct { - id string - s storage.Storage - wantErr bool - wantClient storage.Client - }{ - {"foo", s, false, c1}, - {"bar", s, true, storage.Client{}}, - {"foo", s2, true, storage.Client{}}, - {"bar", s2, false, c2}, - } - - for i, tc := range tests { - gotClient, err := tc.s.GetClient(tc.id) - if err != nil { - if !tc.wantErr { - t.Errorf("case %d: GetClient(%q) %v", i, tc.id, err) - } - continue - } - - if tc.wantErr { - t.Errorf("case %d: GetClient(%q) expected error", i, tc.id) - continue - } - - if !reflect.DeepEqual(tc.wantClient, gotClient) { - t.Errorf("case %d: expected=%#v got=%#v", i, tc.wantClient, gotClient) - } - } -} diff --git a/storage/memory/static_test.go b/storage/memory/static_test.go new file mode 100644 index 00000000..33140612 --- /dev/null +++ b/storage/memory/static_test.go @@ -0,0 +1,192 @@ +package memory + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/Sirupsen/logrus" + "github.com/coreos/dex/storage" +) + +func TestStaticClients(t *testing.T) { + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + backing := New(logger) + + c1 := storage.Client{ID: "foo", Secret: "foo_secret"} + c2 := storage.Client{ID: "bar", Secret: "bar_secret"} + c3 := storage.Client{ID: "spam", Secret: "spam_secret"} + + backing.CreateClient(c1) + s := storage.WithStaticClients(backing, []storage.Client{c2}) + + tests := []struct { + name string + action func() error + wantErr bool + }{ + { + name: "get client from static storage", + action: func() error { + _, err := s.GetClient(c2.ID) + return err + }, + }, + { + name: "get client from backing storage", + action: func() error { + _, err := s.GetClient(c1.ID) + return err + }, + }, + { + name: "update static client", + action: func() error { + updater := func(c storage.Client) (storage.Client, error) { + c.Secret = "new_" + c.Secret + return c, nil + } + return s.UpdateClient(c2.ID, updater) + }, + wantErr: true, + }, + { + name: "update non-static client", + action: func() error { + updater := func(c storage.Client) (storage.Client, error) { + c.Secret = "new_" + c.Secret + return c, nil + } + return s.UpdateClient(c1.ID, updater) + }, + }, + { + name: "list clients", + action: func() error { + clients, err := s.ListClients() + if err != nil { + return err + } + if n := len(clients); n != 2 { + return fmt.Errorf("expected 2 clients got %d", n) + } + return nil + }, + }, + { + name: "create client", + action: func() error { + return s.CreateClient(c3) + }, + }, + } + + for _, tc := range tests { + err := tc.action() + if err != nil && !tc.wantErr { + t.Errorf("%s: %v", tc.name, err) + } + if err == nil && tc.wantErr { + t.Errorf("%s: expected error, didn't get one", tc.name) + } + } +} + +func TestStaticPasswords(t *testing.T) { + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + backing := New(logger) + + p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"} + p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"} + p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"} + + backing.CreatePassword(p1) + s := storage.WithStaticPasswords(backing, []storage.Password{p2}) + + tests := []struct { + name string + action func() error + wantErr bool + }{ + { + name: "get password from static storage", + action: func() error { + _, err := s.GetPassword(p2.Email) + return err + }, + }, + { + name: "get password from backing storage", + action: func() error { + _, err := s.GetPassword(p1.Email) + return err + }, + }, + { + name: "get password from static storage with casing", + action: func() error { + _, err := s.GetPassword(strings.ToUpper(p2.Email)) + return err + }, + }, + { + name: "update static password", + action: func() error { + updater := func(p storage.Password) (storage.Password, error) { + p.Username = "new_" + p.Username + return p, nil + } + return s.UpdatePassword(p2.Email, updater) + }, + wantErr: true, + }, + { + name: "update non-static password", + action: func() error { + updater := func(p storage.Password) (storage.Password, error) { + p.Username = "new_" + p.Username + return p, nil + } + return s.UpdatePassword(p1.Email, updater) + }, + }, + { + name: "list passwords", + action: func() error { + passwords, err := s.ListPasswords() + if err != nil { + return err + } + if n := len(passwords); n != 2 { + return fmt.Errorf("expected 2 passwords got %d", n) + } + return nil + }, + }, + { + name: "create password", + action: func() error { + return s.CreatePassword(p3) + }, + }, + } + + for _, tc := range tests { + err := tc.action() + if err != nil && !tc.wantErr { + t.Errorf("%s: %v", tc.name, err) + } + if err == nil && tc.wantErr { + t.Errorf("%s: expected error, didn't get one", tc.name) + } + } +} diff --git a/storage/static.go b/storage/static.go index 4076a613..b2ab5d4f 100644 --- a/storage/static.go +++ b/storage/static.go @@ -19,11 +19,7 @@ type staticClientsStorage struct { clientsByID map[string]Client } -// WithStaticClients returns a storage with a read-only set of clients. Write actions, -// such as creating other clients, will fail. -// -// In the future the returned storage may allow creating and storing additional clients -// in the underlying storage. +// WithStaticClients adds a read-only set of clients to the underlying storages. func WithStaticClients(s Storage, staticClients []Client) Storage { clientsByID := make(map[string]Client, len(staticClients)) for _, client := range staticClients { @@ -36,25 +32,50 @@ func (s staticClientsStorage) GetClient(id string) (Client, error) { if client, ok := s.clientsByID[id]; ok { return client, nil } - return Client{}, ErrNotFound + return s.Storage.GetClient(id) +} + +func (s staticClientsStorage) isStatic(id string) bool { + _, ok := s.clientsByID[id] + return ok } func (s staticClientsStorage) ListClients() ([]Client, error) { - clients := make([]Client, len(s.clients)) - copy(clients, s.clients) - return clients, nil + clients, err := s.Storage.ListClients() + if err != nil { + return nil, err + } + n := 0 + for _, client := range clients { + // If a client in the backing storage has the same ID as a static client + // prefer the static client. + if !s.isStatic(client.ID) { + clients[n] = client + n++ + } + } + return append(clients[:n], s.clients...), nil } func (s staticClientsStorage) CreateClient(c Client) error { - return errors.New("static clients: read-only cannot create client") + if s.isStatic(c.ID) { + return errors.New("static clients: read-only cannot create client") + } + return s.Storage.CreateClient(c) } func (s staticClientsStorage) DeleteClient(id string) error { - return errors.New("static clients: read-only cannot delete client") + if s.isStatic(id) { + return errors.New("static clients: read-only cannot delete client") + } + return s.Storage.DeleteClient(id) } func (s staticClientsStorage) UpdateClient(id string, updater func(old Client) (Client, error)) error { - return errors.New("static clients: read-only cannot update client") + if s.isStatic(id) { + return errors.New("static clients: read-only cannot update client") + } + return s.Storage.UpdateClient(id, updater) } type staticPasswordsStorage struct { @@ -76,27 +97,56 @@ func WithStaticPasswords(s Storage, staticPasswords []Password) Storage { return staticPasswordsStorage{s, staticPasswords, passwordsByEmail} } +func (s staticPasswordsStorage) isStatic(email string) bool { + _, ok := s.passwordsByEmail[strings.ToLower(email)] + return ok +} + func (s staticPasswordsStorage) GetPassword(email string) (Password, error) { - if password, ok := s.passwordsByEmail[strings.ToLower(email)]; ok { + // TODO(ericchiang): BLAH. We really need to figure out how to handle + // lower cased emails better. + email = strings.ToLower(email) + if password, ok := s.passwordsByEmail[email]; ok { return password, nil } - return Password{}, ErrNotFound + return s.Storage.GetPassword(email) } func (s staticPasswordsStorage) ListPasswords() ([]Password, error) { - passwords := make([]Password, len(s.passwords)) - copy(passwords, s.passwords) - return passwords, nil + passwords, err := s.Storage.ListPasswords() + if err != nil { + return nil, err + } + + n := 0 + for _, password := range passwords { + // If an entry has the same email as those provided in the static + // values, prefer the static value. + if !s.isStatic(password.Email) { + passwords[n] = password + n++ + } + } + return append(passwords[:n], s.passwords...), nil } func (s staticPasswordsStorage) CreatePassword(p Password) error { - return errors.New("static passwords: read-only cannot create password") + if s.isStatic(p.Email) { + return errors.New("static passwords: read-only cannot create password") + } + return s.Storage.CreatePassword(p) } -func (s staticPasswordsStorage) DeletePassword(id string) error { - return errors.New("static passwords: read-only cannot create password") +func (s staticPasswordsStorage) DeletePassword(email string) error { + if s.isStatic(email) { + return errors.New("static passwords: read-only cannot create password") + } + return s.Storage.DeletePassword(email) } -func (s staticPasswordsStorage) UpdatePassword(id string, updater func(old Password) (Password, error)) error { - return errors.New("static passwords: read-only cannot update password") +func (s staticPasswordsStorage) UpdatePassword(email string, updater func(old Password) (Password, error)) error { + if s.isStatic(email) { + return errors.New("static passwords: read-only cannot update password") + } + return s.Storage.UpdatePassword(email, updater) }