diff --git a/server/handlers.go b/server/handlers.go index ae98452c..ff81460e 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -232,7 +232,16 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } scopes := parseScopes(authReq.Scopes) - showBacklink := len(s.connectors) > 1 + + // Work out where the "Select another login method" link should go. + backLink := "" + if len(s.connectors) > 1 { + backLinkURL := url.URL{ + Path: s.absPath("/auth"), + RawQuery: r.Form.Encode(), + } + backLink = backLinkURL.String() + } switch r.Method { case http.MethodGet: @@ -249,7 +258,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: - if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(conn), false, showBacklink); err != nil { + if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(conn), false, backLink); err != nil { s.logger.Errorf("Server template error: %v", err) } case connector.SAMLConnector: @@ -297,7 +306,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } if !ok { - if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(passwordConnector), true, showBacklink); err != nil { + if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(passwordConnector), true, backLink); err != nil { s.logger.Errorf("Server template error: %v", err) } return diff --git a/server/templates.go b/server/templates.go index ac484301..2712c9c3 100644 --- a/server/templates.go +++ b/server/templates.go @@ -266,15 +266,15 @@ func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []c return renderTemplate(w, t.loginTmpl, data) } -func (t *templates) password(r *http.Request, w http.ResponseWriter, postURL, lastUsername, usernamePrompt string, lastWasInvalid, showBacklink bool) error { +func (t *templates) password(r *http.Request, w http.ResponseWriter, postURL, lastUsername, usernamePrompt string, lastWasInvalid bool, backLink string) error { data := struct { PostURL string - BackLink bool + BackLink string Username string UsernamePrompt string Invalid bool ReqPath string - }{postURL, showBacklink, lastUsername, usernamePrompt, lastWasInvalid, r.URL.Path} + }{postURL, backLink, lastUsername, usernamePrompt, lastWasInvalid, r.URL.Path} return renderTemplate(w, t.passwordTmpl, data) } diff --git a/web/templates/password.html b/web/templates/password.html index 5b585b4e..8c77b26e 100644 --- a/web/templates/password.html +++ b/web/templates/password.html @@ -27,7 +27,7 @@ {{ if .BackLink }} {{ end }}