Merge pull request #1481 from LanceH/master

Added "connector_id" to skip straight to a connector (similar to when len(connector) is 1.
This commit is contained in:
Stephan Renatus 2019-07-23 11:31:25 +02:00 committed by GitHub
commit 421c26fdf5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 135 additions and 3 deletions

View file

@ -233,6 +233,18 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
return return
} }
// Redirect if a client chooses a specific connector_id
if authReq.ConnectorID != "" {
for _, c := range connectors {
if c.ID == authReq.ConnectorID {
http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound)
return
}
}
s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
return
}
if len(connectors) == 1 { if len(connectors) == 1 {
for _, c := range connectors { for _, c := range connectors {
// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter

View file

@ -100,6 +100,7 @@ const (
errUnsupportedGrantType = "unsupported_grant_type" errUnsupportedGrantType = "unsupported_grant_type"
errInvalidGrant = "invalid_grant" errInvalidGrant = "invalid_grant"
errInvalidClient = "invalid_client" errInvalidClient = "invalid_client"
errInvalidConnectorID = "invalid_connector_id"
) )
const ( const (
@ -391,6 +392,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
clientID := q.Get("client_id") clientID := q.Get("client_id")
state := q.Get("state") state := q.Get("state")
nonce := q.Get("nonce") nonce := q.Get("nonce")
connectorID := q.Get("connector_id")
// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this. // Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
scopes := strings.Fields(q.Get("scope")) scopes := strings.Fields(q.Get("scope"))
responseTypes := strings.Fields(q.Get("response_type")) responseTypes := strings.Fields(q.Get("response_type"))
@ -405,6 +407,16 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
return req, &authErr{"", "", errServerError, ""} return req, &authErr{"", "", errServerError, ""}
} }
if connectorID != "" {
connectors, err := s.storage.ListConnectors()
if err != nil {
return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
}
if !validateConnectorID(connectors, connectorID) {
return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
}
}
if !validateRedirectURI(client, redirectURI) { if !validateRedirectURI(client, redirectURI) {
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
return req, &authErr{"", "", errInvalidRequest, description} return req, &authErr{"", "", errInvalidRequest, description}
@ -509,6 +521,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
Scopes: scopes, Scopes: scopes,
RedirectURI: redirectURI, RedirectURI: redirectURI,
ResponseTypes: responseTypes, ResponseTypes: responseTypes,
ConnectorID: connectorID,
}, nil }, nil
} }
@ -568,6 +581,15 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
return err == nil && host == "localhost" return err == nil && host == "localhost"
} }
func validateConnectorID(connectors []storage.Connector, connectorID string) bool {
for _, c := range connectors {
if c.ID == connectorID {
return true
}
}
return false
}
// storageKeySet implements the oidc.KeySet interface backed by Dex storage // storageKeySet implements the oidc.KeySet interface backed by Dex storage
type storageKeySet struct { type storageKeySet struct {
storage.Storage storage.Storage

View file

@ -10,7 +10,7 @@ import (
"strings" "strings"
"testing" "testing"
jose "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory" "github.com/dexidp/dex/storage/memory"
@ -145,6 +145,58 @@ func TestParseAuthorizationRequest(t *testing.T) {
}, },
wantErr: true, wantErr: true,
}, },
{
name: "choose connector_id",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code", "id_token", "token"},
queryParams: map[string]string{
"connector_id": "mock",
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code id_token",
"scope": "openid email profile",
},
},
{
name: "choose second connector_id",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code", "id_token", "token"},
queryParams: map[string]string{
"connector_id": "mock2",
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code id_token",
"scope": "openid email profile",
},
},
{
name: "choose invalid connector_id",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code", "id_token", "token"},
queryParams: map[string]string{
"connector_id": "bogus",
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code id_token",
"scope": "openid email profile",
},
wantErr: true,
},
} }
for _, tc := range tests { for _, tc := range tests {
@ -152,7 +204,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, server := newTestServer(ctx, t, func(c *Config) { httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) {
c.SupportedResponseTypes = tc.supportedResponseTypes c.SupportedResponseTypes = tc.supportedResponseTypes
c.Storage = storage.WithStaticClients(c.Storage, tc.clients) c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
}) })
@ -162,7 +214,6 @@ func TestParseAuthorizationRequest(t *testing.T) {
for k, v := range tc.queryParams { for k, v := range tc.queryParams {
params.Set(k, v) params.Set(k, v)
} }
var req *http.Request var req *http.Request
if tc.usePOST { if tc.usePOST {
body := strings.NewReader(params.Encode()) body := strings.NewReader(params.Encode())

View file

@ -117,6 +117,53 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
return s, server return s, server
} }
func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r)
}))
config := Config{
Issuer: s.URL,
Storage: memory.New(logger),
Web: WebConfig{
Dir: "../web",
},
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
}
if updateConfig != nil {
updateConfig(&config)
}
s.URL = config.Issuer
connector := storage.Connector{
ID: "mock",
Type: "mockCallback",
Name: "Mock",
ResourceVersion: "1",
}
connector2 := storage.Connector{
ID: "mock2",
Type: "mockCallback",
Name: "Mock",
ResourceVersion: "1",
}
if err := config.Storage.CreateConnector(connector); err != nil {
t.Fatalf("create connector: %v", err)
}
if err := config.Storage.CreateConnector(connector2); err != nil {
t.Fatalf("create connector: %v", err)
}
var err error
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
t.Fatal(err)
}
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
return s, server
}
func TestNewTestServer(t *testing.T) { func TestNewTestServer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()