server: better UX when remote ID already exists

Instead of cryptic message with nowhere to, give them the choice to
login with that account or register.
This commit is contained in:
Bobby Rullo 2015-12-23 10:48:20 -08:00
parent 1675acf21b
commit dc828825e6
6 changed files with 243 additions and 52 deletions

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"html/template" "html/template"
"io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -129,11 +130,15 @@ var connectorDisplayNameMap = map[string]string{
"bitbucket": "Bitbucket", "bitbucket": "Bitbucket",
} }
func execTemplate(w http.ResponseWriter, tpl *template.Template, data interface{}) { type Template interface {
Execute(io.Writer, interface{}) error
}
func execTemplate(w http.ResponseWriter, tpl Template, data interface{}) {
execTemplateWithStatus(w, tpl, data, http.StatusOK) execTemplateWithStatus(w, tpl, data, http.StatusOK)
} }
func execTemplateWithStatus(w http.ResponseWriter, tpl *template.Template, data interface{}, status int) { func execTemplateWithStatus(w http.ResponseWriter, tpl Template, data interface{}, status int) {
w.WriteHeader(status) w.WriteHeader(status)
if err := tpl.Execute(w, data); err != nil { if err := tpl.Execute(w, data); err != nil {
log.Errorf("Error loading page: %q", err) log.Errorf("Error loading page: %q", err)

View file

@ -20,6 +20,12 @@ type formError struct {
Error string Error string
} }
type remoteExistsData struct {
Login string
Register string
}
type registerTemplateData struct { type registerTemplateData struct {
Error bool Error bool
FormErrors []formError FormErrors []formError
@ -28,6 +34,7 @@ type registerTemplateData struct {
Code string Code string
Password string Password string
Local bool Local bool
RemoteExists *remoteExistsData
} }
var ( var (
@ -47,8 +54,7 @@ var (
} }
) )
func handleRegisterFunc(s *Server) http.HandlerFunc { func handleRegisterFunc(s *Server, tpl Template) http.HandlerFunc {
tpl := s.RegisterTemplate
errPage := func(w http.ResponseWriter, msg string, code string, status int) { errPage := func(w http.ResponseWriter, msg string, code string, status int) {
data := registerTemplateData{ data := registerTemplateData{
@ -92,6 +98,46 @@ func handleRegisterFunc(s *Server) http.HandlerFunc {
return return
} }
var exists bool
exists, err = remoteIdentityExists(s.UserRepo, ses.ConnectorID, ses.Identity.ID)
if err != nil {
internalError(w, err)
return
}
if exists {
// we have to create a new session to be able to run the server.Login function
newSessionKey, err := s.NewSession(ses.ConnectorID, ses.ClientID,
ses.ClientState, ses.RedirectURL, ses.Nonce, false, ses.Scope)
if err != nil {
internalError(w, err)
return
}
// make sure to clean up the old session
if err = s.KillSession(code); err != nil {
internalError(w, err)
}
// finally, we can create a valid redirect URL for them.
redirURL, err := s.Login(ses.Identity, newSessionKey)
if err != nil {
internalError(w, err)
return
}
registerURL := newLoginURLFromSession(
s.IssuerURL, ses, true, []string{}, "")
execTemplate(w, tpl, registerTemplateData{
RemoteExists: &remoteExistsData{
Login: redirURL,
Register: registerURL.String(),
},
})
return
}
// determine whether or not this is a local or remote ID that is going // determine whether or not this is a local or remote ID that is going
// to be registered. // to be registered.
idpc, ok := idx[ses.ConnectorID] idpc, ok := idx[ses.ConnectorID]
@ -175,7 +221,7 @@ func handleRegisterFunc(s *Server) http.HandlerFunc {
log.Errorf("Error killing session: %v", err) log.Errorf("Error killing session: %v", err)
} }
http.Redirect(w, r, loginURL.String(), http.StatusSeeOther) http.Redirect(w, r, loginURL.String(), http.StatusSeeOther)
return
} }
if err != nil { if err != nil {
@ -212,17 +258,22 @@ func handleRegisterFunc(s *Server) http.HandlerFunc {
} }
} }
ru := ses.RedirectURL w.Header().Set("Location", makeClientRedirectURL(
q := ru.Query() ses.RedirectURL, code, ses.ClientState).String())
q.Set("code", code)
q.Set("state", ses.ClientState)
ru.RawQuery = q.Encode()
w.Header().Set("Location", ru.String())
w.WriteHeader(http.StatusSeeOther) w.WriteHeader(http.StatusSeeOther)
return return
} }
} }
func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url.URL {
ru := baseRedirURL
q := ru.Query()
q.Set("code", code)
q.Set("state", clientState)
ru.RawQuery = q.Encode()
return &ru
}
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) { func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID) userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
if err != nil { if err != nil {
@ -304,3 +355,20 @@ func newLoginURLFromSession(issuer url.URL, ses *session.Session, register bool,
loginURL.RawQuery = v.Encode() loginURL.RawQuery = v.Encode()
return &loginURL return &loginURL
} }
func remoteIdentityExists(ur user.UserRepo, connectorID, id string) (bool, error) {
_, err := ur.GetByRemoteIdentity(nil, user.RemoteIdentity{
ConnectorID: connectorID,
ID: id,
})
if err == user.ErrorNotFound {
return false, nil
}
if err == nil {
return true, nil
}
return false, err
}

View file

@ -1,6 +1,8 @@
package server package server
import ( import (
"errors"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -14,8 +16,26 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
type testTemplate struct {
tpl Template
data registerTemplateData
}
func (t *testTemplate) Execute(w io.Writer, data interface{}) error {
dataMap, ok := data.(registerTemplateData)
if !ok {
return errors.New("could not cast to registerTemplateData")
}
t.data = dataMap
return t.tpl.Execute(w, data)
}
func TestHandleRegister(t *testing.T) { func TestHandleRegister(t *testing.T) {
testIssuerAuth := testIssuerURL
testIssuerAuth.Path = "/auth"
str := func(s string) []string { str := func(s string) []string {
return []string{s} return []string{s}
} }
@ -25,11 +45,14 @@ func TestHandleRegister(t *testing.T) {
connID string connID string
attachRemote bool attachRemote bool
remoteIdentityEmail string remoteIdentityEmail string
remoteAlreadyExists bool
// want // want
wantStatus int wantStatus int
wantFormValues url.Values wantFormValues url.Values
wantUserCreated bool wantUserExists bool
wantRedirectURL url.URL
wantRegisterTemplateData *registerTemplateData
}{ }{
{ {
// User comes in with a valid code, redirected from the connector, // User comes in with a valid code, redirected from the connector,
@ -59,7 +82,37 @@ func TestHandleRegister(t *testing.T) {
attachRemote: true, attachRemote: true,
wantStatus: http.StatusSeeOther, wantStatus: http.StatusSeeOther,
wantUserCreated: true, wantUserExists: true,
},
{
// User comes in with a valid code, redirected from the connector.
// User is redirected to dex page with msg_code "login-maybe",
// because the remote identity already exists.
query: url.Values{
"code": []string{"code-3"},
},
connID: "oidc-trusted",
remoteIdentityEmail: "test@example.com",
attachRemote: true,
remoteAlreadyExists: true,
wantStatus: http.StatusOK,
wantUserExists: true,
wantRegisterTemplateData: &registerTemplateData{
RemoteExists: &remoteExistsData{
Login: newURLWithParams(testRedirectURL, url.Values{
"code": []string{"code-7"},
"state": []string{""},
}).String(),
Register: newURLWithParams(testIssuerAuth, url.Values{
"client_id": []string{testClientID},
"redirect_uri": []string{testRedirectURL.String()},
"register": []string{"1"},
"scope": []string{"openid"},
"state": []string{""},
}).String(),
},
},
}, },
{ {
// User comes in with a valid code, redirected from the connector, // User comes in with a valid code, redirected from the connector,
@ -75,7 +128,7 @@ func TestHandleRegister(t *testing.T) {
attachRemote: true, attachRemote: true,
wantStatus: http.StatusSeeOther, wantStatus: http.StatusSeeOther,
wantUserCreated: true, wantUserExists: true,
}, },
{ {
// User comes in with a valid code, redirected from the connector, // User comes in with a valid code, redirected from the connector,
@ -89,7 +142,7 @@ func TestHandleRegister(t *testing.T) {
attachRemote: true, attachRemote: true,
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantUserCreated: false, wantUserExists: false,
wantFormValues: url.Values{ wantFormValues: url.Values{
"code": str("code-4"), "code": str("code-4"),
"email": str(""), "email": str(""),
@ -108,7 +161,7 @@ func TestHandleRegister(t *testing.T) {
attachRemote: true, attachRemote: true,
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantUserCreated: false, wantUserExists: false,
wantFormValues: url.Values{ wantFormValues: url.Values{
"code": str("code-4"), "code": str("code-4"),
"email": str(""), "email": str(""),
@ -144,7 +197,7 @@ func TestHandleRegister(t *testing.T) {
}, },
connID: "local", connID: "local",
wantStatus: http.StatusSeeOther, wantStatus: http.StatusSeeOther,
wantUserCreated: true, wantUserExists: true,
}, },
{ {
// User comes in with spaces in their email, having submitted the // User comes in with spaces in their email, having submitted the
@ -157,7 +210,7 @@ func TestHandleRegister(t *testing.T) {
}, },
connID: "local", connID: "local",
wantStatus: http.StatusSeeOther, wantStatus: http.StatusSeeOther,
wantUserCreated: true, wantUserExists: true,
}, },
{ {
// User comes in with an invalid email, having submitted the form. // User comes in with an invalid email, having submitted the form.
@ -187,7 +240,7 @@ func TestHandleRegister(t *testing.T) {
}, },
connID: "local", connID: "local",
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantUserCreated: false, wantUserExists: false,
wantFormValues: url.Values{ wantFormValues: url.Values{
"code": str("code-3"), "code": str("code-3"),
"email": str("test@example.com"), "email": str("test@example.com"),
@ -207,7 +260,7 @@ func TestHandleRegister(t *testing.T) {
connID: "oidc", connID: "oidc",
attachRemote: true, attachRemote: true,
wantStatus: http.StatusSeeOther, wantStatus: http.StatusSeeOther,
wantUserCreated: true, wantUserExists: true,
}, },
{ {
// Same as before, but missing a code. // Same as before, but missing a code.
@ -218,7 +271,7 @@ func TestHandleRegister(t *testing.T) {
connID: "oidc", connID: "oidc",
attachRemote: true, attachRemote: true,
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantUserCreated: false, wantUserExists: false,
}, },
} }
@ -228,6 +281,20 @@ func TestHandleRegister(t *testing.T) {
t.Fatalf("case %d: could not make test fixtures: %v", i, err) t.Fatalf("case %d: could not make test fixtures: %v", i, err)
} }
if tt.remoteAlreadyExists {
f.userRepo.Create(nil, user.User{
ID: "register-test-new-user",
Email: tt.remoteIdentityEmail,
EmailVerified: true,
})
f.userRepo.AddRemoteIdentity(nil, "register-test-new-user",
user.RemoteIdentity{
ID: "remoteID",
ConnectorID: tt.connID,
})
}
key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"}) key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"})
t.Logf("case %d: key for NewSession: %v", i, key) t.Logf("case %d: key for NewSession: %v", i, key)
@ -251,10 +318,10 @@ func TestHandleRegister(t *testing.T) {
t.Fatalf("case %d: expected non-nil error: %v", i, err) t.Fatalf("case %d: expected non-nil error: %v", i, err)
} }
t.Logf("case %d: key for NewSession: %v", i, key) t.Logf("case %d: key for NewSession: %v", i, key)
} }
hdlr := handleRegisterFunc(f.srv) tpl := &testTemplate{tpl: f.srv.RegisterTemplate}
hdlr := handleRegisterFunc(f.srv, tpl)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u := "http://server.example.com" u := "http://server.example.com"
@ -266,12 +333,25 @@ func TestHandleRegister(t *testing.T) {
} }
hdlr.ServeHTTP(w, req) hdlr.ServeHTTP(w, req)
if tt.wantRedirectURL.String() != "" {
locationHdr := w.HeaderMap.Get("Location")
redirURL, err := url.Parse(locationHdr)
if err != nil {
t.Errorf("case %d: unexpected error parsing url %q: %q", i, locationHdr, err)
} else {
if diff := pretty.Compare(*redirURL, tt.wantRedirectURL); diff != "" {
t.Errorf("case %d: Compare(redirURL, tt.wantRedirectURL) = %v", i, diff)
}
}
}
if tt.wantStatus != w.Code { if tt.wantStatus != w.Code {
t.Errorf("case %d: wantStatus=%v, got=%v", i, tt.wantStatus, w.Code) t.Errorf("case %d: wantStatus=%v, got=%v", i, tt.wantStatus, w.Code)
} }
_, err = f.userRepo.GetByEmail(nil, "test@example.com") _, err = f.userRepo.GetByEmail(nil, "test@example.com")
if tt.wantUserCreated { if tt.wantUserExists {
if err != nil { if err != nil {
t.Errorf("case %d: user not created: %v", i, err) t.Errorf("case %d: user not created: %v", i, err)
} }
@ -288,5 +368,17 @@ func TestHandleRegister(t *testing.T) {
t.Errorf("case %d: Compare(want, got) = %v", i, diff) t.Errorf("case %d: Compare(want, got) = %v", i, diff)
} }
if tt.wantRegisterTemplateData != nil {
if diff := pretty.Compare(*tt.wantRegisterTemplateData, tpl.data); diff != "" {
t.Errorf("case %d: Compare(tt.wantRegisterTemplateData, tpl.data) = %v",
i, diff)
}
}
} }
} }
func newURLWithParams(u url.URL, values url.Values) *url.URL {
newU := u
newU.RawQuery = values.Encode()
return &newU
}

View file

@ -206,7 +206,7 @@ func (s *Server) HTTPHandler() http.Handler {
mux.Handle(httpPathHealth, makeHealthHandler(checks)) mux.Handle(httpPathHealth, makeHealthHandler(checks))
if s.EnableRegistration { if s.EnableRegistration {
mux.HandleFunc(httpPathRegister, handleRegisterFunc(s)) mux.HandleFunc(httpPathRegister, handleRegisterFunc(s, s.RegisterTemplate))
} }
mux.HandleFunc(httpPathEmailVerify, handleEmailVerifyFunc(s.VerifyEmailTemplate, mux.HandleFunc(httpPathEmailVerify, handleEmailVerifyFunc(s.VerifyEmailTemplate,

View file

@ -146,7 +146,8 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err return nil, err
} }
tpl, err := getTemplates("dex", "https://coreos.com/assets/images/brand/coreos-mark-30px.png", tpl, err := getTemplates("dex",
"https://coreos.com/assets/images/brand/coreos-mark-30px.png",
true, templatesLocation) true, templatesLocation)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -5,6 +5,31 @@
{{ if .Error }} {{ if .Error }}
<div class="error-box">{{ .Message }}</div> <div class="error-box">{{ .Message }}</div>
{{ else if .RemoteExists }}
{{ with .RemoteExists }}
<div class="instruction-block">
This account is already registered.
If you'd like to login with that account, click here:
</div>
<div>
<a href="{{ .Login }}" target="_self">
<button class="btn btn-provider">
<span class="btn-text">Login</span>
</button>
</a>
</div>
<div class="instruction-block">
If you would like to register with a different account, click here:
</div>
<div>
<a href="{{ .Register }}" target="_self">
<button class="btn btn-provider">
<span class="btn-text">Register</span>
</button>
</a>
</div>
{{ end }}
{{ else }} {{ else }}
<form id="registerForm" method="POST" action="/register"> <form id="registerForm" method="POST" action="/register">