diff --git a/server/http.go b/server/http.go index d74c350b..049ae778 100644 --- a/server/http.go +++ b/server/http.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "html/template" + "io" "net/http" "net/url" "strings" @@ -129,11 +130,15 @@ var connectorDisplayNameMap = map[string]string{ "bitbucket": "Bitbucket", } -func execTemplate(w http.ResponseWriter, tpl *template.Template, data interface{}) { +type Template interface { + Execute(io.Writer, interface{}) error +} + +func execTemplate(w http.ResponseWriter, tpl Template, data interface{}) { execTemplateWithStatus(w, tpl, data, http.StatusOK) } -func execTemplateWithStatus(w http.ResponseWriter, tpl *template.Template, data interface{}, status int) { +func execTemplateWithStatus(w http.ResponseWriter, tpl Template, data interface{}, status int) { w.WriteHeader(status) if err := tpl.Execute(w, data); err != nil { log.Errorf("Error loading page: %q", err) diff --git a/server/register.go b/server/register.go index a5eb1b15..d4d4dc14 100644 --- a/server/register.go +++ b/server/register.go @@ -20,14 +20,21 @@ type formError struct { Error string } +type remoteExistsData struct { + Login string + + Register string +} + type registerTemplateData struct { - Error bool - FormErrors []formError - Message string - Email string - Code string - Password string - Local bool + Error bool + FormErrors []formError + Message string + Email string + Code string + Password string + Local bool + RemoteExists *remoteExistsData } var ( @@ -47,8 +54,7 @@ var ( } ) -func handleRegisterFunc(s *Server) http.HandlerFunc { - tpl := s.RegisterTemplate +func handleRegisterFunc(s *Server, tpl Template) http.HandlerFunc { errPage := func(w http.ResponseWriter, msg string, code string, status int) { data := registerTemplateData{ @@ -92,6 +98,46 @@ func handleRegisterFunc(s *Server) http.HandlerFunc { return } + var exists bool + exists, err = remoteIdentityExists(s.UserRepo, ses.ConnectorID, ses.Identity.ID) + if err != nil { + internalError(w, err) + return + } + + if exists { + // we have to create a new session to be able to run the server.Login function + newSessionKey, err := s.NewSession(ses.ConnectorID, ses.ClientID, + ses.ClientState, ses.RedirectURL, ses.Nonce, false, ses.Scope) + if err != nil { + internalError(w, err) + return + } + // make sure to clean up the old session + if err = s.KillSession(code); err != nil { + internalError(w, err) + } + + // finally, we can create a valid redirect URL for them. + redirURL, err := s.Login(ses.Identity, newSessionKey) + if err != nil { + internalError(w, err) + return + } + + registerURL := newLoginURLFromSession( + s.IssuerURL, ses, true, []string{}, "") + + execTemplate(w, tpl, registerTemplateData{ + RemoteExists: &remoteExistsData{ + Login: redirURL, + Register: registerURL.String(), + }, + }) + + return + } + // determine whether or not this is a local or remote ID that is going // to be registered. idpc, ok := idx[ses.ConnectorID] @@ -175,7 +221,7 @@ func handleRegisterFunc(s *Server) http.HandlerFunc { log.Errorf("Error killing session: %v", err) } http.Redirect(w, r, loginURL.String(), http.StatusSeeOther) - + return } if err != nil { @@ -212,17 +258,22 @@ func handleRegisterFunc(s *Server) http.HandlerFunc { } } - ru := ses.RedirectURL - q := ru.Query() - q.Set("code", code) - q.Set("state", ses.ClientState) - ru.RawQuery = q.Encode() - w.Header().Set("Location", ru.String()) + w.Header().Set("Location", makeClientRedirectURL( + ses.RedirectURL, code, ses.ClientState).String()) w.WriteHeader(http.StatusSeeOther) return } } +func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url.URL { + ru := baseRedirURL + q := ru.Query() + q.Set("code", code) + q.Set("state", clientState) + ru.RawQuery = q.Encode() + return &ru +} + func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) { userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID) if err != nil { @@ -304,3 +355,20 @@ func newLoginURLFromSession(issuer url.URL, ses *session.Session, register bool, loginURL.RawQuery = v.Encode() return &loginURL } + +func remoteIdentityExists(ur user.UserRepo, connectorID, id string) (bool, error) { + _, err := ur.GetByRemoteIdentity(nil, user.RemoteIdentity{ + ConnectorID: connectorID, + ID: id, + }) + + if err == user.ErrorNotFound { + return false, nil + } + + if err == nil { + return true, nil + } + + return false, err +} diff --git a/server/register_test.go b/server/register_test.go index a3970f88..eec0ef9a 100644 --- a/server/register_test.go +++ b/server/register_test.go @@ -1,6 +1,8 @@ package server import ( + "errors" + "io" "net/http" "net/http/httptest" "net/url" @@ -14,8 +16,26 @@ import ( "github.com/coreos/go-oidc/oidc" ) +type testTemplate struct { + tpl Template + + data registerTemplateData +} + +func (t *testTemplate) Execute(w io.Writer, data interface{}) error { + dataMap, ok := data.(registerTemplateData) + if !ok { + return errors.New("could not cast to registerTemplateData") + } + t.data = dataMap + return t.tpl.Execute(w, data) +} + func TestHandleRegister(t *testing.T) { + testIssuerAuth := testIssuerURL + testIssuerAuth.Path = "/auth" + str := func(s string) []string { return []string{s} } @@ -25,11 +45,14 @@ func TestHandleRegister(t *testing.T) { connID string attachRemote bool remoteIdentityEmail string + remoteAlreadyExists bool // want - wantStatus int - wantFormValues url.Values - wantUserCreated bool + wantStatus int + wantFormValues url.Values + wantUserExists bool + wantRedirectURL url.URL + wantRegisterTemplateData *registerTemplateData }{ { // User comes in with a valid code, redirected from the connector, @@ -58,8 +81,38 @@ func TestHandleRegister(t *testing.T) { remoteIdentityEmail: "test@example.com", attachRemote: true, - wantStatus: http.StatusSeeOther, - wantUserCreated: true, + wantStatus: http.StatusSeeOther, + wantUserExists: true, + }, + { + // User comes in with a valid code, redirected from the connector. + // User is redirected to dex page with msg_code "login-maybe", + // because the remote identity already exists. + query: url.Values{ + "code": []string{"code-3"}, + }, + connID: "oidc-trusted", + remoteIdentityEmail: "test@example.com", + attachRemote: true, + remoteAlreadyExists: true, + + wantStatus: http.StatusOK, + wantUserExists: true, + wantRegisterTemplateData: ®isterTemplateData{ + RemoteExists: &remoteExistsData{ + Login: newURLWithParams(testRedirectURL, url.Values{ + "code": []string{"code-7"}, + "state": []string{""}, + }).String(), + Register: newURLWithParams(testIssuerAuth, url.Values{ + "client_id": []string{testClientID}, + "redirect_uri": []string{testRedirectURL.String()}, + "register": []string{"1"}, + "scope": []string{"openid"}, + "state": []string{""}, + }).String(), + }, + }, }, { // User comes in with a valid code, redirected from the connector, @@ -74,8 +127,8 @@ func TestHandleRegister(t *testing.T) { remoteIdentityEmail: "test@example.com", attachRemote: true, - wantStatus: http.StatusSeeOther, - wantUserCreated: true, + wantStatus: http.StatusSeeOther, + wantUserExists: true, }, { // User comes in with a valid code, redirected from the connector, @@ -88,8 +141,8 @@ func TestHandleRegister(t *testing.T) { remoteIdentityEmail: "", attachRemote: true, - wantStatus: http.StatusOK, - wantUserCreated: false, + wantStatus: http.StatusOK, + wantUserExists: false, wantFormValues: url.Values{ "code": str("code-4"), "email": str(""), @@ -107,8 +160,8 @@ func TestHandleRegister(t *testing.T) { remoteIdentityEmail: "notanemail", attachRemote: true, - wantStatus: http.StatusOK, - wantUserCreated: false, + wantStatus: http.StatusOK, + wantUserExists: false, wantFormValues: url.Values{ "code": str("code-4"), "email": str(""), @@ -142,9 +195,9 @@ func TestHandleRegister(t *testing.T) { "email": str("test@example.com"), "password": str("password"), }, - connID: "local", - wantStatus: http.StatusSeeOther, - wantUserCreated: true, + connID: "local", + wantStatus: http.StatusSeeOther, + wantUserExists: true, }, { // User comes in with spaces in their email, having submitted the @@ -155,9 +208,9 @@ func TestHandleRegister(t *testing.T) { "email": str("\t\ntest@example.com "), "password": str("password"), }, - connID: "local", - wantStatus: http.StatusSeeOther, - wantUserCreated: true, + connID: "local", + wantStatus: http.StatusSeeOther, + wantUserExists: true, }, { // User comes in with an invalid email, having submitted the form. @@ -185,9 +238,9 @@ func TestHandleRegister(t *testing.T) { "validate": []string{"1"}, "email": str("test@example.com"), }, - connID: "local", - wantStatus: http.StatusBadRequest, - wantUserCreated: false, + connID: "local", + wantStatus: http.StatusBadRequest, + wantUserExists: false, wantFormValues: url.Values{ "code": str("code-3"), "email": str("test@example.com"), @@ -204,10 +257,10 @@ func TestHandleRegister(t *testing.T) { "validate": []string{"1"}, "email": str("test@example.com"), }, - connID: "oidc", - attachRemote: true, - wantStatus: http.StatusSeeOther, - wantUserCreated: true, + connID: "oidc", + attachRemote: true, + wantStatus: http.StatusSeeOther, + wantUserExists: true, }, { // Same as before, but missing a code. @@ -215,10 +268,10 @@ func TestHandleRegister(t *testing.T) { "validate": []string{"1"}, "email": str("test@example.com"), }, - connID: "oidc", - attachRemote: true, - wantStatus: http.StatusUnauthorized, - wantUserCreated: false, + connID: "oidc", + attachRemote: true, + wantStatus: http.StatusUnauthorized, + wantUserExists: false, }, } @@ -228,6 +281,20 @@ func TestHandleRegister(t *testing.T) { t.Fatalf("case %d: could not make test fixtures: %v", i, err) } + if tt.remoteAlreadyExists { + f.userRepo.Create(nil, user.User{ + ID: "register-test-new-user", + Email: tt.remoteIdentityEmail, + EmailVerified: true, + }) + + f.userRepo.AddRemoteIdentity(nil, "register-test-new-user", + user.RemoteIdentity{ + ID: "remoteID", + ConnectorID: tt.connID, + }) + } + key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"}) t.Logf("case %d: key for NewSession: %v", i, key) @@ -251,10 +318,10 @@ func TestHandleRegister(t *testing.T) { t.Fatalf("case %d: expected non-nil error: %v", i, err) } t.Logf("case %d: key for NewSession: %v", i, key) - } - hdlr := handleRegisterFunc(f.srv) + tpl := &testTemplate{tpl: f.srv.RegisterTemplate} + hdlr := handleRegisterFunc(f.srv, tpl) w := httptest.NewRecorder() u := "http://server.example.com" @@ -266,12 +333,25 @@ func TestHandleRegister(t *testing.T) { } hdlr.ServeHTTP(w, req) + + if tt.wantRedirectURL.String() != "" { + locationHdr := w.HeaderMap.Get("Location") + redirURL, err := url.Parse(locationHdr) + if err != nil { + t.Errorf("case %d: unexpected error parsing url %q: %q", i, locationHdr, err) + } else { + if diff := pretty.Compare(*redirURL, tt.wantRedirectURL); diff != "" { + t.Errorf("case %d: Compare(redirURL, tt.wantRedirectURL) = %v", i, diff) + } + } + } + if tt.wantStatus != w.Code { t.Errorf("case %d: wantStatus=%v, got=%v", i, tt.wantStatus, w.Code) } _, err = f.userRepo.GetByEmail(nil, "test@example.com") - if tt.wantUserCreated { + if tt.wantUserExists { if err != nil { t.Errorf("case %d: user not created: %v", i, err) } @@ -288,5 +368,17 @@ func TestHandleRegister(t *testing.T) { t.Errorf("case %d: Compare(want, got) = %v", i, diff) } + if tt.wantRegisterTemplateData != nil { + if diff := pretty.Compare(*tt.wantRegisterTemplateData, tpl.data); diff != "" { + t.Errorf("case %d: Compare(tt.wantRegisterTemplateData, tpl.data) = %v", + i, diff) + } + } } } + +func newURLWithParams(u url.URL, values url.Values) *url.URL { + newU := u + newU.RawQuery = values.Encode() + return &newU +} diff --git a/server/server.go b/server/server.go index e4d7e328..66bc2fd7 100644 --- a/server/server.go +++ b/server/server.go @@ -206,7 +206,7 @@ func (s *Server) HTTPHandler() http.Handler { mux.Handle(httpPathHealth, makeHealthHandler(checks)) if s.EnableRegistration { - mux.HandleFunc(httpPathRegister, handleRegisterFunc(s)) + mux.HandleFunc(httpPathRegister, handleRegisterFunc(s, s.RegisterTemplate)) } mux.HandleFunc(httpPathEmailVerify, handleEmailVerifyFunc(s.VerifyEmailTemplate, diff --git a/server/testutil.go b/server/testutil.go index 99615539..fed05f25 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -146,7 +146,8 @@ func makeTestFixtures() (*testFixtures, error) { return nil, err } - tpl, err := getTemplates("dex", "https://coreos.com/assets/images/brand/coreos-mark-30px.png", + tpl, err := getTemplates("dex", + "https://coreos.com/assets/images/brand/coreos-mark-30px.png", true, templatesLocation) if err != nil { return nil, err diff --git a/static/html/register.html b/static/html/register.html index 7c57cbd5..33a20c16 100644 --- a/static/html/register.html +++ b/static/html/register.html @@ -4,7 +4,32 @@

Create Your Account

{{ if .Error }} -
{{ .Message }}
+
{{ .Message }}
+ {{ else if .RemoteExists }} + {{ with .RemoteExists }} +
+ This account is already registered. + If you'd like to login with that account, click here: +
+
+ + + +
+
+ If you would like to register with a different account, click here: +
+
+ + + +
+ {{ end }} + {{ else }}