diff --git a/server/handlers.go b/server/handlers.go index babd5417..f7330440 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -283,7 +283,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { connID := mux.Vars(r)["connector"] conn, err := s.getConnector(connID) if err != nil { - s.logger.Errorf("Failed to create authorization request: %v", err) + s.logger.Errorf("Failed to get connector: %v", err) s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") return } @@ -304,6 +304,9 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Set the connector being used for the login. if authReq.ConnectorID != connID { updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { + if a.ConnectorID != "" { + return a, fmt.Errorf("connector is already set for this auth request") + } a.ConnectorID = connID return a, nil } diff --git a/server/handlers_test.go b/server/handlers_test.go index b30076dd..12dce3a8 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -8,8 +8,12 @@ import ( "net/http" "net/http/httptest" "testing" + "time" + + "github.com/gorilla/mux" "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/memory" ) func TestHandleHealth(t *testing.T) { @@ -119,3 +123,84 @@ func TestHandleInvalidSAMLCallbacks(t *testing.T) { } } } + +func TestConnectorLoginDoesNotAllowToChangeConnectorForAuthRequest(t *testing.T) { + memStorage := memory.New(logger) + + templates, err := loadTemplates(webConfig{}, "../web/templates") + if err != nil { + t.Fatal("failed to load tempalates") + } + + s := &Server{ + storage: memStorage, + logger: logger, + templates: templates, + supportedResponseTypes: map[string]bool{"code": true}, + now: time.Now, + connectors: make(map[string]Connector), + } + + r := mux.NewRouter() + r.HandleFunc("/auth/{connector}", s.handleConnectorLogin) + s.mux = r + + clientID := "clientID" + clientSecret := "secret" + redirectURL := "localhost:5555" + "/callback" + client := storage.Client{ + ID: clientID, + Secret: clientSecret, + RedirectURIs: []string{redirectURL}, + } + if err := memStorage.CreateClient(client); err != nil { + t.Fatal("failed to create client") + } + + createConnector := func(t *testing.T, id string) storage.Connector { + connector := storage.Connector{ + ID: id, + Type: "mockCallback", + Name: "Mock", + ResourceVersion: "1", + } + if err := memStorage.CreateConnector(connector); err != nil { + t.Fatalf("failed to create connector %v", id) + } + + return connector + } + + connector1 := createConnector(t, "mock1") + connector2 := createConnector(t, "mock2") + + authReq := storage.AuthRequest{ + ID: storage.NewID(), + } + if err := memStorage.CreateAuthRequest(authReq); err != nil { + t.Fatal("failed to create auth request") + } + + createConnectorLoginRequest := func(connID string) *http.Request { + req := httptest.NewRequest("GET", "/auth/"+connID, nil) + q := req.URL.Query() + q.Add("req", authReq.ID) + q.Add("redirect_uri", redirectURL) + q.Add("scope", "openid") + q.Add("response_type", "code") + req.URL.RawQuery = q.Encode() + return req + } + + recorder := httptest.NewRecorder() + s.ServeHTTP(recorder, createConnectorLoginRequest(connector1.ID)) + if recorder.Code != 302 { + t.Fatal("failed to process request") + } + + recorder2 := httptest.NewRecorder() + s.ServeHTTP(recorder2, createConnectorLoginRequest(connector2.ID)) + if recorder2.Code != 500 { + t.Error("attempt to overwrite connector on auth request should fail") + } +}