diff --git a/server/handlers.go b/server/handlers.go index e7e15c27..70ef1321 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -233,6 +233,18 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { return } + // Redirect if a client chooses a specific connector_id + if authReq.ConnectorID != "" { + for _, c := range connectors { + if c.ID == authReq.ConnectorID { + http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) + return + } + } + s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound) + return + } + if len(connectors) == 1 { for _, c := range connectors { // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter diff --git a/server/oauth2.go b/server/oauth2.go index 68a72f66..5b7a421c 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -100,6 +100,7 @@ const ( errUnsupportedGrantType = "unsupported_grant_type" errInvalidGrant = "invalid_grant" errInvalidClient = "invalid_client" + errInvalidConnectorID = "invalid_connector_id" ) const ( @@ -391,6 +392,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq clientID := q.Get("client_id") state := q.Get("state") nonce := q.Get("nonce") + connectorID := q.Get("connector_id") // Some clients, like the old go-oidc, provide extra whitespace. Tolerate this. scopes := strings.Fields(q.Get("scope")) responseTypes := strings.Fields(q.Get("response_type")) @@ -405,6 +407,16 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq return req, &authErr{"", "", errServerError, ""} } + if connectorID != "" { + connectors, err := s.storage.ListConnectors() + if err != nil { + return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"} + } + if !validateConnectorID(connectors, connectorID) { + return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"} + } + } + if !validateRedirectURI(client, redirectURI) { description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) return req, &authErr{"", "", errInvalidRequest, description} @@ -509,6 +521,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq Scopes: scopes, RedirectURI: redirectURI, ResponseTypes: responseTypes, + ConnectorID: connectorID, }, nil } @@ -568,6 +581,15 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool { return err == nil && host == "localhost" } +func validateConnectorID(connectors []storage.Connector, connectorID string) bool { + for _, c := range connectors { + if c.ID == connectorID { + return true + } + } + return false +} + // storageKeySet implements the oidc.KeySet interface backed by Dex storage type storageKeySet struct { storage.Storage diff --git a/server/oauth2_test.go b/server/oauth2_test.go index bb8d2723..ad122055 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -10,7 +10,7 @@ import ( "strings" "testing" - jose "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/memory" @@ -145,6 +145,58 @@ func TestParseAuthorizationRequest(t *testing.T) { }, wantErr: true, }, + { + name: "choose connector_id", + clients: []storage.Client{ + { + ID: "bar", + RedirectURIs: []string{"https://example.com/bar"}, + }, + }, + supportedResponseTypes: []string{"code", "id_token", "token"}, + queryParams: map[string]string{ + "connector_id": "mock", + "client_id": "bar", + "redirect_uri": "https://example.com/bar", + "response_type": "code id_token", + "scope": "openid email profile", + }, + }, + { + name: "choose second connector_id", + clients: []storage.Client{ + { + ID: "bar", + RedirectURIs: []string{"https://example.com/bar"}, + }, + }, + supportedResponseTypes: []string{"code", "id_token", "token"}, + queryParams: map[string]string{ + "connector_id": "mock2", + "client_id": "bar", + "redirect_uri": "https://example.com/bar", + "response_type": "code id_token", + "scope": "openid email profile", + }, + }, + { + name: "choose invalid connector_id", + clients: []storage.Client{ + { + ID: "bar", + RedirectURIs: []string{"https://example.com/bar"}, + }, + }, + supportedResponseTypes: []string{"code", "id_token", "token"}, + queryParams: map[string]string{ + "connector_id": "bogus", + "client_id": "bar", + "redirect_uri": "https://example.com/bar", + "response_type": "code id_token", + "scope": "openid email profile", + }, + wantErr: true, + }, } for _, tc := range tests { @@ -152,7 +204,7 @@ func TestParseAuthorizationRequest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - httpServer, server := newTestServer(ctx, t, func(c *Config) { + httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) { c.SupportedResponseTypes = tc.supportedResponseTypes c.Storage = storage.WithStaticClients(c.Storage, tc.clients) }) @@ -162,7 +214,6 @@ func TestParseAuthorizationRequest(t *testing.T) { for k, v := range tc.queryParams { params.Set(k, v) } - var req *http.Request if tc.usePOST { body := strings.NewReader(params.Encode()) diff --git a/server/server_test.go b/server/server_test.go index 1cc145a0..2b4c6453 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -117,6 +117,53 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi return s, server } +func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { + var server *Server + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server.ServeHTTP(w, r) + })) + + config := Config{ + Issuer: s.URL, + Storage: memory.New(logger), + Web: WebConfig{ + Dir: "../web", + }, + Logger: logger, + PrometheusRegistry: prometheus.NewRegistry(), + } + if updateConfig != nil { + updateConfig(&config) + } + s.URL = config.Issuer + + connector := storage.Connector{ + ID: "mock", + Type: "mockCallback", + Name: "Mock", + ResourceVersion: "1", + } + connector2 := storage.Connector{ + ID: "mock2", + Type: "mockCallback", + Name: "Mock", + ResourceVersion: "1", + } + if err := config.Storage.CreateConnector(connector); err != nil { + t.Fatalf("create connector: %v", err) + } + if err := config.Storage.CreateConnector(connector2); err != nil { + t.Fatalf("create connector: %v", err) + } + + var err error + if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil { + t.Fatal(err) + } + server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. + return s, server +} + func TestNewTestServer(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel()