forked from mystiq/dex
Merge pull request #1481 from LanceH/master
Added "connector_id" to skip straight to a connector (similar to when len(connector) is 1.
This commit is contained in:
commit
421c26fdf5
4 changed files with 135 additions and 3 deletions
|
@ -233,6 +233,18 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Redirect if a client chooses a specific connector_id
|
||||
if authReq.ConnectorID != "" {
|
||||
for _, c := range connectors {
|
||||
if c.ID == authReq.ConnectorID {
|
||||
http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if len(connectors) == 1 {
|
||||
for _, c := range connectors {
|
||||
// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
|
||||
|
|
|
@ -100,6 +100,7 @@ const (
|
|||
errUnsupportedGrantType = "unsupported_grant_type"
|
||||
errInvalidGrant = "invalid_grant"
|
||||
errInvalidClient = "invalid_client"
|
||||
errInvalidConnectorID = "invalid_connector_id"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -391,6 +392,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
|
|||
clientID := q.Get("client_id")
|
||||
state := q.Get("state")
|
||||
nonce := q.Get("nonce")
|
||||
connectorID := q.Get("connector_id")
|
||||
// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
|
||||
scopes := strings.Fields(q.Get("scope"))
|
||||
responseTypes := strings.Fields(q.Get("response_type"))
|
||||
|
@ -405,6 +407,16 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
|
|||
return req, &authErr{"", "", errServerError, ""}
|
||||
}
|
||||
|
||||
if connectorID != "" {
|
||||
connectors, err := s.storage.ListConnectors()
|
||||
if err != nil {
|
||||
return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
|
||||
}
|
||||
if !validateConnectorID(connectors, connectorID) {
|
||||
return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
|
||||
}
|
||||
}
|
||||
|
||||
if !validateRedirectURI(client, redirectURI) {
|
||||
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
|
||||
return req, &authErr{"", "", errInvalidRequest, description}
|
||||
|
@ -509,6 +521,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
|
|||
Scopes: scopes,
|
||||
RedirectURI: redirectURI,
|
||||
ResponseTypes: responseTypes,
|
||||
ConnectorID: connectorID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -568,6 +581,15 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
|
|||
return err == nil && host == "localhost"
|
||||
}
|
||||
|
||||
func validateConnectorID(connectors []storage.Connector, connectorID string) bool {
|
||||
for _, c := range connectors {
|
||||
if c.ID == connectorID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
|
||||
type storageKeySet struct {
|
||||
storage.Storage
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/memory"
|
||||
|
@ -145,6 +145,58 @@ func TestParseAuthorizationRequest(t *testing.T) {
|
|||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "choose connector_id",
|
||||
clients: []storage.Client{
|
||||
{
|
||||
ID: "bar",
|
||||
RedirectURIs: []string{"https://example.com/bar"},
|
||||
},
|
||||
},
|
||||
supportedResponseTypes: []string{"code", "id_token", "token"},
|
||||
queryParams: map[string]string{
|
||||
"connector_id": "mock",
|
||||
"client_id": "bar",
|
||||
"redirect_uri": "https://example.com/bar",
|
||||
"response_type": "code id_token",
|
||||
"scope": "openid email profile",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "choose second connector_id",
|
||||
clients: []storage.Client{
|
||||
{
|
||||
ID: "bar",
|
||||
RedirectURIs: []string{"https://example.com/bar"},
|
||||
},
|
||||
},
|
||||
supportedResponseTypes: []string{"code", "id_token", "token"},
|
||||
queryParams: map[string]string{
|
||||
"connector_id": "mock2",
|
||||
"client_id": "bar",
|
||||
"redirect_uri": "https://example.com/bar",
|
||||
"response_type": "code id_token",
|
||||
"scope": "openid email profile",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "choose invalid connector_id",
|
||||
clients: []storage.Client{
|
||||
{
|
||||
ID: "bar",
|
||||
RedirectURIs: []string{"https://example.com/bar"},
|
||||
},
|
||||
},
|
||||
supportedResponseTypes: []string{"code", "id_token", "token"},
|
||||
queryParams: map[string]string{
|
||||
"connector_id": "bogus",
|
||||
"client_id": "bar",
|
||||
"redirect_uri": "https://example.com/bar",
|
||||
"response_type": "code id_token",
|
||||
"scope": "openid email profile",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
@ -152,7 +204,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
httpServer, server := newTestServer(ctx, t, func(c *Config) {
|
||||
httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) {
|
||||
c.SupportedResponseTypes = tc.supportedResponseTypes
|
||||
c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
|
||||
})
|
||||
|
@ -162,7 +214,6 @@ func TestParseAuthorizationRequest(t *testing.T) {
|
|||
for k, v := range tc.queryParams {
|
||||
params.Set(k, v)
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
if tc.usePOST {
|
||||
body := strings.NewReader(params.Encode())
|
||||
|
|
|
@ -117,6 +117,53 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
|
|||
return s, server
|
||||
}
|
||||
|
||||
func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
|
||||
var server *Server
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server.ServeHTTP(w, r)
|
||||
}))
|
||||
|
||||
config := Config{
|
||||
Issuer: s.URL,
|
||||
Storage: memory.New(logger),
|
||||
Web: WebConfig{
|
||||
Dir: "../web",
|
||||
},
|
||||
Logger: logger,
|
||||
PrometheusRegistry: prometheus.NewRegistry(),
|
||||
}
|
||||
if updateConfig != nil {
|
||||
updateConfig(&config)
|
||||
}
|
||||
s.URL = config.Issuer
|
||||
|
||||
connector := storage.Connector{
|
||||
ID: "mock",
|
||||
Type: "mockCallback",
|
||||
Name: "Mock",
|
||||
ResourceVersion: "1",
|
||||
}
|
||||
connector2 := storage.Connector{
|
||||
ID: "mock2",
|
||||
Type: "mockCallback",
|
||||
Name: "Mock",
|
||||
ResourceVersion: "1",
|
||||
}
|
||||
if err := config.Storage.CreateConnector(connector); err != nil {
|
||||
t.Fatalf("create connector: %v", err)
|
||||
}
|
||||
if err := config.Storage.CreateConnector(connector2); err != nil {
|
||||
t.Fatalf("create connector: %v", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
||||
return s, server
|
||||
}
|
||||
|
||||
func TestNewTestServer(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
|
Loading…
Reference in a new issue