make userID configurable
This commit is contained in:
parent
59560c9919
commit
9650836851
3 changed files with 247 additions and 13 deletions
|
@ -66,6 +66,12 @@ connectors:
|
|||
# all the claims requested.
|
||||
# https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
# getUserInfo: true
|
||||
|
||||
# The set claim is used as user id.
|
||||
# Default: sub
|
||||
# Claims list at https://openid.net/specs/openid-connect-core-1_0.html#Claims
|
||||
#
|
||||
# userIdKey: nickname
|
||||
```
|
||||
|
||||
[oidc-doc]: openid-connect.md
|
||||
|
|
|
@ -44,6 +44,9 @@ type Config struct {
|
|||
// the token. This is especially useful where upstreams return "thin"
|
||||
// id tokens
|
||||
GetUserInfo bool `json:"getUserInfo"`
|
||||
|
||||
// Configurable key which contains the user id claim
|
||||
UserIDKey string `json:"userIDKey"`
|
||||
}
|
||||
|
||||
// Domains that don't support basic auth. golang.org/x/oauth2 has an internal
|
||||
|
@ -127,6 +130,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
|
|||
hostedDomains: c.HostedDomains,
|
||||
insecureSkipEmailVerified: c.InsecureSkipEmailVerified,
|
||||
getUserInfo: c.GetUserInfo,
|
||||
userIDKey: c.UserIDKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -146,6 +150,7 @@ type oidcConnector struct {
|
|||
hostedDomains []string
|
||||
insecureSkipEmailVerified bool
|
||||
getUserInfo bool
|
||||
userIDKey string
|
||||
}
|
||||
|
||||
func (c *oidcConnector) Close() error {
|
||||
|
@ -199,33 +204,41 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
|
|||
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Username string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
HostedDomain string `json:"hd"`
|
||||
}
|
||||
var claims map[string]interface{}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
|
||||
}
|
||||
|
||||
name, found := claims["name"].(string)
|
||||
if !found {
|
||||
return identity, errors.New("missing \"name\" claim")
|
||||
}
|
||||
email, found := claims["email"].(string)
|
||||
if !found {
|
||||
return identity, errors.New("missing \"email\" claim")
|
||||
}
|
||||
emailVerified, found := claims["email_verified"].(bool)
|
||||
if !found {
|
||||
return identity, errors.New("missing \"email_verified\" claim")
|
||||
}
|
||||
hostedDomain, _ := claims["hd"].(string)
|
||||
|
||||
if len(c.hostedDomains) > 0 {
|
||||
found := false
|
||||
for _, domain := range c.hostedDomains {
|
||||
if claims.HostedDomain == domain {
|
||||
if hostedDomain == domain {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return identity, fmt.Errorf("oidc: unexpected hd claim %v", claims.HostedDomain)
|
||||
return identity, fmt.Errorf("oidc: unexpected hd claim %v", hostedDomain)
|
||||
}
|
||||
}
|
||||
|
||||
if c.insecureSkipEmailVerified {
|
||||
claims.EmailVerified = true
|
||||
|
||||
emailVerified = true
|
||||
}
|
||||
|
||||
if c.getUserInfo {
|
||||
|
@ -240,10 +253,19 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
|
|||
|
||||
identity = connector.Identity{
|
||||
UserID: idToken.Subject,
|
||||
Username: claims.Username,
|
||||
Email: claims.Email,
|
||||
EmailVerified: claims.EmailVerified,
|
||||
Username: name,
|
||||
Email: email,
|
||||
EmailVerified: emailVerified,
|
||||
}
|
||||
|
||||
if c.userIDKey != "" {
|
||||
userID, found := claims[c.userIDKey].(string)
|
||||
if !found {
|
||||
return identity, fmt.Errorf("oidc: not found %v claim", c.userIDKey)
|
||||
}
|
||||
identity.UserID = userID
|
||||
}
|
||||
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,24 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/connector"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestKnownBrokenAuthHeaderProvider(t *testing.T) {
|
||||
|
@ -23,3 +40,192 @@ func TestKnownBrokenAuthHeaderProvider(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleCallback(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userIDKey string
|
||||
expectUserID string
|
||||
}{
|
||||
{"simpleCase", "", "sub"},
|
||||
{"withUserIDKey", "name", "name"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testServer, err := setupServer()
|
||||
if err != nil {
|
||||
t.Fatal("failed to setup test server", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
serverURL := testServer.URL
|
||||
config := Config{
|
||||
Issuer: serverURL,
|
||||
ClientID: "clientID",
|
||||
ClientSecret: "clientSecret",
|
||||
Scopes: []string{"groups"},
|
||||
RedirectURI: fmt.Sprintf("%s/callback", serverURL),
|
||||
UserIDKey: tc.userIDKey,
|
||||
}
|
||||
|
||||
conn, err := newConnector(config)
|
||||
if err != nil {
|
||||
t.Fatal("failed to create new connector", err)
|
||||
}
|
||||
|
||||
req, err := newRequestWithAuthCode(testServer.URL, "someCode")
|
||||
if err != nil {
|
||||
t.Fatal("failed to create request", err)
|
||||
}
|
||||
|
||||
identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req)
|
||||
if err != nil {
|
||||
t.Fatal("handle callback failed", err)
|
||||
}
|
||||
|
||||
expectEquals(t, identity.UserID, tc.expectUserID)
|
||||
expectEquals(t, identity.Username, "name")
|
||||
expectEquals(t, identity.Email, "email")
|
||||
expectEquals(t, identity.EmailVerified, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupServer() (*httptest.Server, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate rsa key: %v", err)
|
||||
}
|
||||
|
||||
jwk := jose.JSONWebKey{
|
||||
Key: key,
|
||||
KeyID: "keyId",
|
||||
Algorithm: "RSA",
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(&map[string]interface{}{
|
||||
"keys": []map[string]interface{}{{
|
||||
"alg": jwk.Algorithm,
|
||||
"kty": jwk.Algorithm,
|
||||
"kid": jwk.KeyID,
|
||||
"n": n(&key.PublicKey),
|
||||
"e": e(&key.PublicKey),
|
||||
}},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
url := fmt.Sprintf("http://%s", r.Host)
|
||||
|
||||
token, err := newToken(&jwk, map[string]interface{}{
|
||||
"iss": url,
|
||||
"aud": "clientID",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"sub": "sub",
|
||||
"name": "name",
|
||||
"email": "email",
|
||||
"email_verified": true,
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(&map[string]string{
|
||||
"access_token": token,
|
||||
"id_token": token,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||
url := fmt.Sprintf("http://%s", r.Host)
|
||||
|
||||
json.NewEncoder(w).Encode(&map[string]string{
|
||||
"issuer": url,
|
||||
"token_endpoint": fmt.Sprintf("%s/token", url),
|
||||
"authorization_endpoint": fmt.Sprintf("%s/authorize", url),
|
||||
"userinfo_endpoint": fmt.Sprintf("%s/userinfo", url),
|
||||
"jwks_uri": fmt.Sprintf("%s/keys", url),
|
||||
})
|
||||
})
|
||||
|
||||
return httptest.NewServer(mux), nil
|
||||
}
|
||||
|
||||
func newToken(key *jose.JSONWebKey, claims map[string]interface{}) (string, error) {
|
||||
signingKey := jose.SigningKey{
|
||||
Key: key,
|
||||
Algorithm: jose.RS256,
|
||||
}
|
||||
|
||||
signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create new signer: %v", err)
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal claims: %v", err)
|
||||
}
|
||||
|
||||
signature, err := signer.Sign(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign: %v", err)
|
||||
}
|
||||
return signature.CompactSerialize()
|
||||
}
|
||||
|
||||
func newConnector(config Config) (*oidcConnector, error) {
|
||||
logger := logrus.New()
|
||||
conn, err := config.Open("id", logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open: %v", err)
|
||||
}
|
||||
|
||||
oidcConn, ok := conn.(*oidcConnector)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to convert to oidcConnector")
|
||||
}
|
||||
|
||||
return oidcConn, nil
|
||||
}
|
||||
|
||||
func newRequestWithAuthCode(serverURL string, code string) (*http.Request, error) {
|
||||
req, err := http.NewRequest("GET", serverURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
values := req.URL.Query()
|
||||
values.Add("code", code)
|
||||
req.URL.RawQuery = values.Encode()
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func n(pub *rsa.PublicKey) string {
|
||||
return encode(pub.N.Bytes())
|
||||
}
|
||||
|
||||
func e(pub *rsa.PublicKey) string {
|
||||
data := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(data, uint64(pub.E))
|
||||
return encode(bytes.TrimLeft(data, "\x00"))
|
||||
}
|
||||
|
||||
func encode(payload []byte) string {
|
||||
result := base64.URLEncoding.EncodeToString(payload)
|
||||
return strings.TrimRight(result, "=")
|
||||
}
|
||||
|
||||
func expectEquals(t *testing.T, a interface{}, b interface{}) {
|
||||
if !reflect.DeepEqual(a, b) {
|
||||
t.Errorf("Expected %+v to equal %+v", a, b)
|
||||
}
|
||||
}
|
||||
|
|
Reference in a new issue