diff --git a/server/handlers.go b/server/handlers.go index e7facca6..4144dd1f 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -124,6 +124,66 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { // handleAuthorization handles the OAuth2 auth endpoint. func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { + // Extract the arguments + if err := r.ParseForm(); err != nil { + s.logger.Errorf("Failed to parse arguments: %v", err) + + s.renderError(r, w, http.StatusBadRequest, err.Error()) + return + } + + connectorID := r.Form.Get("connector_id") + + connectors, err := s.storage.ListConnectors() + if err != nil { + s.logger.Errorf("Failed to get list of connectors: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.") + return + } + + // We don't need connector_id any more + r.Form.Del("connector_id") + + // Construct a URL with all of the arguments in its query + connURL := url.URL{ + RawQuery: r.Form.Encode(), + } + + // Redirect if a client chooses a specific connector_id + if connectorID != "" { + for _, c := range connectors { + if c.ID == connectorID { + connURL.Path = s.absPath("/auth", c.ID) + http.Redirect(w, r, connURL.String(), http.StatusFound) + return + } + } + s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound) + return + } + + if len(connectors) == 1 && !s.alwaysShowLogin { + connURL.Path = s.absPath("/auth", connectors[0].ID) + http.Redirect(w, r, connURL.String(), http.StatusFound) + } + + connectorInfos := make([]connectorInfo, len(connectors)) + for index, conn := range connectors { + connURL.Path = s.absPath("/auth", conn.ID) + connectorInfos[index] = connectorInfo{ + ID: conn.ID, + Name: conn.Name, + Type: conn.Type, + URL: connURL.String(), + } + } + + if err := s.templates.login(r, w, connectorInfos); err != nil { + s.logger.Errorf("Server template error: %v", err) + } +} + +func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { authReq, err := s.parseAuthorizationRequest(r) if err != nil { s.logger.Errorf("Failed to parse authorization request: %v", err) @@ -145,64 +205,6 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { return } - // TODO(ericchiang): Create this authorization request later in the login flow - // so users don't hit "not found" database errors if they wait at the login - // screen too long. - // - // See: https://github.com/dexidp/dex/issues/646 - authReq.Expiry = s.now().Add(s.authRequestsValidFor) - if err := s.storage.CreateAuthRequest(*authReq); err != nil { - s.logger.Errorf("Failed to create authorization request: %v", err) - s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.") - return - } - - connectors, err := s.storage.ListConnectors() - if err != nil { - s.logger.Errorf("Failed to get list of connectors: %v", err) - s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.") - return - } - - // Redirect if a client chooses a specific connector_id - if authReq.ConnectorID != "" { - for _, c := range connectors { - if c.ID == authReq.ConnectorID { - http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) - return - } - } - s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound) - return - } - - if len(connectors) == 1 && !s.alwaysShowLogin { - for _, c := range connectors { - // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter - // on create the auth request. - http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) - return - } - } - - connectorInfos := make([]connectorInfo, len(connectors)) - for index, conn := range connectors { - connectorInfos[index] = connectorInfo{ - ID: conn.ID, - Name: conn.Name, - Type: conn.Type, - // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter - // on create the auth request. - URL: s.absPath("/auth", conn.ID) + "?req=" + authReq.ID, - } - } - - if err := s.templates.login(r, w, connectorInfos); err != nil { - s.logger.Errorf("Server template error: %v", err) - } -} - -func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { connID := mux.Vars(r)["connector"] conn, err := s.getConnector(connID) if err != nil { @@ -211,37 +213,35 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } - authReqID := r.FormValue("req") - - authReq, err := s.storage.GetAuthRequest(authReqID) - if err != nil { - s.logger.Errorf("Failed to get auth request: %v", err) - if err == storage.ErrNotFound { - s.renderError(r, w, http.StatusBadRequest, "Login session expired.") - } else { - s.renderError(r, w, http.StatusInternalServerError, "Database error.") - } + // Set the connector being used for the login. + if authReq.ConnectorID != "" && authReq.ConnectorID != connID { + s.logger.Errorf("Mismatched connector ID in auth request: %s vs %s", + authReq.ConnectorID, connID) + s.renderError(r, w, http.StatusBadRequest, "Bad connector ID") return } - // Set the connector being used for the login. - if authReq.ConnectorID != connID { - updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { - if a.ConnectorID != "" { - return a, fmt.Errorf("connector is already set for this auth request") - } - a.ConnectorID = connID - return a, nil - } - if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil { - s.logger.Errorf("Failed to set connector ID on auth request: %v", err) - s.renderError(r, w, http.StatusInternalServerError, "Database error.") - return - } + authReq.ConnectorID = connID + + // Actually create the auth request + authReq.Expiry = s.now().Add(s.authRequestsValidFor) + if err := s.storage.CreateAuthRequest(*authReq); err != nil { + s.logger.Errorf("Failed to create authorization request: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.") + return } scopes := parseScopes(authReq.Scopes) - showBacklink := len(s.connectors) > 1 + + // Work out where the "Select another login method" link should go. + backLink := "" + if len(s.connectors) > 1 { + backLinkURL := url.URL{ + Path: s.absPath("/auth"), + RawQuery: r.Form.Encode(), + } + backLink = backLinkURL.String() + } switch r.Method { case http.MethodGet: @@ -250,7 +250,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Use the auth request ID as the "state" token. // // TODO(ericchiang): Is this appropriate or should we also be using a nonce? - callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID) + callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) if err != nil { s.logger.Errorf("Connector %q returned error when creating callback: %v", connID, err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") @@ -258,11 +258,17 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: - if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(conn), false, showBacklink); err != nil { - s.logger.Errorf("Server template error: %v", err) + loginURL := url.URL{ + Path: s.absPath("/auth", connID, "login"), } + q := loginURL.Query() + q.Set("state", authReq.ID) + q.Set("back", backLink) + loginURL.RawQuery = q.Encode() + + http.Redirect(w, r, loginURL.String(), http.StatusFound) case connector.SAMLConnector: - action, value, err := conn.POSTData(scopes, authReqID) + action, value, err := conn.POSTData(scopes, authReq.ID) if err != nil { s.logger.Errorf("Creating SAML data: %v", err) s.renderError(r, w, http.StatusInternalServerError, "Connector Login Error") @@ -285,28 +291,74 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { document.forms[0].submit(); - `, action, value, authReqID) + `, action, value, authReq.ID) default: s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") } - case http.MethodPost: - passwordConnector, ok := conn.Connector.(connector.PasswordConnector) - if !ok { + default: + s.renderError(r, w, http.StatusBadRequest, "Unsupported request method.") + } +} + +func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { + authID := r.URL.Query().Get("state") + if authID == "" { + s.renderError(r, w, http.StatusBadRequest, "User session error.") + return + } + + backLink := r.URL.Query().Get("back") + + authReq, err := s.storage.GetAuthRequest(authID) + if err != nil { + if err == storage.ErrNotFound { + s.logger.Errorf("Invalid 'state' parameter provided: %v", err) s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") return } + s.logger.Errorf("Failed to get auth request: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "Database error.") + return + } + if connID := mux.Vars(r)["connector"]; connID != "" && connID != authReq.ConnectorID { + s.logger.Errorf("Connector mismatch: authentication started with id %q, but password login for id %q was triggered", authReq.ConnectorID, connID) + s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") + return + } + + conn, err := s.getConnector(authReq.ConnectorID) + if err != nil { + s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err) + s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") + return + } + + pwConn, ok := conn.Connector.(connector.PasswordConnector) + if !ok { + s.logger.Errorf("Expected password connector in handlePasswordLogin(), but got %v", pwConn) + s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") + return + } + + switch r.Method { + case http.MethodGet: + if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil { + s.logger.Errorf("Server template error: %v", err) + } + case http.MethodPost: username := r.FormValue("login") password := r.FormValue("password") + scopes := parseScopes(authReq.Scopes) - identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password) + identity, ok, err := pwConn.Login(r.Context(), scopes, username, password) if err != nil { s.logger.Errorf("Failed to login user: %v", err) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err)) return } if !ok { - if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(passwordConnector), true, showBacklink); err != nil { + if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(pwConn), true, backLink); err != nil { s.logger.Errorf("Server template error: %v", err) } return diff --git a/server/handlers_test.go b/server/handlers_test.go index a1c60102..a808e5a1 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -7,19 +7,16 @@ import ( "errors" "net/http" "net/http/httptest" - "os" "testing" "time" gosundheit "github.com/AppsFlyer/go-sundheit" "github.com/AppsFlyer/go-sundheit/checks" "github.com/coreos/go-oidc/v3/oidc" - "github.com/gorilla/mux" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "github.com/dexidp/dex/storage" - "github.com/dexidp/dex/storage/memory" ) func TestHandleHealth(t *testing.T) { @@ -133,87 +130,6 @@ func TestHandleInvalidSAMLCallbacks(t *testing.T) { } } -func TestConnectorLoginDoesNotAllowToChangeConnectorForAuthRequest(t *testing.T) { - memStorage := memory.New(logger) - - templates, err := loadTemplates(webConfig{webFS: os.DirFS("../web")}, "templates") - if err != nil { - t.Fatal("failed to load templates") - } - - s := &Server{ - storage: memStorage, - logger: logger, - templates: templates, - supportedResponseTypes: map[string]bool{"code": true}, - now: time.Now, - connectors: make(map[string]Connector), - } - - r := mux.NewRouter() - r.HandleFunc("/auth/{connector}", s.handleConnectorLogin) - s.mux = r - - clientID := "clientID" - clientSecret := "secret" - redirectURL := "localhost:5555" + "/callback" - client := storage.Client{ - ID: clientID, - Secret: clientSecret, - RedirectURIs: []string{redirectURL}, - } - if err := memStorage.CreateClient(client); err != nil { - t.Fatal("failed to create client") - } - - createConnector := func(t *testing.T, id string) storage.Connector { - connector := storage.Connector{ - ID: id, - Type: "mockCallback", - Name: "Mock", - ResourceVersion: "1", - } - if err := memStorage.CreateConnector(connector); err != nil { - t.Fatalf("failed to create connector %v", id) - } - - return connector - } - - connector1 := createConnector(t, "mock1") - connector2 := createConnector(t, "mock2") - - authReq := storage.AuthRequest{ - ID: storage.NewID(), - } - if err := memStorage.CreateAuthRequest(authReq); err != nil { - t.Fatal("failed to create auth request") - } - - createConnectorLoginRequest := func(connID string) *http.Request { - req := httptest.NewRequest("GET", "/auth/"+connID, nil) - q := req.URL.Query() - q.Add("req", authReq.ID) - q.Add("redirect_uri", redirectURL) - q.Add("scope", "openid") - q.Add("response_type", "code") - req.URL.RawQuery = q.Encode() - return req - } - - recorder := httptest.NewRecorder() - s.ServeHTTP(recorder, createConnectorLoginRequest(connector1.ID)) - if recorder.Code != 302 { - t.Fatal("failed to process request") - } - - recorder2 := httptest.NewRecorder() - s.ServeHTTP(recorder2, createConnectorLoginRequest(connector2.ID)) - if recorder2.Code != 500 { - t.Error("attempt to overwrite connector on auth request should fail") - } -} - // TestHandleAuthCode checks that it is forbidden to use same code twice func TestHandleAuthCode(t *testing.T) { tests := []struct { diff --git a/server/server.go b/server/server.go index 84c3a82f..957b62dc 100644 --- a/server/server.go +++ b/server/server.go @@ -341,6 +341,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleWithCORS("/userinfo", s.handleUserInfo) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) + handleFunc("/auth/{connector}/login", s.handlePasswordLogin) handleFunc("/device", s.handleDeviceExchange) handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/code", s.handleDeviceCode) diff --git a/server/templates.go b/server/templates.go index 4eff1e75..e46855b1 100644 --- a/server/templates.go +++ b/server/templates.go @@ -284,15 +284,15 @@ func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []c return renderTemplate(w, t.loginTmpl, data) } -func (t *templates) password(r *http.Request, w http.ResponseWriter, postURL, lastUsername, usernamePrompt string, lastWasInvalid, showBacklink bool) error { +func (t *templates) password(r *http.Request, w http.ResponseWriter, postURL, lastUsername, usernamePrompt string, lastWasInvalid bool, backLink string) error { data := struct { PostURL string - BackLink bool + BackLink string Username string UsernamePrompt string Invalid bool ReqPath string - }{postURL, showBacklink, lastUsername, usernamePrompt, lastWasInvalid, r.URL.Path} + }{postURL, backLink, lastUsername, usernamePrompt, lastWasInvalid, r.URL.Path} return renderTemplate(w, t.passwordTmpl, data) } diff --git a/web/templates/password.html b/web/templates/password.html index 5b585b4e..8c77b26e 100644 --- a/web/templates/password.html +++ b/web/templates/password.html @@ -27,7 +27,7 @@ {{ if .BackLink }} {{ end }}