forked from mystiq/dex
Merge pull request #2522 from Blorpy/oidc_refresh_token
OIDC connector: Support cases where there is no id_token when using a refresh_token grant
This commit is contained in:
commit
c74ad3bb66
2 changed files with 131 additions and 24 deletions
|
@ -226,6 +226,13 @@ func (e *oauth2Error) Error() string {
|
||||||
return e.error + ": " + e.errorDescription
|
return e.error + ": " + e.errorDescription
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type caller uint
|
||||||
|
|
||||||
|
const (
|
||||||
|
createCaller caller = iota
|
||||||
|
refreshCaller
|
||||||
|
)
|
||||||
|
|
||||||
func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
|
func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
if errType := q.Get("error"); errType != "" {
|
if errType := q.Get("error"); errType != "" {
|
||||||
|
@ -235,8 +242,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
|
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
|
||||||
}
|
}
|
||||||
|
return c.createIdentity(r.Context(), identity, token, createCaller)
|
||||||
return c.createIdentity(r.Context(), identity, token)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh is used to refresh a session with the refresh token provided by the IdP
|
// Refresh is used to refresh a session with the refresh token provided by the IdP
|
||||||
|
@ -255,23 +261,25 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err)
|
return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err)
|
||||||
}
|
}
|
||||||
|
return c.createIdentity(ctx, identity, token, refreshCaller)
|
||||||
return c.createIdentity(ctx, identity, token)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) {
|
func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) {
|
||||||
rawIDToken, ok := token.Extra("id_token").(string)
|
|
||||||
if !ok {
|
|
||||||
return identity, errors.New("oidc: no id_token in token response")
|
|
||||||
}
|
|
||||||
idToken, err := c.verifier.Verify(ctx, rawIDToken)
|
|
||||||
if err != nil {
|
|
||||||
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var claims map[string]interface{}
|
var claims map[string]interface{}
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
|
||||||
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
|
if ok {
|
||||||
|
idToken, err := c.verifier.Verify(ctx, rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
|
||||||
|
}
|
||||||
|
} else if caller != refreshCaller {
|
||||||
|
// ID tokens aren't mandatory in the reply when using a refresh_token grant
|
||||||
|
return identity, errors.New("oidc: no id_token in token response")
|
||||||
}
|
}
|
||||||
|
|
||||||
// We immediately want to run getUserInfo if configured before we validate the claims
|
// We immediately want to run getUserInfo if configured before we validate the claims
|
||||||
|
@ -285,6 +293,12 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const subjectClaimKey = "sub"
|
||||||
|
subject, found := claims[subjectClaimKey].(string)
|
||||||
|
if !found {
|
||||||
|
return identity, fmt.Errorf("missing \"%s\" claim", subjectClaimKey)
|
||||||
|
}
|
||||||
|
|
||||||
userNameKey := "name"
|
userNameKey := "name"
|
||||||
if c.userNameKey != "" {
|
if c.userNameKey != "" {
|
||||||
userNameKey = c.userNameKey
|
userNameKey = c.userNameKey
|
||||||
|
@ -358,7 +372,7 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I
|
||||||
}
|
}
|
||||||
|
|
||||||
identity = connector.Identity{
|
identity = connector.Identity{
|
||||||
UserID: idToken.Subject,
|
UserID: subject,
|
||||||
Username: name,
|
Username: name,
|
||||||
PreferredUsername: preferredUsername,
|
PreferredUsername: preferredUsername,
|
||||||
Email: email,
|
Email: email,
|
||||||
|
|
|
@ -275,7 +275,8 @@ func TestHandleCallback(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
testServer, err := setupServer(tc.token)
|
idTokenDesired := true
|
||||||
|
testServer, err := setupServer(tc.token, idTokenDesired)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to setup test server", err)
|
t.Fatal("failed to setup test server", err)
|
||||||
}
|
}
|
||||||
|
@ -331,7 +332,87 @@ func TestHandleCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupServer(tok map[string]interface{}) (*httptest.Server, error) {
|
func TestRefresh(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expectUserID string
|
||||||
|
expectUserName string
|
||||||
|
idTokenDesired bool
|
||||||
|
token map[string]interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IDTokenOnRefresh",
|
||||||
|
expectUserID: "subvalue",
|
||||||
|
expectUserName: "namevalue",
|
||||||
|
idTokenDesired: true,
|
||||||
|
token: map[string]interface{}{
|
||||||
|
"sub": "subvalue",
|
||||||
|
"name": "namevalue",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NoIDTokenOnRefresh",
|
||||||
|
expectUserID: "subvalue",
|
||||||
|
expectUserName: "namevalue",
|
||||||
|
idTokenDesired: false,
|
||||||
|
token: map[string]interface{}{
|
||||||
|
"sub": "subvalue",
|
||||||
|
"name": "namevalue",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testServer, err := setupServer(tc.token, tc.idTokenDesired)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to setup test server", err)
|
||||||
|
}
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
scopes := []string{"openid", "offline_access"}
|
||||||
|
serverURL := testServer.URL
|
||||||
|
config := Config{
|
||||||
|
Issuer: serverURL,
|
||||||
|
ClientID: "clientID",
|
||||||
|
ClientSecret: "clientSecret",
|
||||||
|
Scopes: scopes,
|
||||||
|
RedirectURI: fmt.Sprintf("%s/callback", serverURL),
|
||||||
|
GetUserInfo: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := newConnector(config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to create new connector", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := newRequestWithAuthCode(testServer.URL, "someCode")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to create request", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshTokenStr := "{\"RefreshToken\":\"asdf\"}"
|
||||||
|
refreshToken := []byte(refreshTokenStr)
|
||||||
|
|
||||||
|
identity := connector.Identity{
|
||||||
|
UserID: tc.expectUserID,
|
||||||
|
Username: tc.expectUserName,
|
||||||
|
ConnectorData: refreshToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshIdentity, err := conn.Refresh(req.Context(), connector.Scopes{OfflineAccess: true}, identity)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Refresh failed", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectEquals(t, refreshIdentity.UserID, tc.expectUserID)
|
||||||
|
expectEquals(t, refreshIdentity.Username, tc.expectUserName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) {
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate rsa key: %v", err)
|
return nil, fmt.Errorf("failed to generate rsa key: %v", err)
|
||||||
|
@ -368,11 +449,23 @@ func setupServer(tok map[string]interface{}) (*httptest.Server, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Add("Content-Type", "application/json")
|
w.Header().Add("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(&map[string]string{
|
if idTokenDesired {
|
||||||
"access_token": token,
|
json.NewEncoder(w).Encode(&map[string]string{
|
||||||
"id_token": token,
|
"access_token": token,
|
||||||
"token_type": "Bearer",
|
"id_token": token,
|
||||||
})
|
"token_type": "Bearer",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
json.NewEncoder(w).Encode(&map[string]string{
|
||||||
|
"access_token": token,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Add("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(tok)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
Loading…
Reference in a new issue