diff --git a/server/handlers.go b/server/handlers.go index f39db575..ae98452c 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -124,6 +124,66 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { // handleAuthorization handles the OAuth2 auth endpoint. func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { + // Extract the arguments + if err := r.ParseForm(); err != nil { + s.logger.Errorf("Failed to parse arguments: %v", err) + + s.renderError(r, w, http.StatusBadRequest, "Bad query/form arguments") + return + } + + connectorID := r.Form.Get("connector_id") + + connectors, err := s.storage.ListConnectors() + if err != nil { + s.logger.Errorf("Failed to get list of connectors: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.") + return + } + + // Construct a URL with all of the arguments in its query + connURL := url.URL{ + RawQuery: r.Form.Encode(), + } + + // Redirect if a client chooses a specific connector_id + if connectorID != "" { + for _, c := range connectors { + if c.ID == connectorID { + connURL.Path = s.absPath("/auth", c.ID) + http.Redirect(w, r, connURL.String(), http.StatusFound) + return + } + } + s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound) + return + } + + if len(connectors) == 1 && !s.alwaysShowLogin { + for _, c := range connectors { + connURL.Path = s.absPath("/auth", c.ID) + http.Redirect(w, r, connURL.String(), http.StatusFound) + return + } + } + + connectorInfos := make([]connectorInfo, len(connectors)) + for index, conn := range connectors { + connURL.Path = s.absPath("/auth", conn.ID) + connectorInfos[index] = connectorInfo{ + ID: conn.ID, + Name: conn.Name, + Type: conn.Type, + URL: connURL.String(), + } + } + + if err := s.templates.login(r, w, connectorInfos); err != nil { + s.logger.Errorf("Server template error: %v", err) + } +} + +func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { authReq, err := s.parseAuthorizationRequest(r) if err != nil { s.logger.Errorf("Failed to parse authorization request: %v", err) @@ -145,64 +205,6 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { return } - // TODO(ericchiang): Create this authorization request later in the login flow - // so users don't hit "not found" database errors if they wait at the login - // screen too long. - // - // See: https://github.com/dexidp/dex/issues/646 - authReq.Expiry = s.now().Add(s.authRequestsValidFor) - if err := s.storage.CreateAuthRequest(*authReq); err != nil { - s.logger.Errorf("Failed to create authorization request: %v", err) - s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.") - return - } - - connectors, err := s.storage.ListConnectors() - if err != nil { - s.logger.Errorf("Failed to get list of connectors: %v", err) - s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.") - return - } - - // Redirect if a client chooses a specific connector_id - if authReq.ConnectorID != "" { - for _, c := range connectors { - if c.ID == authReq.ConnectorID { - http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) - return - } - } - s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound) - return - } - - if len(connectors) == 1 && !s.alwaysShowLogin { - for _, c := range connectors { - // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter - // on create the auth request. - http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) - return - } - } - - connectorInfos := make([]connectorInfo, len(connectors)) - for index, conn := range connectors { - connectorInfos[index] = connectorInfo{ - ID: conn.ID, - Name: conn.Name, - Type: conn.Type, - // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter - // on create the auth request. - URL: s.absPath("/auth", conn.ID) + "?req=" + authReq.ID, - } - } - - if err := s.templates.login(r, w, connectorInfos); err != nil { - s.logger.Errorf("Server template error: %v", err) - } -} - -func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { connID := mux.Vars(r)["connector"] conn, err := s.getConnector(connID) if err != nil { @@ -211,33 +213,22 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } - authReqID := r.FormValue("req") - - authReq, err := s.storage.GetAuthRequest(authReqID) - if err != nil { - s.logger.Errorf("Failed to get auth request: %v", err) - if err == storage.ErrNotFound { - s.renderError(r, w, http.StatusBadRequest, "Login session expired.") - } else { - s.renderError(r, w, http.StatusInternalServerError, "Database error.") - } + // Set the connector being used for the login. + if authReq.ConnectorID != "" && authReq.ConnectorID != connID { + s.logger.Errorf("Mismatched connector ID in auth request: %s vs %s", + authReq.ConnectorID, connID) + s.renderError(r, w, http.StatusBadRequest, "Bad connector ID") return } - // Set the connector being used for the login. - if authReq.ConnectorID != connID { - updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { - if a.ConnectorID != "" { - return a, fmt.Errorf("connector is already set for this auth request") - } - a.ConnectorID = connID - return a, nil - } - if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil { - s.logger.Errorf("Failed to set connector ID on auth request: %v", err) - s.renderError(r, w, http.StatusInternalServerError, "Database error.") - return - } + authReq.ConnectorID = connID + + // Actually create the auth request + authReq.Expiry = s.now().Add(s.authRequestsValidFor) + if err := s.storage.CreateAuthRequest(*authReq); err != nil { + s.logger.Errorf("Failed to create authorization request: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.") + return } scopes := parseScopes(authReq.Scopes) @@ -250,7 +241,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Use the auth request ID as the "state" token. // // TODO(ericchiang): Is this appropriate or should we also be using a nonce? - callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID) + callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) if err != nil { s.logger.Errorf("Connector %q returned error when creating callback: %v", connID, err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") @@ -262,7 +253,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { s.logger.Errorf("Server template error: %v", err) } case connector.SAMLConnector: - action, value, err := conn.POSTData(scopes, authReqID) + action, value, err := conn.POSTData(scopes, authReq.ID) if err != nil { s.logger.Errorf("Creating SAML data: %v", err) s.renderError(r, w, http.StatusInternalServerError, "Connector Login Error") @@ -285,7 +276,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { document.forms[0].submit(); - `, action, value, authReqID) + `, action, value, authReq.ID) default: s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") } @@ -311,7 +302,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } return } - redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) + redirectURL, err := s.finalizeLogin(identity, *authReq, conn.Connector) if err != nil { s.logger.Errorf("Failed to finalize login: %v", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") diff --git a/server/handlers_test.go b/server/handlers_test.go index d195af64..fd4f5147 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -19,7 +19,6 @@ import ( "golang.org/x/oauth2" "github.com/dexidp/dex/storage" - "github.com/dexidp/dex/storage/memory" ) func TestHandleHealth(t *testing.T) { @@ -133,87 +132,6 @@ func TestHandleInvalidSAMLCallbacks(t *testing.T) { } } -func TestConnectorLoginDoesNotAllowToChangeConnectorForAuthRequest(t *testing.T) { - memStorage := memory.New(logger) - - templates, err := loadTemplates(webConfig{webFS: os.DirFS("../web")}, "templates") - if err != nil { - t.Fatal("failed to load templates") - } - - s := &Server{ - storage: memStorage, - logger: logger, - templates: templates, - supportedResponseTypes: map[string]bool{"code": true}, - now: time.Now, - connectors: make(map[string]Connector), - } - - r := mux.NewRouter() - r.HandleFunc("/auth/{connector}", s.handleConnectorLogin) - s.mux = r - - clientID := "clientID" - clientSecret := "secret" - redirectURL := "localhost:5555" + "/callback" - client := storage.Client{ - ID: clientID, - Secret: clientSecret, - RedirectURIs: []string{redirectURL}, - } - if err := memStorage.CreateClient(client); err != nil { - t.Fatal("failed to create client") - } - - createConnector := func(t *testing.T, id string) storage.Connector { - connector := storage.Connector{ - ID: id, - Type: "mockCallback", - Name: "Mock", - ResourceVersion: "1", - } - if err := memStorage.CreateConnector(connector); err != nil { - t.Fatalf("failed to create connector %v", id) - } - - return connector - } - - connector1 := createConnector(t, "mock1") - connector2 := createConnector(t, "mock2") - - authReq := storage.AuthRequest{ - ID: storage.NewID(), - } - if err := memStorage.CreateAuthRequest(authReq); err != nil { - t.Fatal("failed to create auth request") - } - - createConnectorLoginRequest := func(connID string) *http.Request { - req := httptest.NewRequest("GET", "/auth/"+connID, nil) - q := req.URL.Query() - q.Add("req", authReq.ID) - q.Add("redirect_uri", redirectURL) - q.Add("scope", "openid") - q.Add("response_type", "code") - req.URL.RawQuery = q.Encode() - return req - } - - recorder := httptest.NewRecorder() - s.ServeHTTP(recorder, createConnectorLoginRequest(connector1.ID)) - if recorder.Code != 302 { - t.Fatal("failed to process request") - } - - recorder2 := httptest.NewRecorder() - s.ServeHTTP(recorder2, createConnectorLoginRequest(connector2.ID)) - if recorder2.Code != 500 { - t.Error("attempt to overwrite connector on auth request should fail") - } -} - // TestHandleAuthCode checks that it is forbidden to use same code twice func TestHandleAuthCode(t *testing.T) { tests := []struct {