diff --git a/admin/api.go b/admin/api.go index 2b875405..2329c3ba 100644 --- a/admin/api.go +++ b/admin/api.go @@ -6,17 +6,18 @@ import ( "github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/user" + "github.com/coreos/dex/user/manager" ) // AdminAPI provides the logic necessary to implement the Admin API. type AdminAPI struct { - userManager *user.Manager + userManager *manager.UserManager userRepo user.UserRepo passwordInfoRepo user.PasswordInfoRepo localConnectorID string } -func NewAdminAPI(userManager *user.Manager, userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, localConnectorID string) *AdminAPI { +func NewAdminAPI(userManager *manager.UserManager, userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, localConnectorID string) *AdminAPI { if localConnectorID == "" { panic("must specify non-blank localConnectorID") } diff --git a/cmd/dex-overlord/main.go b/cmd/dex-overlord/main.go index 8a3a9936..b8d48fc9 100644 --- a/cmd/dex-overlord/main.go +++ b/cmd/dex-overlord/main.go @@ -17,7 +17,7 @@ import ( "github.com/coreos/dex/pkg/log" ptime "github.com/coreos/dex/pkg/time" "github.com/coreos/dex/server" - "github.com/coreos/dex/user" + "github.com/coreos/dex/user/manager" ) var version = "DEV" @@ -99,8 +99,9 @@ func main() { userRepo := db.NewUserRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc) - userManager := user.NewManager(userRepo, - pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{}) + connCfgRepo := db.NewConnectorConfigRepo(dbc) + userManager := manager.NewUserManager(userRepo, + pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID) kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...) if err != nil { diff --git a/connector/config_repo.go b/connector/config_repo.go index 212a4cd7..dd2d2c09 100644 --- a/connector/config_repo.go +++ b/connector/config_repo.go @@ -4,6 +4,8 @@ import ( "encoding/json" "io" "os" + + "github.com/coreos/dex/repo" ) func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) { @@ -41,6 +43,19 @@ type memConnectorConfigRepo struct { configs []ConnectorConfig } +func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo { + return &memConnectorConfigRepo{configs: cfgs} +} + func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) { return r.configs, nil } + +func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) { + for _, cfg := range r.configs { + if cfg.ConnectorID() == id { + return cfg, nil + } + } + return nil, ErrorNotFound +} diff --git a/connector/interface.go b/connector/interface.go index 44cbb5da..36d0dcc6 100644 --- a/connector/interface.go +++ b/connector/interface.go @@ -1,14 +1,18 @@ package connector import ( + "errors" "html/template" "net/http" "net/url" + "github.com/coreos/dex/repo" "github.com/coreos/go-oidc/oidc" "github.com/coreos/pkg/health" ) +var ErrorNotFound = errors.New("connector not found in repository") + type Connector interface { ID() string LoginURL(sessionKey, prompt string) (string, error) @@ -34,4 +38,5 @@ type ConnectorConfig interface { type ConnectorConfigRepo interface { All() ([]ConnectorConfig, error) + GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error) } diff --git a/db/connector_config.go b/db/connector_config.go index f9d8386e..c8044ec7 100644 --- a/db/connector_config.go +++ b/db/connector_config.go @@ -1,6 +1,7 @@ package db import ( + "database/sql" "encoding/json" "errors" "fmt" @@ -9,6 +10,7 @@ import ( "github.com/lib/pq" "github.com/coreos/dex/connector" + "github.com/coreos/dex/repo" ) const ( @@ -91,6 +93,18 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) { return cfgs, nil } +func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) { + qt := pq.QuoteIdentifier(connectorConfigTableName) + q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt) + var c connectorConfigModel + if err := r.executor(tx).SelectOne(&c, q, id); err != nil { + if err == sql.ErrNoRows { + return nil, connector.ErrorNotFound + } + } + return c.ConnectorConfig() +} + func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error { insert := make([]interface{}, len(cfgs)) for i, cfg := range cfgs { @@ -119,3 +133,15 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error { return tx.Commit() } + +func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor { + if tx == nil { + return r.dbMap + } + + gorpTx, ok := tx.(*gorp.Transaction) + if !ok { + panic("wrong kind of transaction passed to a DB repo") + } + return gorpTx +} diff --git a/functional/repo/connector_repo_test.go b/functional/repo/connector_repo_test.go new file mode 100644 index 00000000..7d5369d4 --- /dev/null +++ b/functional/repo/connector_repo_test.go @@ -0,0 +1,71 @@ +package repo + +import ( + "fmt" + "os" + "testing" + + "github.com/coreos/dex/connector" + "github.com/coreos/dex/db" +) + +type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo + +var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory + +func init() { + if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" { + makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs + } else { + makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn) + } +} + +func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory { + return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo { + dbMap := initDB(dsn) + + repo := db.NewConnectorConfigRepo(dbMap) + if err := repo.Set(cfgs); err != nil { + panic(fmt.Sprintf("Unable to set connector configs: %v", err)) + } + return repo + } +} + +func TestConnectorConfigRepoGetByID(t *testing.T) { + tests := []struct { + cfgs []connector.ConnectorConfig + id string + err error + }{ + { + cfgs: []connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local"}, + }, + id: "local", + }, + { + cfgs: []connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local1"}, + &connector.LocalConnectorConfig{ID: "local2"}, + }, + id: "local2", + }, + { + cfgs: []connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local1"}, + &connector.LocalConnectorConfig{ID: "local2"}, + }, + id: "foo", + err: connector.ErrorNotFound, + }, + } + + for i, tt := range tests { + repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs) + if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err { + t.Errorf("case %d: want=%v, got=%v", i, tt.err, err) + } + } +} diff --git a/integration/common_test.go b/integration/common_test.go index f9e33447..a203cbf0 100644 --- a/integration/common_test.go +++ b/integration/common_test.go @@ -10,8 +10,10 @@ import ( "github.com/coreos/go-oidc/key" "github.com/jonboulle/clockwork" + "github.com/coreos/dex/connector" "github.com/coreos/dex/repo" "github.com/coreos/dex/user" + "github.com/coreos/dex/user/manager" ) var ( @@ -42,11 +44,14 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro return &resp, nil } -func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *user.Manager) { +func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) { ur := user.NewUserRepoFromUsers(users) pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords) - um := user.NewManager(ur, pwr, repo.InMemTransactionFactory, user.ManagerOptions{}) + ccr := connector.NewConnectorConfigRepoFromConfigs( + []connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}}, + ) + um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{}) um.Clock = clock return ur, pwr, um } diff --git a/server/config.go b/server/config.go index 3b775c5e..f70065a3 100644 --- a/server/config.go +++ b/server/config.go @@ -22,6 +22,7 @@ import ( "github.com/coreos/dex/session" "github.com/coreos/dex/user" useremail "github.com/coreos/dex/user/email" + "github.com/coreos/dex/user/manager" ) type ServerConfig struct { @@ -133,7 +134,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { refTokRepo := refresh.NewRefreshTokenRepo() txnFactory := repo.InMemTransactionFactory - userManager := user.NewManager(userRepo, pwiRepo, txnFactory, user.ManagerOptions{}) + userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{}) srv.ClientIdentityRepo = ciRepo srv.KeySetRepo = kRepo srv.ConnectorConfigRepo = cfgRepo @@ -171,7 +172,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error { cfgRepo := db.NewConnectorConfigRepo(dbc) userRepo := db.NewUserRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc) - userManager := user.NewManager(userRepo, pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{}) + userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) refreshTokenRepo := db.NewRefreshTokenRepo(dbc) sm := session.NewSessionManager(sRepo, skRepo) diff --git a/server/email_verification.go b/server/email_verification.go index b2c14723..fe3e9746 100644 --- a/server/email_verification.go +++ b/server/email_verification.go @@ -15,6 +15,7 @@ import ( "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/user" useremail "github.com/coreos/dex/user/email" + "github.com/coreos/dex/user/manager" ) // handleVerifyEmailResendFunc will resend an email-verification email given a valid JWT for the user and a redirect URL. @@ -190,7 +191,7 @@ type emailVerifiedTemplateData struct { } func handleEmailVerifyFunc(verifiedTpl *template.Template, issuer url.URL, keysFunc func() ([]key.PublicKey, - error), userManager *user.Manager) http.HandlerFunc { + error), userManager *manager.UserManager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() @@ -217,12 +218,12 @@ func handleEmailVerifyFunc(verifiedTpl *template.Template, issuer url.URL, keysF cbURL, err := userManager.VerifyEmail(ev) if err != nil { switch err { - case user.ErrorEmailAlreadyVerified: + case manager.ErrorEmailAlreadyVerified: execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{ Error: "Invalid Verification Link", Message: "Your email link has expired or has already been verified.", }, http.StatusBadRequest) - case user.ErrorEVEmailDoesntMatch: + case manager.ErrorEVEmailDoesntMatch: execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{ Error: "Invalid Verification Link", Message: "Your email link does not match the email address on file. Perhaps you have a more recent verification link?", diff --git a/server/invitation.go b/server/invitation.go index 4acba885..c70e3291 100644 --- a/server/invitation.go +++ b/server/invitation.go @@ -7,6 +7,7 @@ import ( "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/user" + "github.com/coreos/dex/user/manager" "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/key" ) @@ -18,7 +19,7 @@ type invitationTemplateData struct { type InvitationHandler struct { issuerURL url.URL passwordResetURL url.URL - um *user.Manager + um *manager.UserManager keysFunc func() ([]key.PublicKey, error) signerFunc func() (jose.Signer, error) redirectValidityWindow time.Duration @@ -55,13 +56,13 @@ func (h *InvitationHandler) handleGET(w http.ResponseWriter, r *http.Request) { } _, err = h.um.VerifyEmail(invite) - if err != nil && err != user.ErrorEmailAlreadyVerified { + if err != nil && err != manager.ErrorEmailAlreadyVerified { // Allow AlreadyVerified folks to pass through- otherwise // folks who encounter an error after passing this point will // never be able to set their passwords. log.Debugf("error attempting to verify email: %v", err) switch err { - case user.ErrorEVEmailDoesntMatch: + case manager.ErrorEVEmailDoesntMatch: writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidRequest, "Your email does not match the email address on file")) return diff --git a/server/password.go b/server/password.go index 0e4e42ba..f2f9ffe6 100644 --- a/server/password.go +++ b/server/password.go @@ -12,6 +12,7 @@ import ( "github.com/coreos/dex/session" "github.com/coreos/dex/user" useremail "github.com/coreos/dex/user/email" + "github.com/coreos/dex/user/manager" ) type sendResetPasswordEmailData struct { @@ -181,7 +182,7 @@ type resetPasswordTemplateData struct { type ResetPasswordHandler struct { tpl *template.Template issuerURL url.URL - um *user.Manager + um *manager.UserManager keysFunc func() ([]key.PublicKey, error) } @@ -237,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() { cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext) if err != nil { switch err { - case user.ErrorPasswordAlreadyChanged: + case manager.ErrorPasswordAlreadyChanged: r.data.Error = "Link Expired" r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email." r.data.DontShowForm = true diff --git a/server/register.go b/server/register.go index 0413f7e1..a5eb1b15 100644 --- a/server/register.go +++ b/server/register.go @@ -11,6 +11,7 @@ import ( "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/session" "github.com/coreos/dex/user" + "github.com/coreos/dex/user/manager" "github.com/coreos/go-oidc/oidc" ) @@ -222,7 +223,7 @@ func handleRegisterFunc(s *Server) http.HandlerFunc { } } -func registerFromLocalConnector(userManager *user.Manager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) { +func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) { userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID) if err != nil { return "", err @@ -237,7 +238,7 @@ func registerFromLocalConnector(userManager *user.Manager, sessionManager *sessi return userID, nil } -func registerFromRemoteConnector(userManager *user.Manager, ses *session.Session, email string, emailVerified bool) (string, error) { +func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) { if ses.Identity.ID == "" { return "", errors.New("No Identity found in session.") } diff --git a/server/server.go b/server/server.go index f0a4cf97..e4d7e328 100644 --- a/server/server.go +++ b/server/server.go @@ -25,6 +25,7 @@ import ( "github.com/coreos/dex/user" usersapi "github.com/coreos/dex/user/api" useremail "github.com/coreos/dex/user/email" + "github.com/coreos/dex/user/manager" ) const ( @@ -68,7 +69,7 @@ type Server struct { HealthChecks []health.Checkable Connectors []connector.Connector UserRepo user.UserRepo - UserManager *user.Manager + UserManager *manager.UserManager PasswordInfoRepo user.PasswordInfoRepo RefreshTokenRepo refresh.RefreshTokenRepo UserEmailer *useremail.UserEmailer diff --git a/server/testutil.go b/server/testutil.go index 4f2aae07..99615539 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -15,6 +15,7 @@ import ( "github.com/coreos/dex/session" "github.com/coreos/dex/user" useremail "github.com/coreos/dex/user/email" + "github.com/coreos/dex/user/manager" ) const ( @@ -91,7 +92,6 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc { func makeTestFixtures() (*testFixtures, error) { userRepo := user.NewUserRepoFromUsers(testUsers) pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos) - manager := user.NewManager(userRepo, pwRepo, repo.InMemTransactionFactory, user.ManagerOptions{}) connConfigs := []connector.ConnectorConfig{ &connector.OIDCConnectorConfig{ @@ -111,6 +111,9 @@ func makeTestFixtures() (*testFixtures, error) { ID: "local", }, } + connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs) + + manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{}) sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sessionManager.GenerateCode = sequentialGenerateCodeFunc() diff --git a/server/user.go b/server/user.go index ce2aecd6..713263f7 100644 --- a/server/user.go +++ b/server/user.go @@ -16,6 +16,7 @@ import ( schema "github.com/coreos/dex/schema/workerschema" "github.com/coreos/dex/user" "github.com/coreos/dex/user/api" + "github.com/coreos/dex/user/manager" ) const ( @@ -33,11 +34,11 @@ var ( type UserMgmtServer struct { api *api.UsersAPI jwtvFactory JWTVerifierFactory - um *user.Manager + um *manager.UserManager cir client.ClientIdentityRepo } -func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *user.Manager, cir client.ClientIdentityRepo) *UserMgmtServer { +func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *manager.UserManager, cir client.ClientIdentityRepo) *UserMgmtServer { return &UserMgmtServer{ api: userMgmtAPI, jwtvFactory: jwtvFactory, diff --git a/test b/test index 3d26c7ba..981ca110 100755 --- a/test +++ b/test @@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"} source ./build -TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api email" +TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email" FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log" # user has not provided PKG override diff --git a/user/api/api.go b/user/api/api.go index 687e9757..2eea7f5c 100644 --- a/user/api/api.go +++ b/user/api/api.go @@ -13,6 +13,7 @@ import ( "github.com/coreos/dex/pkg/log" schema "github.com/coreos/dex/schema/workerschema" "github.com/coreos/dex/user" + "github.com/coreos/dex/user/manager" ) var ( @@ -81,7 +82,7 @@ func (e Error) Error() string { // calling User. It is assumed that the clientID has already validated as an // admin app before calling. type UsersAPI struct { - manager *user.Manager + manager *manager.UserManager localConnectorID string clientIdentityRepo client.ClientIdentityRepo emailer Emailer @@ -96,7 +97,7 @@ type Creds struct { User user.User } -func NewUsersAPI(manager *user.Manager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI { +func NewUsersAPI(manager *manager.UserManager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI { return &UsersAPI{ manager: manager, clientIdentityRepo: cir, diff --git a/user/api/api_test.go b/user/api/api_test.go index f6b3255a..967db2ac 100644 --- a/user/api/api_test.go +++ b/user/api/api_test.go @@ -10,9 +10,11 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/coreos/dex/client" + "github.com/coreos/dex/connector" "github.com/coreos/dex/repo" schema "github.com/coreos/dex/schema/workerschema" "github.com/coreos/dex/user" + "github.com/coreos/dex/user/manager" ) type testEmailer struct { @@ -123,7 +125,10 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { Password: []byte("password-2"), }, }) - mgr := user.NewManager(ur, pwr, repo.InMemTransactionFactory, user.ManagerOptions{}) + ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local"}, + }) + mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{}) mgr.Clock = clock ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ diff --git a/user/manager.go b/user/manager/manager.go similarity index 59% rename from user/manager.go rename to user/manager/manager.go index aa8654af..a37ca8b1 100644 --- a/user/manager.go +++ b/user/manager/manager.go @@ -1,4 +1,4 @@ -package user +package manager import ( "errors" @@ -6,8 +6,10 @@ import ( "github.com/jonboulle/clockwork" + "github.com/coreos/dex/connector" "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/repo" + "github.com/coreos/dex/user" ) var ( @@ -19,13 +21,14 @@ var ( // Manager performs user-related "business-logic" functions on user and related objects. // This is in contrast to the Repos which perform little more than CRUD operations. -type Manager struct { +type UserManager struct { Clock clockwork.Clock - userRepo UserRepo - pwRepo PasswordInfoRepo + userRepo user.UserRepo + pwRepo user.PasswordInfoRepo + connCfgRepo connector.ConnectorConfigRepo begin repo.TransactionFactory - userIDGenerator UserIDGenerator + userIDGenerator user.UserIDGenerator } type ManagerOptions struct { @@ -34,58 +37,59 @@ type ManagerOptions struct { // variable policies } -func NewManager(userRepo UserRepo, pwRepo PasswordInfoRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *Manager { - return &Manager{ +func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, connCfgRepo connector.ConnectorConfigRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager { + return &UserManager{ Clock: clockwork.NewRealClock(), userRepo: userRepo, pwRepo: pwRepo, + connCfgRepo: connCfgRepo, begin: txnFactory, - userIDGenerator: DefaultUserIDGenerator, + userIDGenerator: user.DefaultUserIDGenerator, } } -func (m *Manager) Get(id string) (User, error) { +func (m *UserManager) Get(id string) (user.User, error) { return m.userRepo.Get(nil, id) } -func (m *Manager) List(filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) { +func (m *UserManager) List(filter user.UserFilter, maxResults int, nextPageToken string) ([]user.User, string, error) { return m.userRepo.List(nil, filter, maxResults, nextPageToken) } // CreateUser creates a new user with the given hashedPassword; the connID should be the ID of the local connector. // The userID of the created user is returned as the first argument. -func (m *Manager) CreateUser(user User, hashedPassword Password, connID string) (string, error) { +func (m *UserManager) CreateUser(usr user.User, hashedPassword user.Password, connID string) (string, error) { tx, err := m.begin() if err != nil { return "", err } - insertedUser, err := m.insertNewUser(tx, user.Email, user.EmailVerified) + insertedUser, err := m.insertNewUser(tx, usr.Email, usr.EmailVerified) if err != nil { rollback(tx) return "", err } - user.ID = insertedUser.ID - user.CreatedAt = insertedUser.CreatedAt - err = m.userRepo.Update(tx, user) + usr.ID = insertedUser.ID + usr.CreatedAt = insertedUser.CreatedAt + err = m.userRepo.Update(tx, usr) if err != nil { rollback(tx) return "", err } - rid := RemoteIdentity{ + rid := user.RemoteIdentity{ ConnectorID: connID, - ID: user.ID, + ID: usr.ID, } - if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil { + if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil { rollback(tx) return "", err } - pwi := PasswordInfo{ - UserID: user.ID, + pwi := user.PasswordInfo{ + UserID: usr.ID, Password: hashedPassword, } err = m.pwRepo.Create(tx, pwi) @@ -99,10 +103,10 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string) rollback(tx) return "", err } - return user.ID, nil + return usr.ID, nil } -func (m *Manager) Disable(userID string, disabled bool) error { +func (m *UserManager) Disable(userID string, disabled bool) error { tx, err := m.begin() if err = m.userRepo.Disable(tx, userID, disabled); err != nil { @@ -119,7 +123,7 @@ func (m *Manager) Disable(userID string, disabled bool) error { } // RegisterWithRemoteIdentity creates new user and attaches the given remote identity. -func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) { +func (m *UserManager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid user.RemoteIdentity) (string, error) { tx, err := m.begin() if err != nil { return "", err @@ -127,20 +131,20 @@ func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, r if _, err = m.userRepo.GetByRemoteIdentity(tx, rid); err == nil { rollback(tx) - return "", ErrorDuplicateRemoteIdentity + return "", user.ErrorDuplicateRemoteIdentity } - if err != ErrorNotFound { + if err != user.ErrorNotFound { rollback(tx) return "", err } - user, err := m.insertNewUser(tx, email, emailVerified) + usr, err := m.insertNewUser(tx, email, emailVerified) if err != nil { rollback(tx) return "", err } - if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil { + if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil { rollback(tx) return "", err } @@ -150,44 +154,44 @@ func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, r rollback(tx) return "", err } - return user.ID, nil + return usr.ID, nil } // RegisterWithPassword creates a new user with the given name and password. // connID is the connector ID of the ConnectorLocal connector. -func (m *Manager) RegisterWithPassword(email, plaintext, connID string) (string, error) { +func (m *UserManager) RegisterWithPassword(email, plaintext, connID string) (string, error) { tx, err := m.begin() if err != nil { return "", err } - if !ValidPassword(plaintext) { + if !user.ValidPassword(plaintext) { rollback(tx) - return "", ErrorInvalidPassword + return "", user.ErrorInvalidPassword } - user, err := m.insertNewUser(tx, email, false) + usr, err := m.insertNewUser(tx, email, false) if err != nil { rollback(tx) return "", err } - rid := RemoteIdentity{ + rid := user.RemoteIdentity{ ConnectorID: connID, - ID: user.ID, + ID: usr.ID, } - if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil { + if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil { rollback(tx) return "", err } - password, err := NewPasswordFromPlaintext(plaintext) + password, err := user.NewPasswordFromPlaintext(plaintext) if err != nil { rollback(tx) return "", err } - pwi := PasswordInfo{ - UserID: user.ID, + pwi := user.PasswordInfo{ + UserID: usr.ID, Password: password, } @@ -202,7 +206,7 @@ func (m *Manager) RegisterWithPassword(email, plaintext, connID string) (string, rollback(tx) return "", err } - return user.ID, nil + return usr.ID, nil } type EmailVerifiable interface { @@ -218,31 +222,31 @@ type EmailVerifiable interface { // create it, ensuring that the token was signed and that the JWT was not // expired. // The callback url (i.e. where to send the user after the verification) is returned. -func (m *Manager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) { +func (m *UserManager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) { tx, err := m.begin() if err != nil { return nil, err } - user, err := m.userRepo.Get(tx, ev.UserID()) + usr, err := m.userRepo.Get(tx, ev.UserID()) if err != nil { rollback(tx) return nil, err } - if user.Email != ev.Email() { + if usr.Email != ev.Email() { rollback(tx) return nil, ErrorEVEmailDoesntMatch } - if user.EmailVerified { + if usr.EmailVerified { rollback(tx) return nil, ErrorEmailAlreadyVerified } - user.EmailVerified = true + usr.EmailVerified = true - err = m.userRepo.Update(tx, user) + err = m.userRepo.Update(tx, usr) if err != nil { rollback(tx) return nil, err @@ -258,19 +262,19 @@ func (m *Manager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) { type PasswordChangeable interface { UserID() string - Password() Password + Password() user.Password Callback() *url.URL } -func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url.URL, error) { +func (m *UserManager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url.URL, error) { tx, err := m.begin() if err != nil { return nil, err } - if !ValidPassword(plaintext) { + if !user.ValidPassword(plaintext) { rollback(tx) - return nil, ErrorInvalidPassword + return nil, user.ErrorInvalidPassword } pwi, err := m.pwRepo.Get(tx, pwr.UserID()) @@ -284,7 +288,7 @@ func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url return nil, ErrorPasswordAlreadyChanged } - newPass, err := NewPasswordFromPlaintext(plaintext) + newPass, err := user.NewPasswordFromPlaintext(plaintext) if err != nil { rollback(tx) return nil, err @@ -305,36 +309,46 @@ func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url return pwr.Callback(), nil } -func (m *Manager) insertNewUser(tx repo.Transaction, email string, emailVerified bool) (User, error) { - if !ValidEmail(email) { - return User{}, ErrorInvalidEmail +func (m *UserManager) insertNewUser(tx repo.Transaction, email string, emailVerified bool) (user.User, error) { + if !user.ValidEmail(email) { + return user.User{}, user.ErrorInvalidEmail } var err error if _, err = m.userRepo.GetByEmail(tx, email); err == nil { - return User{}, ErrorDuplicateEmail + return user.User{}, user.ErrorDuplicateEmail } - if err != ErrorNotFound { - return User{}, err + if err != user.ErrorNotFound { + return user.User{}, err } userID, err := m.userIDGenerator() if err != nil { - return User{}, err + return user.User{}, err } - user := User{ + usr := user.User{ ID: userID, Email: email, EmailVerified: emailVerified, CreatedAt: m.Clock.Now(), } - err = m.userRepo.Create(tx, user) + err = m.userRepo.Create(tx, usr) if err != nil { - return User{}, err + return user.User{}, err } - return user, nil + return usr, nil +} + +func (m *UserManager) addRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error { + if _, err := m.connCfgRepo.GetConnectorByID(tx, rid.ConnectorID); err != nil { + return err + } + if err := m.userRepo.AddRemoteIdentity(tx, userID, rid); err != nil { + return err + } + return nil } func rollback(tx repo.Transaction) { diff --git a/user/manager_test.go b/user/manager/manager_test.go similarity index 72% rename from user/manager_test.go rename to user/manager/manager_test.go index 5e9db7b2..fbe0a4a3 100644 --- a/user/manager_test.go +++ b/user/manager/manager_test.go @@ -1,4 +1,4 @@ -package user +package manager import ( "net/url" @@ -9,13 +9,16 @@ import ( "github.com/jonboulle/clockwork" "github.com/kylelemons/godebug/pretty" + "github.com/coreos/dex/connector" "github.com/coreos/dex/repo" + "github.com/coreos/dex/user" ) type testFixtures struct { - ur UserRepo - pwr PasswordInfoRepo - mgr *Manager + ur user.UserRepo + pwr user.PasswordInfoRepo + ccr connector.ConnectorConfigRepo + mgr *UserManager clock clockwork.Clock } @@ -23,25 +26,25 @@ func makeTestFixtures() *testFixtures { f := &testFixtures{} f.clock = clockwork.NewFakeClock() - f.ur = NewUserRepoFromUsers([]UserWithRemoteIdentities{ + f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ { - User: User{ + User: user.User{ ID: "ID-1", Email: "Email-1@example.com", }, - RemoteIdentities: []RemoteIdentity{ + RemoteIdentities: []user.RemoteIdentity{ { ConnectorID: "local", ID: "1", }, }, }, { - User: User{ + User: user.User{ ID: "ID-2", Email: "Email-2@example.com", EmailVerified: true, }, - RemoteIdentities: []RemoteIdentity{ + RemoteIdentities: []user.RemoteIdentity{ { ConnectorID: "local", ID: "2", @@ -49,7 +52,7 @@ func makeTestFixtures() *testFixtures { }, }, }) - f.pwr = NewPasswordInfoRepoFromPasswordInfos([]PasswordInfo{ + f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ { UserID: "ID-1", Password: []byte("password-1"), @@ -59,7 +62,10 @@ func makeTestFixtures() *testFixtures { Password: []byte("password-2"), }, }) - f.mgr = NewManager(f.ur, f.pwr, repo.InMemTransactionFactory, ManagerOptions{}) + f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local"}, + }) + f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{}) f.mgr.Clock = f.clock return f } @@ -68,13 +74,13 @@ func TestRegisterWithRemoteIdentity(t *testing.T) { tests := []struct { email string emailVerified bool - rid RemoteIdentity + rid user.RemoteIdentity err error }{ { email: "email@example.com", emailVerified: false, - rid: RemoteIdentity{ + rid: user.RemoteIdentity{ ConnectorID: "local", ID: "1234", }, @@ -82,20 +88,29 @@ func TestRegisterWithRemoteIdentity(t *testing.T) { }, { emailVerified: false, - rid: RemoteIdentity{ + rid: user.RemoteIdentity{ ConnectorID: "local", ID: "1234", }, - err: ErrorInvalidEmail, + err: user.ErrorInvalidEmail, }, { email: "email@example.com", emailVerified: false, - rid: RemoteIdentity{ + rid: user.RemoteIdentity{ ConnectorID: "local", ID: "1", }, - err: ErrorDuplicateRemoteIdentity, + err: user.ErrorDuplicateRemoteIdentity, + }, + { + email: "anotheremail@example.com", + emailVerified: false, + rid: user.RemoteIdentity{ + ConnectorID: "idonotexist", + ID: "1", + }, + err: connector.ErrorNotFound, }, } @@ -148,17 +163,17 @@ func TestRegisterWithPassword(t *testing.T) { }, { plaintext: "secretpassword123", - err: ErrorInvalidEmail, + err: user.ErrorInvalidEmail, }, { email: "email@example.com", - err: ErrorInvalidPassword, + err: user.ErrorInvalidPassword, }, } for i, tt := range tests { f := makeTestFixtures() - connID := "connID" + connID := "local" userID, err := f.mgr.RegisterWithPassword( tt.email, tt.plaintext, @@ -183,7 +198,7 @@ func TestRegisterWithPassword(t *testing.T) { t.Errorf("case %d: user.EmailVerified: want=%v, got=%v", i, false, usr.EmailVerified) } - ridUSR, err := f.ur.GetByRemoteIdentity(nil, RemoteIdentity{ + ridUSR, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{ ID: userID, ConnectorID: connID, }) @@ -220,12 +235,12 @@ func TestVerifyEmail(t *testing.T) { callback := "http://client.example.com/callback" expires := time.Hour * 3 - makeClaims := func(usr User) jose.Claims { + makeClaims := func(usr user.User) jose.Claims { return map[string]interface{}{ "iss": issuer.String(), "aud": clientID, - ClaimEmailVerificationCallback: callback, - ClaimEmailVerificationEmail: usr.Email, + user.ClaimEmailVerificationCallback: callback, + user.ClaimEmailVerificationEmail: usr.Email, "exp": float64(now.Add(expires).Unix()), "sub": usr.ID, "iat": float64(now.Unix()), @@ -238,28 +253,28 @@ func TestVerifyEmail(t *testing.T) { }{ { // happy path - evClaims: makeClaims(User{ID: "ID-1", Email: "Email-1@example.com"}), + evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-1@example.com"}), }, { // non-matching email - evClaims: makeClaims(User{ID: "ID-1", Email: "Email-2@example.com"}), + evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-2@example.com"}), wantErr: true, }, { // already verified email - evClaims: makeClaims(User{ID: "ID-2", Email: "Email-2@example.com"}), + evClaims: makeClaims(user.User{ID: "ID-2", Email: "Email-2@example.com"}), wantErr: true, }, { // non-existent user. - evClaims: makeClaims(User{ID: "ID-UNKNOWN", Email: "noone@example.com"}), + evClaims: makeClaims(user.User{ID: "ID-UNKNOWN", Email: "noone@example.com"}), wantErr: true, }, } for i, tt := range tests { f := makeTestFixtures() - cb, err := f.mgr.VerifyEmail(EmailVerification{tt.evClaims}) + cb, err := f.mgr.VerifyEmail(user.EmailVerification{tt.evClaims}) if tt.wantErr { if err == nil { t.Errorf("case %d: want non-nil err", i) @@ -271,9 +286,9 @@ func TestVerifyEmail(t *testing.T) { t.Errorf("case %d: want err=nil got=%q", i, err) } - if cb.String() != tt.evClaims[ClaimEmailVerificationCallback] { + if cb.String() != tt.evClaims[user.ClaimEmailVerificationCallback] { t.Errorf("case %d: want=%q, got=%q", i, cb.String(), - tt.evClaims[ClaimEmailVerificationCallback]) + tt.evClaims[user.ClaimEmailVerificationCallback]) } } } @@ -290,8 +305,8 @@ func TestChangePassword(t *testing.T) { return map[string]interface{}{ "iss": issuer.String(), "aud": clientID, - ClaimPasswordResetCallback: callback, - ClaimPasswordResetPassword: password, + user.ClaimPasswordResetCallback: callback, + user.ClaimPasswordResetPassword: password, "exp": float64(now.Add(expires).Unix()), "sub": usrID, "iat": float64(now.Unix()), @@ -329,7 +344,7 @@ func TestChangePassword(t *testing.T) { for i, tt := range tests { f := makeTestFixtures() - cb, err := f.mgr.ChangePassword(PasswordReset{tt.pwrClaims}, tt.newPassword) + cb, err := f.mgr.ChangePassword(user.PasswordReset{tt.pwrClaims}, tt.newPassword) if tt.wantErr { if err == nil { t.Errorf("case %d: want non-nil err", i) @@ -346,47 +361,61 @@ func TestChangePassword(t *testing.T) { if cb != nil { cbString = cb.String() } - if cbString != tt.pwrClaims[ClaimPasswordResetCallback] { + if cbString != tt.pwrClaims[user.ClaimPasswordResetCallback] { t.Errorf("case %d: want=%q, got=%q", i, cb.String(), - tt.pwrClaims[ClaimPasswordResetCallback]) + tt.pwrClaims[user.ClaimPasswordResetCallback]) } } } func TestCreateUser(t *testing.T) { tests := []struct { - usr User - hashedPW Password + usr user.User + hashedPW user.Password + localID string // defaults to "local" wantErr bool }{ { - usr: User{ + usr: user.User{ DisplayName: "Bob Exampleson", Email: "bob@example.com", }, - hashedPW: Password("I am a hash"), + hashedPW: user.Password("I am a hash"), }, { - usr: User{ + usr: user.User{ DisplayName: "Al Adminson", Email: "al@example.com", Admin: true, }, - hashedPW: Password("I am a hash"), + hashedPW: user.Password("I am a hash"), }, { - usr: User{ + usr: user.User{ DisplayName: "Ed Emailless", }, - hashedPW: Password("I am a hash"), + hashedPW: user.Password("I am a hash"), + wantErr: true, + }, + { + usr: user.User{ + DisplayName: "Eric Exampleson", + Email: "eric@example.com", + }, + hashedPW: user.Password("I am a hash"), + localID: "abadlocalid", wantErr: true, }, } for i, tt := range tests { f := makeTestFixtures() - id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, "local") + localID := "local" + if tt.localID != "" { + localID = tt.localID + } + id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID) if tt.wantErr { if err == nil { t.Errorf("case %d: want non-nil err", i) @@ -422,7 +451,7 @@ func TestCreateUser(t *testing.T) { t.Errorf("case %d: want=%q, got=%q", i, tt.hashedPW, pwi.Password) } - ridUser, err := f.ur.GetByRemoteIdentity(nil, RemoteIdentity{ + ridUser, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{ ID: id, ConnectorID: "local", })