diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 5e995d1b..78557a5a 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -226,6 +226,13 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } +type caller uint + +const ( + createCaller caller = iota + refreshCaller +) + func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { @@ -235,8 +242,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide if err != nil { return identity, fmt.Errorf("oidc: failed to get token: %v", err) } - - return c.createIdentity(r.Context(), identity, token) + return c.createIdentity(r.Context(), identity, token, createCaller) } // Refresh is used to refresh a session with the refresh token provided by the IdP @@ -255,23 +261,25 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit if err != nil { return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err) } - - return c.createIdentity(ctx, identity, token) + return c.createIdentity(ctx, identity, token, refreshCaller) } -func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) { - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - return identity, errors.New("oidc: no id_token in token response") - } - idToken, err := c.verifier.Verify(ctx, rawIDToken) - if err != nil { - return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) - } - +func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) { var claims map[string]interface{} - if err := idToken.Claims(&claims); err != nil { - return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) + + rawIDToken, ok := token.Extra("id_token").(string) + if ok { + idToken, err := c.verifier.Verify(ctx, rawIDToken) + if err != nil { + return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) + } + + if err := idToken.Claims(&claims); err != nil { + return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) + } + } else if caller != refreshCaller { + // ID tokens aren't mandatory in the reply when using a refresh_token grant + return identity, errors.New("oidc: no id_token in token response") } // We immediately want to run getUserInfo if configured before we validate the claims @@ -285,6 +293,12 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I } } + const subjectClaimKey = "sub" + subject, found := claims[subjectClaimKey].(string) + if !found { + return identity, fmt.Errorf("missing \"%s\" claim", subjectClaimKey) + } + userNameKey := "name" if c.userNameKey != "" { userNameKey = c.userNameKey @@ -358,7 +372,7 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I } identity = connector.Identity{ - UserID: idToken.Subject, + UserID: subject, Username: name, PreferredUsername: preferredUsername, Email: email, diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 3038cebc..66764cd7 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -275,7 +275,8 @@ func TestHandleCallback(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - testServer, err := setupServer(tc.token) + idTokenDesired := true + testServer, err := setupServer(tc.token, idTokenDesired) if err != nil { t.Fatal("failed to setup test server", err) } @@ -331,7 +332,87 @@ func TestHandleCallback(t *testing.T) { } } -func setupServer(tok map[string]interface{}) (*httptest.Server, error) { +func TestRefresh(t *testing.T) { + t.Helper() + + tests := []struct { + name string + expectUserID string + expectUserName string + idTokenDesired bool + token map[string]interface{} + }{ + { + name: "IDTokenOnRefresh", + expectUserID: "subvalue", + expectUserName: "namevalue", + idTokenDesired: true, + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + }, + }, + { + name: "NoIDTokenOnRefresh", + expectUserID: "subvalue", + expectUserName: "namevalue", + idTokenDesired: false, + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testServer, err := setupServer(tc.token, tc.idTokenDesired) + if err != nil { + t.Fatal("failed to setup test server", err) + } + defer testServer.Close() + + scopes := []string{"openid", "offline_access"} + serverURL := testServer.URL + config := Config{ + Issuer: serverURL, + ClientID: "clientID", + ClientSecret: "clientSecret", + Scopes: scopes, + RedirectURI: fmt.Sprintf("%s/callback", serverURL), + GetUserInfo: true, + } + + conn, err := newConnector(config) + if err != nil { + t.Fatal("failed to create new connector", err) + } + + req, err := newRequestWithAuthCode(testServer.URL, "someCode") + if err != nil { + t.Fatal("failed to create request", err) + } + + refreshTokenStr := "{\"RefreshToken\":\"asdf\"}" + refreshToken := []byte(refreshTokenStr) + + identity := connector.Identity{ + UserID: tc.expectUserID, + Username: tc.expectUserName, + ConnectorData: refreshToken, + } + + refreshIdentity, err := conn.Refresh(req.Context(), connector.Scopes{OfflineAccess: true}, identity) + if err != nil { + t.Fatal("Refresh failed", err) + } + + expectEquals(t, refreshIdentity.UserID, tc.expectUserID) + expectEquals(t, refreshIdentity.Username, tc.expectUserName) + }) + } +} + +func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) { key, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { return nil, fmt.Errorf("failed to generate rsa key: %v", err) @@ -368,11 +449,21 @@ func setupServer(tok map[string]interface{}) (*httptest.Server, error) { } w.Header().Add("Content-Type", "application/json") - json.NewEncoder(w).Encode(&map[string]string{ - "access_token": token, - "id_token": token, - "token_type": "Bearer", - }) + if idTokenDesired { + json.NewEncoder(w).Encode(&map[string]string{ + "access_token": token, + "id_token": token, + "token_type": "Bearer"}) + } else { + json.NewEncoder(w).Encode(&map[string]string{ + "access_token": token, + "token_type": "Bearer"}) + } + }) + + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + json.NewEncoder(w).Encode(tok) }) mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {