diff --git a/server/server.go b/server/server.go index 81b86ec3..703af668 100644 --- a/server/server.go +++ b/server/server.go @@ -3,12 +3,15 @@ package server import ( "errors" "fmt" + "log" "net/http" "net/url" "path" "sync/atomic" "time" + "golang.org/x/crypto/bcrypt" + "github.com/gorilla/mux" "github.com/coreos/dex/connector" @@ -44,6 +47,8 @@ type Config struct { // If specified, the server will use this function for determining time. Now func() time.Time + EnablePasswordDB bool + TemplateConfig TemplateConfig } @@ -91,6 +96,14 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { if err != nil { return nil, fmt.Errorf("server: can't parse issuer URL") } + if c.EnablePasswordDB { + c.Connectors = append(c.Connectors, Connector{ + ID: "local", + DisplayName: "Email", + Connector: newPasswordDB(c.Storage), + }) + } + if len(c.Connectors) == 0 { return nil, errors.New("server: no connectors specified") } @@ -182,6 +195,38 @@ func (s *Server) absURL(pathItems ...string) string { return u.String() } +func newPasswordDB(s storage.Storage) interface { + connector.Connector + connector.PasswordConnector +} { + return passwordDB{s} +} + +type passwordDB struct { + s storage.Storage +} + +func (db passwordDB) Close() error { return nil } + +func (db passwordDB) Login(email, password string) (connector.Identity, bool, error) { + p, err := db.s.GetPassword(email) + if err != nil { + if err != storage.ErrNotFound { + log.Printf("get password: %v", err) + } + return connector.Identity{}, false, err + } + if err := bcrypt.CompareHashAndPassword(p.Hash, []byte(password)); err != nil { + return connector.Identity{}, false, nil + } + return connector.Identity{ + UserID: p.UserID, + Username: p.Username, + Email: p.Email, + EmailVerified: true, + }, true, nil +} + // newKeyCacher returns a storage which caches keys so long as the next func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage { if now == nil { diff --git a/server/server_test.go b/server/server_test.go index 296f22cc..46fcc710 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -16,9 +16,12 @@ import ( "time" "github.com/ericchiang/oidc" + "github.com/kylelemons/godebug/pretty" + "golang.org/x/crypto/bcrypt" "golang.org/x/net/context" "golang.org/x/oauth2" + "github.com/coreos/dex/connector" "github.com/coreos/dex/connector/mock" "github.com/coreos/dex/storage" "github.com/coreos/dex/storage/memory" @@ -381,6 +384,91 @@ func TestOAuth2ImplicitFlow(t *testing.T) { } } +func TestPasswordDB(t *testing.T) { + s := memory.New() + conn := newPasswordDB(s) + defer conn.Close() + + pw := "hi" + + h, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.MinCost) + if err != nil { + t.Fatal(err) + } + + s.CreatePassword(storage.Password{ + Email: "jane@example.com", + Username: "jane", + UserID: "foobar", + Hash: h, + }) + + tests := []struct { + name string + username string + password string + wantIdentity connector.Identity + wantInvalid bool + wantErr bool + }{ + { + name: "valid password", + username: "jane@example.com", + password: pw, + wantIdentity: connector.Identity{ + Email: "jane@example.com", + Username: "jane", + UserID: "foobar", + EmailVerified: true, + }, + }, + { + name: "unknown user", + username: "john@example.com", + password: pw, + wantErr: true, + }, + { + name: "invalid password", + username: "jane@example.com", + password: "not the correct password", + wantInvalid: true, + }, + } + + for _, tc := range tests { + ident, valid, err := conn.Login(tc.username, tc.password) + if err != nil { + if !tc.wantErr { + t.Errorf("%s: %v", tc.name, err) + } + continue + } + + if tc.wantErr { + t.Errorf("%s: expected error", tc.name) + continue + } + + if !valid { + if !tc.wantInvalid { + t.Errorf("%s: expected valid password", tc.name) + } + continue + } + + if tc.wantInvalid { + t.Errorf("%s: expected invalid password", tc.name) + continue + } + + if diff := pretty.Compare(tc.wantIdentity, ident); diff != "" { + t.Errorf("%s: %s", tc.name, diff) + } + } + +} + type storageWithKeysTrigger struct { storage.Storage f func()