diff --git a/cmd/dex-worker/main.go b/cmd/dex-worker/main.go index cf21cf3e..234c96f6 100644 --- a/cmd/dex-worker/main.go +++ b/cmd/dex-worker/main.go @@ -46,7 +46,9 @@ func main() { emailFrom := fs.String("email-from", "", "emails sent from dex will come from this address") emailConfig := fs.String("email-cfg", "./static/fixtures/emailer.json", "configures emailer.") - enableRegistration := fs.Bool("enable-registration", false, "Allows users to self-register") + enableRegistration := fs.Bool("enable-registration", false, "Allows users to self-register. This flag cannot be used in combination with --enable-automatic-registration.") + registerOnFirstLogin := fs.Bool("enable-automatic-registration", false, "When a user logs in through a federated identity service, automatically register them if they don't have an account. This flag cannot be used in combination with --enable-registration.") + enableClientRegistration := fs.Bool("enable-client-registration", false, "Allow dynamic registration of clients") noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing") @@ -90,6 +92,11 @@ func main() { os.Exit(0) } + if (*enableRegistration) && (*registerOnFirstLogin) { + fmt.Fprintln(os.Stderr, "The flags --enable-registration and --enable-automatic-login cannot both be true.") + os.Exit(1) + } + if *logDebug { log.EnableDebug() log.Infof("Debug logging enabled.") @@ -135,6 +142,7 @@ func main() { IssuerLogoURL: *issuerLogoURL, EnableRegistration: *enableRegistration, EnableClientRegistration: *enableClientRegistration, + RegisterOnFirstLogin: *registerOnFirstLogin, } if *noDB { diff --git a/server/config.go b/server/config.go index 06232528..0f329d04 100644 --- a/server/config.go +++ b/server/config.go @@ -38,6 +38,7 @@ type ServerConfig struct { StateConfig StateConfigurer EnableRegistration bool EnableClientRegistration bool + RegisterOnFirstLogin bool } type StateConfigurer interface { @@ -78,6 +79,7 @@ func (cfg *ServerConfig) Server() (*Server, error) { EnableRegistration: cfg.EnableRegistration, EnableClientRegistration: cfg.EnableClientRegistration, + RegisterOnFirstLogin: cfg.RegisterOnFirstLogin, } err = cfg.StateConfig.Configure(&srv) diff --git a/server/server.go b/server/server.go index 22cf42d7..172998b5 100644 --- a/server/server.go +++ b/server/server.go @@ -64,28 +64,35 @@ type OIDCServer interface { type JWTVerifierFactory func(clientID string) oidc.JWTVerifier type Server struct { - IssuerURL url.URL - KeyManager key.PrivateKeyManager - KeySetRepo key.PrivateKeySetRepo - SessionManager *sessionmanager.SessionManager - ClientRepo client.ClientRepo - ConnectorConfigRepo connector.ConnectorConfigRepo + IssuerURL url.URL + Templates *template.Template LoginTemplate *template.Template RegisterTemplate *template.Template VerifyEmailTemplate *template.Template SendResetPasswordEmailTemplate *template.Template ResetPasswordTemplate *template.Template - HealthChecks []health.Checkable - Connectors []connector.Connector - UserRepo user.UserRepo - UserManager *usermanager.UserManager - ClientManager *clientmanager.ClientManager - PasswordInfoRepo user.PasswordInfoRepo - RefreshTokenRepo refresh.RefreshTokenRepo - UserEmailer *useremail.UserEmailer - EnableRegistration bool - EnableClientRegistration bool + + HealthChecks []health.Checkable + Connectors []connector.Connector + + ClientRepo client.ClientRepo + ConnectorConfigRepo connector.ConnectorConfigRepo + KeySetRepo key.PrivateKeySetRepo + RefreshTokenRepo refresh.RefreshTokenRepo + UserRepo user.UserRepo + PasswordInfoRepo user.PasswordInfoRepo + + ClientManager *clientmanager.ClientManager + KeyManager key.PrivateKeyManager + SessionManager *sessionmanager.SessionManager + UserManager *usermanager.UserManager + + UserEmailer *useremail.UserEmailer + + EnableRegistration bool + EnableClientRegistration bool + RegisterOnFirstLogin bool dbMap *gorp.DbMap localConnectorID string @@ -323,42 +330,72 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { return ru.String(), nil } - usr, err := s.UserRepo.GetByRemoteIdentity(nil, user.RemoteIdentity{ - ConnectorID: ses.ConnectorID, - ID: ses.Identity.ID, - }) + remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID} + + // Get the connector used to log the user in. + var conn connector.Connector + for _, c := range s.Connectors { + if c.ID() == ses.ConnectorID { + conn = c + break + } + } + if conn == nil { + return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID) + } + + usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity) if err == user.ErrorNotFound { - // Does the user have an existing account with a different connector? - if ses.Identity.Email != "" { - connID, err := getConnectorForUserByEmail(s.UserRepo, ses.Identity.Email) - if err == nil { - // Ask user to sign in through existing account. - u := newLoginURLFromSession(s.IssuerURL, ses, false, []string{connID}, "wrong-connector") - return u.String(), nil - } + if ses.Identity.Email == "" { + // User doesn't have an existing account. Ask them to register. + u := newLoginURLFromSession(s.IssuerURL, ses, true, []string{ses.ConnectorID}, "register-maybe") + return u.String(), nil } - // User doesn't have an existing account. Ask them to register. - u := newLoginURLFromSession(s.IssuerURL, ses, true, []string{ses.ConnectorID}, "register-maybe") - return u.String(), nil - } - if err != nil { - return "", err + // Does the user have an existing account with a different connector? + if connID, err := getConnectorForUserByEmail(s.UserRepo, ses.Identity.Email); err == nil { + // Ask user to sign in through existing account. + u := newLoginURLFromSession(s.IssuerURL, ses, false, []string{connID}, "wrong-connector") + return u.String(), nil + } + + // RegisterOnFirstLogin doesn't work for the local connector + tryToRegister := s.RegisterOnFirstLogin && (ses.ConnectorID != s.localConnectorID) + + if !tryToRegister { + // User doesn't have an existing account. Ask them to register. + u := newLoginURLFromSession(s.IssuerURL, ses, true, []string{ses.ConnectorID}, "register-maybe") + return u.String(), nil + } + + // First time logging in through a remote connector. Attempt to register. + emailVerified := conn.TrustedEmailProvider() + usrID, err := s.UserManager.RegisterWithRemoteIdentity(ses.Identity.Email, emailVerified, remoteIdentity) + if err != nil { + return "", fmt.Errorf("failed to register user: %v", err) + } + usr, err = s.UserManager.Get(usrID) + if err != nil { + return "", fmt.Errorf("getting created user: %v", err) + } + } else if err != nil { + return "", fmt.Errorf("getting user: %v", err) } if usr.Disabled { + log.Errorf("user %s disabled", ses.Identity.Email) return "", user.ErrorNotFound } ses, err = s.SessionManager.AttachUser(sessionID, usr.ID) if err != nil { - return "", err + return "", fmt.Errorf("attaching user to session: %v", err) } log.Infof("Session %s user identified: clientID=%s user=%#v", sessionID, ses.ClientID, usr) code, err := s.SessionManager.NewSessionKey(sessionID) if err != nil { - return "", err + return "", fmt.Errorf("creating new session key: %v", err) } ru := ses.RedirectURL diff --git a/server/server_test.go b/server/server_test.go index 44a7f2ed..8834ad99 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "reflect" + "strings" "testing" "time" @@ -185,31 +186,121 @@ func TestServerNewSession(t *testing.T) { } func TestServerLogin(t *testing.T) { - f, err := makeTestFixtures() - if err != nil { - t.Fatalf("error making test fixtures: %v", err) + + tests := []struct { + testCase string + connectorID string + clientID string + userID string + remoteUserID string + email string + configure func(s *Server) + + wantError bool // should server.Login fail? + wantLogin bool // should server.Login redirect back to the app? + }{ + { + testCase: "good user", + connectorID: testConnectorID1, + clientID: testClientID, + userID: testUserID1, + remoteUserID: testUserRemoteID1, + email: testUserEmail1, + wantLogin: true, + }, + { + testCase: "user has remote identity with another connector", + connectorID: testConnectorIDOpenID, + clientID: testClientID, + userID: testUserID1, + remoteUserID: testUserRemoteID1, + email: testUserEmail1, + wantLogin: false, + }, + { + testCase: "unknown connector id", + connectorID: "bad connector id", + clientID: testClientID, + userID: testUserID1, + remoteUserID: testUserRemoteID1, + email: testUserEmail1, + wantError: true, + }, + { + testCase: "unregistered user", + connectorID: testConnectorIDOpenID, + clientID: testClientID, + userID: testUserID1, + remoteUserID: "unregistered-user-id", + email: "newemail@example.com", + wantLogin: false, + }, + { + testCase: "unregistered user with register on first login", + connectorID: testConnectorIDOpenID, + clientID: testClientID, + userID: testUserID1, + remoteUserID: "unregistered-user-id", + email: "newemail@example.com", + configure: func(srv *Server) { srv.RegisterOnFirstLogin = true }, + wantLogin: true, + }, + { + testCase: "unregistered user through local connector with register on first login", + connectorID: testConnectorLocalID, + clientID: testClientID, + userID: testUserID1, + remoteUserID: "unregistered-user-id", + email: "newemail@example.com", + configure: func(srv *Server) { srv.RegisterOnFirstLogin = true }, + wantLogin: false, + }, } - sm := f.sessionManager - sessionID, err := sm.NewSession("IDPC-1", testClientID, "bogus", testRedirectURL, "", false, []string{"openid"}) + for _, tt := range tests { + f, err := makeTestFixtures() + if err != nil { + t.Fatalf("error making test fixtures: %v", err) + } - ident := oidc.Identity{ID: testUserRemoteID1, Name: "elroy", Email: testUserEmail1} - key, err := sm.NewSessionKey(sessionID) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + if tt.configure != nil { + tt.configure(f.srv) + } - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - redirectURL, err := f.srv.Login(ident, key) - if err != nil { - t.Fatalf("Unexpected err from Server.Login: %v", err) - } + sm := f.sessionManager + sessionID, err := sm.NewSession(tt.connectorID, tt.clientID, "bogus", testRedirectURL, "", false, []string{"openid"}) + if err != nil { + t.Errorf("case %s: new session: %v", tt.testCase, err) + continue + } - wantRedirectURL := "http://client.example.com/callback?code=code-3&state=bogus" - if wantRedirectURL != redirectURL { - t.Fatalf("Unexpected redirectURL: want=%q, got=%q", wantRedirectURL, redirectURL) + key, err := sm.NewSessionKey(sessionID) + if err != nil { + t.Errorf("case %s: new session key: %v", tt.testCase, err) + continue + } + + ident := oidc.Identity{ID: tt.remoteUserID, Name: "elroy", Email: tt.email} + redirectURL, err := f.srv.Login(ident, key) + if err != nil { + if !tt.wantError { + t.Errorf("case %s: server.Login: %v", tt.testCase, err) + } + continue + } + if tt.wantError { + t.Errorf("case %s: expected server.Login to fail", tt.testCase) + continue + } + + // Did the server redirect back to the client app or display an error to the user? + gotRedirectURL := strings.HasPrefix(redirectURL, testRedirectURL.String()) + if gotRedirectURL && !tt.wantLogin { + t.Errorf("case %s: should not have logged in", tt.testCase) + } + if !gotRedirectURL && tt.wantLogin { + t.Errorf("case %s: failed to log in. expected redirect url got: %s", tt.testCase, redirectURL) + } } } diff --git a/server/testutil.go b/server/testutil_test.go similarity index 92% rename from server/testutil.go rename to server/testutil_test.go index 2155d70d..16317a03 100644 --- a/server/testutil.go +++ b/server/testutil_test.go @@ -54,6 +54,10 @@ var ( testConnectorID1 = "IDPC-1" + testConnectorIDOpenID = "oidc" + testConnectorIDOpenIDTrusted = "oidc-trusted" + testConnectorLocalID = "local" + testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"} testUsers = []user.UserWithRemoteIdentities{ @@ -143,20 +147,27 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err connConfigs := []connector.ConnectorConfig{ &connector.OIDCConnectorConfig{ - ID: "oidc", + ID: testConnectorIDOpenID, IssuerURL: testIssuerURL.String(), ClientID: "12345", ClientSecret: "567789", }, &connector.OIDCConnectorConfig{ - ID: "oidc-trusted", + ID: testConnectorIDOpenIDTrusted, IssuerURL: testIssuerURL.String(), ClientID: "12345-trusted", ClientSecret: "567789-trusted", TrustedEmailProvider: true, }, + &connector.OIDCConnectorConfig{ + ID: testConnectorID1, + IssuerURL: testIssuerURL.String(), + ClientID: testConnectorID1 + "_client_id", + ClientSecret: testConnectorID1 + "_client_secret", + TrustedEmailProvider: true, + }, &connector.LocalConnectorConfig{ - ID: "local", + ID: testConnectorLocalID, }, } connCfgRepo := db.NewConnectorConfigRepo(dbMap)