forked from mystiq/dex
Merge pull request #463 from ericchiang/register-on-first-login
*: add --enable-automatic-registration flag to worker
This commit is contained in:
commit
e92b6a5908
5 changed files with 209 additions and 60 deletions
|
@ -46,7 +46,9 @@ func main() {
|
||||||
emailFrom := fs.String("email-from", "", "emails sent from dex will come from this address")
|
emailFrom := fs.String("email-from", "", "emails sent from dex will come from this address")
|
||||||
emailConfig := fs.String("email-cfg", "./static/fixtures/emailer.json", "configures emailer.")
|
emailConfig := fs.String("email-cfg", "./static/fixtures/emailer.json", "configures emailer.")
|
||||||
|
|
||||||
enableRegistration := fs.Bool("enable-registration", false, "Allows users to self-register")
|
enableRegistration := fs.Bool("enable-registration", false, "Allows users to self-register. This flag cannot be used in combination with --enable-automatic-registration.")
|
||||||
|
registerOnFirstLogin := fs.Bool("enable-automatic-registration", false, "When a user logs in through a federated identity service, automatically register them if they don't have an account. This flag cannot be used in combination with --enable-registration.")
|
||||||
|
|
||||||
enableClientRegistration := fs.Bool("enable-client-registration", false, "Allow dynamic registration of clients")
|
enableClientRegistration := fs.Bool("enable-client-registration", false, "Allow dynamic registration of clients")
|
||||||
|
|
||||||
noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing")
|
noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing")
|
||||||
|
@ -90,6 +92,11 @@ func main() {
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (*enableRegistration) && (*registerOnFirstLogin) {
|
||||||
|
fmt.Fprintln(os.Stderr, "The flags --enable-registration and --enable-automatic-login cannot both be true.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
if *logDebug {
|
if *logDebug {
|
||||||
log.EnableDebug()
|
log.EnableDebug()
|
||||||
log.Infof("Debug logging enabled.")
|
log.Infof("Debug logging enabled.")
|
||||||
|
@ -135,6 +142,7 @@ func main() {
|
||||||
IssuerLogoURL: *issuerLogoURL,
|
IssuerLogoURL: *issuerLogoURL,
|
||||||
EnableRegistration: *enableRegistration,
|
EnableRegistration: *enableRegistration,
|
||||||
EnableClientRegistration: *enableClientRegistration,
|
EnableClientRegistration: *enableClientRegistration,
|
||||||
|
RegisterOnFirstLogin: *registerOnFirstLogin,
|
||||||
}
|
}
|
||||||
|
|
||||||
if *noDB {
|
if *noDB {
|
||||||
|
|
|
@ -38,6 +38,7 @@ type ServerConfig struct {
|
||||||
StateConfig StateConfigurer
|
StateConfig StateConfigurer
|
||||||
EnableRegistration bool
|
EnableRegistration bool
|
||||||
EnableClientRegistration bool
|
EnableClientRegistration bool
|
||||||
|
RegisterOnFirstLogin bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateConfigurer interface {
|
type StateConfigurer interface {
|
||||||
|
@ -78,6 +79,7 @@ func (cfg *ServerConfig) Server() (*Server, error) {
|
||||||
|
|
||||||
EnableRegistration: cfg.EnableRegistration,
|
EnableRegistration: cfg.EnableRegistration,
|
||||||
EnableClientRegistration: cfg.EnableClientRegistration,
|
EnableClientRegistration: cfg.EnableClientRegistration,
|
||||||
|
RegisterOnFirstLogin: cfg.RegisterOnFirstLogin,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = cfg.StateConfig.Configure(&srv)
|
err = cfg.StateConfig.Configure(&srv)
|
||||||
|
|
109
server/server.go
109
server/server.go
|
@ -64,28 +64,35 @@ type OIDCServer interface {
|
||||||
type JWTVerifierFactory func(clientID string) oidc.JWTVerifier
|
type JWTVerifierFactory func(clientID string) oidc.JWTVerifier
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
IssuerURL url.URL
|
IssuerURL url.URL
|
||||||
KeyManager key.PrivateKeyManager
|
|
||||||
KeySetRepo key.PrivateKeySetRepo
|
|
||||||
SessionManager *sessionmanager.SessionManager
|
|
||||||
ClientRepo client.ClientRepo
|
|
||||||
ConnectorConfigRepo connector.ConnectorConfigRepo
|
|
||||||
Templates *template.Template
|
Templates *template.Template
|
||||||
LoginTemplate *template.Template
|
LoginTemplate *template.Template
|
||||||
RegisterTemplate *template.Template
|
RegisterTemplate *template.Template
|
||||||
VerifyEmailTemplate *template.Template
|
VerifyEmailTemplate *template.Template
|
||||||
SendResetPasswordEmailTemplate *template.Template
|
SendResetPasswordEmailTemplate *template.Template
|
||||||
ResetPasswordTemplate *template.Template
|
ResetPasswordTemplate *template.Template
|
||||||
HealthChecks []health.Checkable
|
|
||||||
Connectors []connector.Connector
|
HealthChecks []health.Checkable
|
||||||
UserRepo user.UserRepo
|
Connectors []connector.Connector
|
||||||
UserManager *usermanager.UserManager
|
|
||||||
ClientManager *clientmanager.ClientManager
|
ClientRepo client.ClientRepo
|
||||||
PasswordInfoRepo user.PasswordInfoRepo
|
ConnectorConfigRepo connector.ConnectorConfigRepo
|
||||||
RefreshTokenRepo refresh.RefreshTokenRepo
|
KeySetRepo key.PrivateKeySetRepo
|
||||||
UserEmailer *useremail.UserEmailer
|
RefreshTokenRepo refresh.RefreshTokenRepo
|
||||||
EnableRegistration bool
|
UserRepo user.UserRepo
|
||||||
EnableClientRegistration bool
|
PasswordInfoRepo user.PasswordInfoRepo
|
||||||
|
|
||||||
|
ClientManager *clientmanager.ClientManager
|
||||||
|
KeyManager key.PrivateKeyManager
|
||||||
|
SessionManager *sessionmanager.SessionManager
|
||||||
|
UserManager *usermanager.UserManager
|
||||||
|
|
||||||
|
UserEmailer *useremail.UserEmailer
|
||||||
|
|
||||||
|
EnableRegistration bool
|
||||||
|
EnableClientRegistration bool
|
||||||
|
RegisterOnFirstLogin bool
|
||||||
|
|
||||||
dbMap *gorp.DbMap
|
dbMap *gorp.DbMap
|
||||||
localConnectorID string
|
localConnectorID string
|
||||||
|
@ -323,42 +330,72 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
||||||
return ru.String(), nil
|
return ru.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
usr, err := s.UserRepo.GetByRemoteIdentity(nil, user.RemoteIdentity{
|
remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID}
|
||||||
ConnectorID: ses.ConnectorID,
|
|
||||||
ID: ses.Identity.ID,
|
// Get the connector used to log the user in.
|
||||||
})
|
var conn connector.Connector
|
||||||
|
for _, c := range s.Connectors {
|
||||||
|
if c.ID() == ses.ConnectorID {
|
||||||
|
conn = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
|
||||||
|
}
|
||||||
|
|
||||||
|
usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity)
|
||||||
if err == user.ErrorNotFound {
|
if err == user.ErrorNotFound {
|
||||||
// Does the user have an existing account with a different connector?
|
if ses.Identity.Email == "" {
|
||||||
if ses.Identity.Email != "" {
|
// User doesn't have an existing account. Ask them to register.
|
||||||
connID, err := getConnectorForUserByEmail(s.UserRepo, ses.Identity.Email)
|
u := newLoginURLFromSession(s.IssuerURL, ses, true, []string{ses.ConnectorID}, "register-maybe")
|
||||||
if err == nil {
|
return u.String(), nil
|
||||||
// Ask user to sign in through existing account.
|
|
||||||
u := newLoginURLFromSession(s.IssuerURL, ses, false, []string{connID}, "wrong-connector")
|
|
||||||
return u.String(), nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// User doesn't have an existing account. Ask them to register.
|
// Does the user have an existing account with a different connector?
|
||||||
u := newLoginURLFromSession(s.IssuerURL, ses, true, []string{ses.ConnectorID}, "register-maybe")
|
if connID, err := getConnectorForUserByEmail(s.UserRepo, ses.Identity.Email); err == nil {
|
||||||
return u.String(), nil
|
// Ask user to sign in through existing account.
|
||||||
}
|
u := newLoginURLFromSession(s.IssuerURL, ses, false, []string{connID}, "wrong-connector")
|
||||||
if err != nil {
|
return u.String(), nil
|
||||||
return "", err
|
}
|
||||||
|
|
||||||
|
// RegisterOnFirstLogin doesn't work for the local connector
|
||||||
|
tryToRegister := s.RegisterOnFirstLogin && (ses.ConnectorID != s.localConnectorID)
|
||||||
|
|
||||||
|
if !tryToRegister {
|
||||||
|
// User doesn't have an existing account. Ask them to register.
|
||||||
|
u := newLoginURLFromSession(s.IssuerURL, ses, true, []string{ses.ConnectorID}, "register-maybe")
|
||||||
|
return u.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// First time logging in through a remote connector. Attempt to register.
|
||||||
|
emailVerified := conn.TrustedEmailProvider()
|
||||||
|
usrID, err := s.UserManager.RegisterWithRemoteIdentity(ses.Identity.Email, emailVerified, remoteIdentity)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to register user: %v", err)
|
||||||
|
}
|
||||||
|
usr, err = s.UserManager.Get(usrID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("getting created user: %v", err)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return "", fmt.Errorf("getting user: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if usr.Disabled {
|
if usr.Disabled {
|
||||||
|
log.Errorf("user %s disabled", ses.Identity.Email)
|
||||||
return "", user.ErrorNotFound
|
return "", user.ErrorNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
ses, err = s.SessionManager.AttachUser(sessionID, usr.ID)
|
ses, err = s.SessionManager.AttachUser(sessionID, usr.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("attaching user to session: %v", err)
|
||||||
}
|
}
|
||||||
log.Infof("Session %s user identified: clientID=%s user=%#v", sessionID, ses.ClientID, usr)
|
log.Infof("Session %s user identified: clientID=%s user=%#v", sessionID, ses.ClientID, usr)
|
||||||
|
|
||||||
code, err := s.SessionManager.NewSessionKey(sessionID)
|
code, err := s.SessionManager.NewSessionKey(sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("creating new session key: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ru := ses.RedirectURL
|
ru := ses.RedirectURL
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -185,31 +186,121 @@ func TestServerNewSession(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerLogin(t *testing.T) {
|
func TestServerLogin(t *testing.T) {
|
||||||
f, err := makeTestFixtures()
|
|
||||||
if err != nil {
|
tests := []struct {
|
||||||
t.Fatalf("error making test fixtures: %v", err)
|
testCase string
|
||||||
|
connectorID string
|
||||||
|
clientID string
|
||||||
|
userID string
|
||||||
|
remoteUserID string
|
||||||
|
email string
|
||||||
|
configure func(s *Server)
|
||||||
|
|
||||||
|
wantError bool // should server.Login fail?
|
||||||
|
wantLogin bool // should server.Login redirect back to the app?
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
testCase: "good user",
|
||||||
|
connectorID: testConnectorID1,
|
||||||
|
clientID: testClientID,
|
||||||
|
userID: testUserID1,
|
||||||
|
remoteUserID: testUserRemoteID1,
|
||||||
|
email: testUserEmail1,
|
||||||
|
wantLogin: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testCase: "user has remote identity with another connector",
|
||||||
|
connectorID: testConnectorIDOpenID,
|
||||||
|
clientID: testClientID,
|
||||||
|
userID: testUserID1,
|
||||||
|
remoteUserID: testUserRemoteID1,
|
||||||
|
email: testUserEmail1,
|
||||||
|
wantLogin: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testCase: "unknown connector id",
|
||||||
|
connectorID: "bad connector id",
|
||||||
|
clientID: testClientID,
|
||||||
|
userID: testUserID1,
|
||||||
|
remoteUserID: testUserRemoteID1,
|
||||||
|
email: testUserEmail1,
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testCase: "unregistered user",
|
||||||
|
connectorID: testConnectorIDOpenID,
|
||||||
|
clientID: testClientID,
|
||||||
|
userID: testUserID1,
|
||||||
|
remoteUserID: "unregistered-user-id",
|
||||||
|
email: "newemail@example.com",
|
||||||
|
wantLogin: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testCase: "unregistered user with register on first login",
|
||||||
|
connectorID: testConnectorIDOpenID,
|
||||||
|
clientID: testClientID,
|
||||||
|
userID: testUserID1,
|
||||||
|
remoteUserID: "unregistered-user-id",
|
||||||
|
email: "newemail@example.com",
|
||||||
|
configure: func(srv *Server) { srv.RegisterOnFirstLogin = true },
|
||||||
|
wantLogin: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testCase: "unregistered user through local connector with register on first login",
|
||||||
|
connectorID: testConnectorLocalID,
|
||||||
|
clientID: testClientID,
|
||||||
|
userID: testUserID1,
|
||||||
|
remoteUserID: "unregistered-user-id",
|
||||||
|
email: "newemail@example.com",
|
||||||
|
configure: func(srv *Server) { srv.RegisterOnFirstLogin = true },
|
||||||
|
wantLogin: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := f.sessionManager
|
for _, tt := range tests {
|
||||||
sessionID, err := sm.NewSession("IDPC-1", testClientID, "bogus", testRedirectURL, "", false, []string{"openid"})
|
f, err := makeTestFixtures()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error making test fixtures: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ident := oidc.Identity{ID: testUserRemoteID1, Name: "elroy", Email: testUserEmail1}
|
if tt.configure != nil {
|
||||||
key, err := sm.NewSessionKey(sessionID)
|
tt.configure(f.srv)
|
||||||
if err != nil {
|
}
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
sm := f.sessionManager
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
sessionID, err := sm.NewSession(tt.connectorID, tt.clientID, "bogus", testRedirectURL, "", false, []string{"openid"})
|
||||||
}
|
if err != nil {
|
||||||
redirectURL, err := f.srv.Login(ident, key)
|
t.Errorf("case %s: new session: %v", tt.testCase, err)
|
||||||
if err != nil {
|
continue
|
||||||
t.Fatalf("Unexpected err from Server.Login: %v", err)
|
}
|
||||||
}
|
|
||||||
|
|
||||||
wantRedirectURL := "http://client.example.com/callback?code=code-3&state=bogus"
|
key, err := sm.NewSessionKey(sessionID)
|
||||||
if wantRedirectURL != redirectURL {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected redirectURL: want=%q, got=%q", wantRedirectURL, redirectURL)
|
t.Errorf("case %s: new session key: %v", tt.testCase, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ident := oidc.Identity{ID: tt.remoteUserID, Name: "elroy", Email: tt.email}
|
||||||
|
redirectURL, err := f.srv.Login(ident, key)
|
||||||
|
if err != nil {
|
||||||
|
if !tt.wantError {
|
||||||
|
t.Errorf("case %s: server.Login: %v", tt.testCase, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tt.wantError {
|
||||||
|
t.Errorf("case %s: expected server.Login to fail", tt.testCase)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Did the server redirect back to the client app or display an error to the user?
|
||||||
|
gotRedirectURL := strings.HasPrefix(redirectURL, testRedirectURL.String())
|
||||||
|
if gotRedirectURL && !tt.wantLogin {
|
||||||
|
t.Errorf("case %s: should not have logged in", tt.testCase)
|
||||||
|
}
|
||||||
|
if !gotRedirectURL && tt.wantLogin {
|
||||||
|
t.Errorf("case %s: failed to log in. expected redirect url got: %s", tt.testCase, redirectURL)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,10 @@ var (
|
||||||
|
|
||||||
testConnectorID1 = "IDPC-1"
|
testConnectorID1 = "IDPC-1"
|
||||||
|
|
||||||
|
testConnectorIDOpenID = "oidc"
|
||||||
|
testConnectorIDOpenIDTrusted = "oidc-trusted"
|
||||||
|
testConnectorLocalID = "local"
|
||||||
|
|
||||||
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
|
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
|
||||||
|
|
||||||
testUsers = []user.UserWithRemoteIdentities{
|
testUsers = []user.UserWithRemoteIdentities{
|
||||||
|
@ -143,20 +147,27 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
|
||||||
|
|
||||||
connConfigs := []connector.ConnectorConfig{
|
connConfigs := []connector.ConnectorConfig{
|
||||||
&connector.OIDCConnectorConfig{
|
&connector.OIDCConnectorConfig{
|
||||||
ID: "oidc",
|
ID: testConnectorIDOpenID,
|
||||||
IssuerURL: testIssuerURL.String(),
|
IssuerURL: testIssuerURL.String(),
|
||||||
ClientID: "12345",
|
ClientID: "12345",
|
||||||
ClientSecret: "567789",
|
ClientSecret: "567789",
|
||||||
},
|
},
|
||||||
&connector.OIDCConnectorConfig{
|
&connector.OIDCConnectorConfig{
|
||||||
ID: "oidc-trusted",
|
ID: testConnectorIDOpenIDTrusted,
|
||||||
IssuerURL: testIssuerURL.String(),
|
IssuerURL: testIssuerURL.String(),
|
||||||
ClientID: "12345-trusted",
|
ClientID: "12345-trusted",
|
||||||
ClientSecret: "567789-trusted",
|
ClientSecret: "567789-trusted",
|
||||||
TrustedEmailProvider: true,
|
TrustedEmailProvider: true,
|
||||||
},
|
},
|
||||||
|
&connector.OIDCConnectorConfig{
|
||||||
|
ID: testConnectorID1,
|
||||||
|
IssuerURL: testIssuerURL.String(),
|
||||||
|
ClientID: testConnectorID1 + "_client_id",
|
||||||
|
ClientSecret: testConnectorID1 + "_client_secret",
|
||||||
|
TrustedEmailProvider: true,
|
||||||
|
},
|
||||||
&connector.LocalConnectorConfig{
|
&connector.LocalConnectorConfig{
|
||||||
ID: "local",
|
ID: testConnectorLocalID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
connCfgRepo := db.NewConnectorConfigRepo(dbMap)
|
connCfgRepo := db.NewConnectorConfigRepo(dbMap)
|
Loading…
Reference in a new issue