diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index f2221c71..86d20603 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -144,7 +144,7 @@ func serve(cmd *cobra.Command, args []string) error { for i, p := range c.StaticPasswords { passwords[i] = storage.Password(p) } - s = storage.WithStaticPasswords(s, passwords) + s = storage.WithStaticPasswords(s, passwords, logger) } storageConnectors := make([]storage.Connector, len(c.StaticConnectors)) diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 90d3d85a..ed80778b 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -128,12 +128,12 @@ func (s *memStorage) CreateAuthRequest(a storage.AuthRequest) (err error) { } func (s *memStorage) CreatePassword(p storage.Password) (err error) { - p.Email = strings.ToLower(p.Email) + lowerEmail := strings.ToLower(p.Email) s.tx(func() { - if _, ok := s.passwords[p.Email]; ok { + if _, ok := s.passwords[lowerEmail]; ok { err = storage.ErrAlreadyExists } else { - s.passwords[p.Email] = p + s.passwords[lowerEmail] = p } }) return diff --git a/storage/memory/static_test.go b/storage/memory/static_test.go index 17ddf7c8..df990ebb 100644 --- a/storage/memory/static_test.go +++ b/storage/memory/static_test.go @@ -108,9 +108,10 @@ func TestStaticPasswords(t *testing.T) { p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"} p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"} p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"} + p4 := storage.Password{Email: "Spam@example.com", Username: "Spam_secret"} backing.CreatePassword(p1) - s := storage.WithStaticPasswords(backing, []storage.Password{p2}) + s := storage.WithStaticPasswords(backing, []storage.Password{p2}, logger) tests := []struct { name string @@ -159,6 +160,29 @@ func TestStaticPasswords(t *testing.T) { return s.UpdatePassword(p1.Email, updater) }, }, + { + name: "create passwords", + action: func() error { + if err := s.CreatePassword(p4); err != nil { + return err + } + return s.CreatePassword(p3) + }, + wantErr: true, + }, + { + name: "get password", + action: func() error { + p, err := s.GetPassword(p4.Email) + if err != nil { + return err + } + if strings.Compare(p.Email, p4.Email) != 0 { + return fmt.Errorf("expected %s passwords got %s", p4.Email, p.Email) + } + return nil + }, + }, { name: "list passwords", action: func() error { @@ -166,18 +190,12 @@ func TestStaticPasswords(t *testing.T) { if err != nil { return err } - if n := len(passwords); n != 2 { - return fmt.Errorf("expected 2 passwords got %d", n) + if n := len(passwords); n != 3 { + return fmt.Errorf("expected 3 passwords got %d", n) } return nil }, }, - { - name: "create password", - action: func() error { - return s.CreatePassword(p3) - }, - }, } for _, tc := range tests { diff --git a/storage/static.go b/storage/static.go index 53bd9bfe..5ae4f783 100644 --- a/storage/static.go +++ b/storage/static.go @@ -3,6 +3,8 @@ package storage import ( "errors" "strings" + + "github.com/sirupsen/logrus" ) // Tests for this code are in the "memory" package, since this package doesn't @@ -25,6 +27,7 @@ func WithStaticClients(s Storage, staticClients []Client) Storage { for _, client := range staticClients { clientsByID[client.ID] = client } + return staticClientsStorage{s, staticClients, clientsByID} } @@ -82,19 +85,26 @@ type staticPasswordsStorage struct { Storage // A read-only set of passwords. - passwords []Password + passwords []Password + // A map of passwords that is indexed by lower-case email ids passwordsByEmail map[string]Password + + logger logrus.FieldLogger } -// WithStaticPasswords returns a storage with a read-only set of passwords. Write actions, -// such as creating other passwords, will fail. -func WithStaticPasswords(s Storage, staticPasswords []Password) Storage { +// WithStaticPasswords returns a storage with a read-only set of passwords. +func WithStaticPasswords(s Storage, staticPasswords []Password, logger logrus.FieldLogger) Storage { passwordsByEmail := make(map[string]Password, len(staticPasswords)) for _, p := range staticPasswords { - p.Email = strings.ToLower(p.Email) - passwordsByEmail[p.Email] = p + //Enable case insensitive email comparison. + lowerEmail := strings.ToLower(p.Email) + if _, ok := passwordsByEmail[lowerEmail]; ok { + logger.Errorf("Attempting to create StaticPasswords with the same email id: %s", p.Email) + } + passwordsByEmail[lowerEmail] = p } - return staticPasswordsStorage{s, staticPasswords, passwordsByEmail} + + return staticPasswordsStorage{s, staticPasswords, passwordsByEmail, logger} } func (s staticPasswordsStorage) isStatic(email string) bool {