forked from mystiq/dex
Merge pull request #1481 from LanceH/master
Added "connector_id" to skip straight to a connector (similar to when len(connector) is 1.
This commit is contained in:
commit
421c26fdf5
4 changed files with 135 additions and 3 deletions
|
@ -233,6 +233,18 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Redirect if a client chooses a specific connector_id
|
||||||
|
if authReq.ConnectorID != "" {
|
||||||
|
for _, c := range connectors {
|
||||||
|
if c.ID == authReq.ConnectorID {
|
||||||
|
http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if len(connectors) == 1 {
|
if len(connectors) == 1 {
|
||||||
for _, c := range connectors {
|
for _, c := range connectors {
|
||||||
// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
|
// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
|
||||||
|
|
|
@ -100,6 +100,7 @@ const (
|
||||||
errUnsupportedGrantType = "unsupported_grant_type"
|
errUnsupportedGrantType = "unsupported_grant_type"
|
||||||
errInvalidGrant = "invalid_grant"
|
errInvalidGrant = "invalid_grant"
|
||||||
errInvalidClient = "invalid_client"
|
errInvalidClient = "invalid_client"
|
||||||
|
errInvalidConnectorID = "invalid_connector_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -391,6 +392,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
|
||||||
clientID := q.Get("client_id")
|
clientID := q.Get("client_id")
|
||||||
state := q.Get("state")
|
state := q.Get("state")
|
||||||
nonce := q.Get("nonce")
|
nonce := q.Get("nonce")
|
||||||
|
connectorID := q.Get("connector_id")
|
||||||
// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
|
// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
|
||||||
scopes := strings.Fields(q.Get("scope"))
|
scopes := strings.Fields(q.Get("scope"))
|
||||||
responseTypes := strings.Fields(q.Get("response_type"))
|
responseTypes := strings.Fields(q.Get("response_type"))
|
||||||
|
@ -405,6 +407,16 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
|
||||||
return req, &authErr{"", "", errServerError, ""}
|
return req, &authErr{"", "", errServerError, ""}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if connectorID != "" {
|
||||||
|
connectors, err := s.storage.ListConnectors()
|
||||||
|
if err != nil {
|
||||||
|
return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
|
||||||
|
}
|
||||||
|
if !validateConnectorID(connectors, connectorID) {
|
||||||
|
return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !validateRedirectURI(client, redirectURI) {
|
if !validateRedirectURI(client, redirectURI) {
|
||||||
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
|
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
|
||||||
return req, &authErr{"", "", errInvalidRequest, description}
|
return req, &authErr{"", "", errInvalidRequest, description}
|
||||||
|
@ -509,6 +521,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
|
||||||
Scopes: scopes,
|
Scopes: scopes,
|
||||||
RedirectURI: redirectURI,
|
RedirectURI: redirectURI,
|
||||||
ResponseTypes: responseTypes,
|
ResponseTypes: responseTypes,
|
||||||
|
ConnectorID: connectorID,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -568,6 +581,15 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
|
||||||
return err == nil && host == "localhost"
|
return err == nil && host == "localhost"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateConnectorID(connectors []storage.Connector, connectorID string) bool {
|
||||||
|
for _, c := range connectors {
|
||||||
|
if c.ID == connectorID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
|
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
|
||||||
type storageKeySet struct {
|
type storageKeySet struct {
|
||||||
storage.Storage
|
storage.Storage
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
jose "gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
"github.com/dexidp/dex/storage/memory"
|
"github.com/dexidp/dex/storage/memory"
|
||||||
|
@ -145,6 +145,58 @@ func TestParseAuthorizationRequest(t *testing.T) {
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "choose connector_id",
|
||||||
|
clients: []storage.Client{
|
||||||
|
{
|
||||||
|
ID: "bar",
|
||||||
|
RedirectURIs: []string{"https://example.com/bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
supportedResponseTypes: []string{"code", "id_token", "token"},
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"connector_id": "mock",
|
||||||
|
"client_id": "bar",
|
||||||
|
"redirect_uri": "https://example.com/bar",
|
||||||
|
"response_type": "code id_token",
|
||||||
|
"scope": "openid email profile",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "choose second connector_id",
|
||||||
|
clients: []storage.Client{
|
||||||
|
{
|
||||||
|
ID: "bar",
|
||||||
|
RedirectURIs: []string{"https://example.com/bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
supportedResponseTypes: []string{"code", "id_token", "token"},
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"connector_id": "mock2",
|
||||||
|
"client_id": "bar",
|
||||||
|
"redirect_uri": "https://example.com/bar",
|
||||||
|
"response_type": "code id_token",
|
||||||
|
"scope": "openid email profile",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "choose invalid connector_id",
|
||||||
|
clients: []storage.Client{
|
||||||
|
{
|
||||||
|
ID: "bar",
|
||||||
|
RedirectURIs: []string{"https://example.com/bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
supportedResponseTypes: []string{"code", "id_token", "token"},
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"connector_id": "bogus",
|
||||||
|
"client_id": "bar",
|
||||||
|
"redirect_uri": "https://example.com/bar",
|
||||||
|
"response_type": "code id_token",
|
||||||
|
"scope": "openid email profile",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
@ -152,7 +204,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpServer, server := newTestServer(ctx, t, func(c *Config) {
|
httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) {
|
||||||
c.SupportedResponseTypes = tc.supportedResponseTypes
|
c.SupportedResponseTypes = tc.supportedResponseTypes
|
||||||
c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
|
c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
|
||||||
})
|
})
|
||||||
|
@ -162,7 +214,6 @@ func TestParseAuthorizationRequest(t *testing.T) {
|
||||||
for k, v := range tc.queryParams {
|
for k, v := range tc.queryParams {
|
||||||
params.Set(k, v)
|
params.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
if tc.usePOST {
|
if tc.usePOST {
|
||||||
body := strings.NewReader(params.Encode())
|
body := strings.NewReader(params.Encode())
|
||||||
|
|
|
@ -117,6 +117,53 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
|
||||||
return s, server
|
return s, server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
|
||||||
|
var server *Server
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
server.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
|
||||||
|
config := Config{
|
||||||
|
Issuer: s.URL,
|
||||||
|
Storage: memory.New(logger),
|
||||||
|
Web: WebConfig{
|
||||||
|
Dir: "../web",
|
||||||
|
},
|
||||||
|
Logger: logger,
|
||||||
|
PrometheusRegistry: prometheus.NewRegistry(),
|
||||||
|
}
|
||||||
|
if updateConfig != nil {
|
||||||
|
updateConfig(&config)
|
||||||
|
}
|
||||||
|
s.URL = config.Issuer
|
||||||
|
|
||||||
|
connector := storage.Connector{
|
||||||
|
ID: "mock",
|
||||||
|
Type: "mockCallback",
|
||||||
|
Name: "Mock",
|
||||||
|
ResourceVersion: "1",
|
||||||
|
}
|
||||||
|
connector2 := storage.Connector{
|
||||||
|
ID: "mock2",
|
||||||
|
Type: "mockCallback",
|
||||||
|
Name: "Mock",
|
||||||
|
ResourceVersion: "1",
|
||||||
|
}
|
||||||
|
if err := config.Storage.CreateConnector(connector); err != nil {
|
||||||
|
t.Fatalf("create connector: %v", err)
|
||||||
|
}
|
||||||
|
if err := config.Storage.CreateConnector(connector2); err != nil {
|
||||||
|
t.Fatalf("create connector: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
||||||
|
return s, server
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewTestServer(t *testing.T) {
|
func TestNewTestServer(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
Loading…
Reference in a new issue