forked from mystiq/dex
storage/static.go: storage backend should not explicitly lower-case email ids.
This commit is contained in:
parent
e40c01ec39
commit
fd4f57b5f3
4 changed files with 48 additions and 20 deletions
|
@ -144,7 +144,7 @@ func serve(cmd *cobra.Command, args []string) error {
|
||||||
for i, p := range c.StaticPasswords {
|
for i, p := range c.StaticPasswords {
|
||||||
passwords[i] = storage.Password(p)
|
passwords[i] = storage.Password(p)
|
||||||
}
|
}
|
||||||
s = storage.WithStaticPasswords(s, passwords)
|
s = storage.WithStaticPasswords(s, passwords, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
storageConnectors := make([]storage.Connector, len(c.StaticConnectors))
|
storageConnectors := make([]storage.Connector, len(c.StaticConnectors))
|
||||||
|
|
|
@ -128,12 +128,12 @@ func (s *memStorage) CreateAuthRequest(a storage.AuthRequest) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *memStorage) CreatePassword(p storage.Password) (err error) {
|
func (s *memStorage) CreatePassword(p storage.Password) (err error) {
|
||||||
p.Email = strings.ToLower(p.Email)
|
lowerEmail := strings.ToLower(p.Email)
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
if _, ok := s.passwords[p.Email]; ok {
|
if _, ok := s.passwords[lowerEmail]; ok {
|
||||||
err = storage.ErrAlreadyExists
|
err = storage.ErrAlreadyExists
|
||||||
} else {
|
} else {
|
||||||
s.passwords[p.Email] = p
|
s.passwords[lowerEmail] = p
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
|
@ -108,9 +108,10 @@ func TestStaticPasswords(t *testing.T) {
|
||||||
p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"}
|
p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"}
|
||||||
p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"}
|
p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"}
|
||||||
p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"}
|
p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"}
|
||||||
|
p4 := storage.Password{Email: "Spam@example.com", Username: "Spam_secret"}
|
||||||
|
|
||||||
backing.CreatePassword(p1)
|
backing.CreatePassword(p1)
|
||||||
s := storage.WithStaticPasswords(backing, []storage.Password{p2})
|
s := storage.WithStaticPasswords(backing, []storage.Password{p2}, logger)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -159,6 +160,29 @@ func TestStaticPasswords(t *testing.T) {
|
||||||
return s.UpdatePassword(p1.Email, updater)
|
return s.UpdatePassword(p1.Email, updater)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "create passwords",
|
||||||
|
action: func() error {
|
||||||
|
if err := s.CreatePassword(p4); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.CreatePassword(p3)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "get password",
|
||||||
|
action: func() error {
|
||||||
|
p, err := s.GetPassword(p4.Email)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if strings.Compare(p.Email, p4.Email) != 0 {
|
||||||
|
return fmt.Errorf("expected %s passwords got %s", p4.Email, p.Email)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "list passwords",
|
name: "list passwords",
|
||||||
action: func() error {
|
action: func() error {
|
||||||
|
@ -166,18 +190,12 @@ func TestStaticPasswords(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if n := len(passwords); n != 2 {
|
if n := len(passwords); n != 3 {
|
||||||
return fmt.Errorf("expected 2 passwords got %d", n)
|
return fmt.Errorf("expected 3 passwords got %d", n)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "create password",
|
|
||||||
action: func() error {
|
|
||||||
return s.CreatePassword(p3)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
|
|
@ -3,6 +3,8 @@ package storage
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Tests for this code are in the "memory" package, since this package doesn't
|
// Tests for this code are in the "memory" package, since this package doesn't
|
||||||
|
@ -25,6 +27,7 @@ func WithStaticClients(s Storage, staticClients []Client) Storage {
|
||||||
for _, client := range staticClients {
|
for _, client := range staticClients {
|
||||||
clientsByID[client.ID] = client
|
clientsByID[client.ID] = client
|
||||||
}
|
}
|
||||||
|
|
||||||
return staticClientsStorage{s, staticClients, clientsByID}
|
return staticClientsStorage{s, staticClients, clientsByID}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,19 +85,26 @@ type staticPasswordsStorage struct {
|
||||||
Storage
|
Storage
|
||||||
|
|
||||||
// A read-only set of passwords.
|
// A read-only set of passwords.
|
||||||
passwords []Password
|
passwords []Password
|
||||||
|
// A map of passwords that is indexed by lower-case email ids
|
||||||
passwordsByEmail map[string]Password
|
passwordsByEmail map[string]Password
|
||||||
|
|
||||||
|
logger logrus.FieldLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithStaticPasswords returns a storage with a read-only set of passwords. Write actions,
|
// WithStaticPasswords returns a storage with a read-only set of passwords.
|
||||||
// such as creating other passwords, will fail.
|
func WithStaticPasswords(s Storage, staticPasswords []Password, logger logrus.FieldLogger) Storage {
|
||||||
func WithStaticPasswords(s Storage, staticPasswords []Password) Storage {
|
|
||||||
passwordsByEmail := make(map[string]Password, len(staticPasswords))
|
passwordsByEmail := make(map[string]Password, len(staticPasswords))
|
||||||
for _, p := range staticPasswords {
|
for _, p := range staticPasswords {
|
||||||
p.Email = strings.ToLower(p.Email)
|
//Enable case insensitive email comparison.
|
||||||
passwordsByEmail[p.Email] = p
|
lowerEmail := strings.ToLower(p.Email)
|
||||||
|
if _, ok := passwordsByEmail[lowerEmail]; ok {
|
||||||
|
logger.Errorf("Attempting to create StaticPasswords with the same email id: %s", p.Email)
|
||||||
|
}
|
||||||
|
passwordsByEmail[lowerEmail] = p
|
||||||
}
|
}
|
||||||
return staticPasswordsStorage{s, staticPasswords, passwordsByEmail}
|
|
||||||
|
return staticPasswordsStorage{s, staticPasswords, passwordsByEmail, logger}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) isStatic(email string) bool {
|
func (s staticPasswordsStorage) isStatic(email string) bool {
|
||||||
|
|
Loading…
Reference in a new issue