Add support for IDPs that do not send ID tokens in the reply when using a refresh grant. Add tests for the aforementioned functionality.

Signed-off-by: Anthony Brandelli <abrandel@cisco.com>
This commit is contained in:
Anthony Brandelli 2022-05-19 22:13:10 -06:00
parent 9cd29bdee0
commit 7c335e9337
2 changed files with 129 additions and 24 deletions

View file

@ -226,6 +226,13 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
type caller uint
const (
createCaller caller = iota
refreshCaller
)
func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
@ -235,8 +242,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}
return c.createIdentity(r.Context(), identity, token)
return c.createIdentity(r.Context(), identity, token, createCaller)
}
// Refresh is used to refresh a session with the refresh token provided by the IdP
@ -255,23 +261,25 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit
if err != nil {
return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err)
}
return c.createIdentity(ctx, identity, token)
return c.createIdentity(ctx, identity, token, refreshCaller)
}
func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return identity, errors.New("oidc: no id_token in token response")
}
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}
func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) {
var claims map[string]interface{}
if err := idToken.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
rawIDToken, ok := token.Extra("id_token").(string)
if ok {
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}
if err := idToken.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}
} else if caller != refreshCaller {
// ID tokens aren't mandatory in the reply when using a refresh_token grant
return identity, errors.New("oidc: no id_token in token response")
}
// We immediately want to run getUserInfo if configured before we validate the claims
@ -285,6 +293,12 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I
}
}
const subjectClaimKey = "sub"
subject, found := claims[subjectClaimKey].(string)
if !found {
return identity, fmt.Errorf("missing \"%s\" claim", subjectClaimKey)
}
userNameKey := "name"
if c.userNameKey != "" {
userNameKey = c.userNameKey
@ -358,7 +372,7 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I
}
identity = connector.Identity{
UserID: idToken.Subject,
UserID: subject,
Username: name,
PreferredUsername: preferredUsername,
Email: email,

View file

@ -275,7 +275,8 @@ func TestHandleCallback(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
testServer, err := setupServer(tc.token)
idTokenDesired := true
testServer, err := setupServer(tc.token, idTokenDesired)
if err != nil {
t.Fatal("failed to setup test server", err)
}
@ -331,7 +332,87 @@ func TestHandleCallback(t *testing.T) {
}
}
func setupServer(tok map[string]interface{}) (*httptest.Server, error) {
func TestRefresh(t *testing.T) {
t.Helper()
tests := []struct {
name string
expectUserID string
expectUserName string
idTokenDesired bool
token map[string]interface{}
}{
{
name: "IDTokenOnRefresh",
expectUserID: "subvalue",
expectUserName: "namevalue",
idTokenDesired: true,
token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
},
},
{
name: "NoIDTokenOnRefresh",
expectUserID: "subvalue",
expectUserName: "namevalue",
idTokenDesired: false,
token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
testServer, err := setupServer(tc.token, tc.idTokenDesired)
if err != nil {
t.Fatal("failed to setup test server", err)
}
defer testServer.Close()
scopes := []string{"openid", "offline_access"}
serverURL := testServer.URL
config := Config{
Issuer: serverURL,
ClientID: "clientID",
ClientSecret: "clientSecret",
Scopes: scopes,
RedirectURI: fmt.Sprintf("%s/callback", serverURL),
GetUserInfo: true,
}
conn, err := newConnector(config)
if err != nil {
t.Fatal("failed to create new connector", err)
}
req, err := newRequestWithAuthCode(testServer.URL, "someCode")
if err != nil {
t.Fatal("failed to create request", err)
}
refreshTokenStr := "{\"RefreshToken\":\"asdf\"}"
refreshToken := []byte(refreshTokenStr)
identity := connector.Identity{
UserID: tc.expectUserID,
Username: tc.expectUserName,
ConnectorData: refreshToken,
}
refreshIdentity, err := conn.Refresh(req.Context(), connector.Scopes{OfflineAccess: true}, identity)
if err != nil {
t.Fatal("Refresh failed", err)
}
expectEquals(t, refreshIdentity.UserID, tc.expectUserID)
expectEquals(t, refreshIdentity.Username, tc.expectUserName)
})
}
}
func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
return nil, fmt.Errorf("failed to generate rsa key: %v", err)
@ -368,11 +449,21 @@ func setupServer(tok map[string]interface{}) (*httptest.Server, error) {
}
w.Header().Add("Content-Type", "application/json")
json.NewEncoder(w).Encode(&map[string]string{
"access_token": token,
"id_token": token,
"token_type": "Bearer",
})
if idTokenDesired {
json.NewEncoder(w).Encode(&map[string]string{
"access_token": token,
"id_token": token,
"token_type": "Bearer"})
} else {
json.NewEncoder(w).Encode(&map[string]string{
"access_token": token,
"token_type": "Bearer"})
}
})
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
json.NewEncoder(w).Encode(tok)
})
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {