diff --git a/db/refresh.go b/db/refresh.go index 0475d7eb..d16f313f 100644 --- a/db/refresh.go +++ b/db/refresh.go @@ -146,6 +146,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, if len(record.Scopes) > 0 { scopes = strings.Split(record.Scopes, " ") } + return record.UserID, scopes, nil } diff --git a/integration/user_api_test.go b/integration/user_api_test.go index c53c1330..fa9c4400 100644 --- a/integration/user_api_test.go +++ b/integration/user_api_test.go @@ -147,7 +147,6 @@ func makeUserAPITestFixtures() *userAPITestFixtures { } refreshRepo := db.NewRefreshTokenRepo(dbMap) - fmt.Println("DEFAULT: ", oidc.DefaultScope) for _, user := range userUsers { if _, err := refreshRepo.Create(user.User.ID, testClientID, append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil { diff --git a/refresh/repo.go b/refresh/repo.go index 7d94cd22..df81a426 100644 --- a/refresh/repo.go +++ b/refresh/repo.go @@ -41,7 +41,9 @@ func DefaultRefreshTokenGenerator() ([]byte, error) { type RefreshTokenRepo interface { // Create generates and returns a new refresh token for the given client-user pair. - // On success the token will be return. + // The scopes will be stored with the refresh token, and used to verify + // against future OIDC refresh requests' scopes. + // On success the token will be returned. Create(userID, clientID string, scope []string) (string, error) // Verify verifies that a token belongs to the client. diff --git a/server/cross_client_test.go b/server/cross_client_test.go index 4d1d4120..22d5b252 100644 --- a/server/cross_client_test.go +++ b/server/cross_client_test.go @@ -14,29 +14,24 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/coreos/dex/client" - clientmanager "github.com/coreos/dex/client/manager" "github.com/coreos/dex/connector" "github.com/coreos/dex/scope" ) func makeCrossClientTestFixtures() (*testFixtures, error) { - f, err := makeTestFixtures() - if err != nil { - return nil, fmt.Errorf("couldn't make test fixtures: %v", err) - } - + xClients := []client.LoadableClient{} for _, cliData := range []struct { - id string - authorized []string + id string + trustedPeers []string }{ { id: "client_a", }, { - id: "client_b", - authorized: []string{"client_a"}, + id: "client_b", + trustedPeers: []string{"client_a"}, }, { - id: "client_c", - authorized: []string{"client_a", "client_b"}, + id: "client_c", + trustedPeers: []string{"client_a", "client_b"}, }, } { u := url.URL{ @@ -44,20 +39,27 @@ func makeCrossClientTestFixtures() (*testFixtures, error) { Path: cliData.id, Host: cliData.id, } - cliCreds, err := f.clientManager.New(client.Client{ - Credentials: oidc.ClientCredentials{ - ID: cliData.id, + xClients = append(xClients, client.LoadableClient{ + Client: client.Client{ + Credentials: oidc.ClientCredentials{ + ID: cliData.id, + Secret: base64.URLEncoding.EncodeToString( + []byte(cliData.id + "_secret")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{u}, + }, }, - Metadata: oidc.ClientMetadata{ - RedirectURIs: []url.URL{u}, - }, - }, &clientmanager.ClientOptions{ - TrustedPeers: cliData.authorized, + TrustedPeers: cliData.trustedPeers, }) - if err != nil { - return nil, fmt.Errorf("Unexpected error creating clients: %v", err) - } - f.clientCreds[cliData.id] = *cliCreds + } + + xClients = append(xClients, testClients...) + f, err := makeTestFixturesWithOptions(testFixtureOptions{ + clients: xClients, + }) + if err != nil { + return nil, fmt.Errorf("couldn't make test fixtures: %v", err) } return f, nil } diff --git a/server/server.go b/server/server.go index c68759bc..a151f0b3 100644 --- a/server/server.go +++ b/server/server.go @@ -445,35 +445,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo claims := ses.Claims(s.IssuerURL.String()) user.AddToClaims(claims) - crossClientIDs := ses.Scope.CrossClientIDs() - if len(crossClientIDs) > 0 { - var aud []string - for _, id := range crossClientIDs { - if ses.ClientID == id { - aud = append(aud, id) - continue - } - allowed, err := s.CrossClientAuthAllowed(ses.ClientID, id) - if err != nil { - log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", ses.ClientID, id, err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) - } - if !allowed { - err := oauth2.NewError(oauth2.ErrorInvalidRequest) - err.Description = fmt.Sprintf( - "%q is not authorized to perform cross-client requests for %q", - ses.ClientID, id) - return nil, "", err - } - aud = append(aud, id) - } - if len(aud) == 1 { - claims.Add("aud", aud[0]) - } else { - claims.Add("aud", aud) - } - claims.Add("azp", ses.ClientID) - } + s.addClaimsFromScope(claims, ses.Scope, ses.ClientID) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { @@ -555,6 +527,8 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt) user.AddToClaims(claims) + s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID) + jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) @@ -596,6 +570,41 @@ func (s *Server) JWTVerifierFactory() JWTVerifierFactory { } } +// addClaimsFromScope adds claims that are based on the scopes that the client requested. +// Currently, these include cross-client claims (aud, azp). +func (s *Server) addClaimsFromScope(claims jose.Claims, scopes scope.Scopes, clientID string) error { + crossClientIDs := scopes.CrossClientIDs() + if len(crossClientIDs) > 0 { + var aud []string + for _, id := range crossClientIDs { + if clientID == id { + aud = append(aud, id) + continue + } + allowed, err := s.CrossClientAuthAllowed(clientID, id) + if err != nil { + log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", clientID, id, err) + return oauth2.NewError(oauth2.ErrorServerError) + } + if !allowed { + err := oauth2.NewError(oauth2.ErrorInvalidRequest) + err.Description = fmt.Sprintf( + "%q is not authorized to perform cross-client requests for %q", + clientID, id) + return err + } + aud = append(aud, id) + } + if len(aud) == 1 { + claims.Add("aud", aud[0]) + } else { + claims.Add("aud", aud) + } + claims.Add("azp", clientID) + } + return nil +} + type sortableIDPCs []connector.Connector func (s sortableIDPCs) Len() int { diff --git a/server/server_test.go b/server/server_test.go index 73ff72fd..44a7f2ed 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -18,6 +18,7 @@ import ( "github.com/coreos/dex/client" "github.com/coreos/dex/db" "github.com/coreos/dex/refresh/refreshtest" + "github.com/coreos/dex/scope" "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" ) @@ -490,117 +491,191 @@ func TestServerRefreshToken(t *testing.T) { signer jose.Signer createScopes []string refreshScopes []string + expectedAud []string err error }{ // Everything is good. { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - testClientCredentials, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - nil, + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, }, // Asking for a scope not originally granted to you. { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - testClientCredentials, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile", "extra_scope"}, - oauth2.NewError(oauth2.ErrorInvalidRequest), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile", "extra_scope"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(malformatted). { - "invalid-token", - testClientID, - testClientCredentials, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - oauth2.NewError(oauth2.ErrorInvalidRequest), + token: "invalid-token", + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid payload content). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), - testClientID, - testClientCredentials, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - oauth2.NewError(oauth2.ErrorInvalidRequest), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid ID content). { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - testClientCredentials, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - oauth2.NewError(oauth2.ErrorInvalidRequest), + token: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid client(client is not associated with the token). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - clientB.Credentials, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: clientB.Credentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no client ID). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - oidc.ClientCredentials{ID: "", Secret: "aaa"}, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: oidc.ClientCredentials{ID: "", Secret: "aaa"}, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no such client). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no secrets). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - oidc.ClientCredentials{ID: testClientID}, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: oidc.ClientCredentials{ID: testClientID}, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(invalid secret). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, - signerFixture, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Signing operation fails. { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - testClientCredentials, - &StaticSigner{sig: nil, err: errors.New("fail")}, - []string{"openid", "profile"}, - []string{"openid", "profile"}, - oauth2.NewError(oauth2.ErrorServerError), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: testClientCredentials, + signer: &StaticSigner{sig: nil, err: errors.New("fail")}, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorServerError), + }, + // Valid Cross-Client + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + expectedAud: []string{"client_b"}, + }, + // Valid Cross-Client - but this time we leave out the scopes in the + // refresh request, which should result in the original stored scopes + // being used. + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + refreshScopes: []string{}, + expectedAud: []string{"client_b"}, + }, + // Valid Cross-Client - asking for fewer scopes than originally used + // when creating the refresh token, which is ok. + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"}, + refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + expectedAud: []string{"client_b"}, + }, + // Valid Cross-Client - asking for multiple clients in the audience. + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"}, + refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"}, + expectedAud: []string{"client_b", "client_c"}, + }, + // Invalid Cross-Client - didn't orignally request cross-client when + // refresh token was created. + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, } @@ -608,7 +683,7 @@ func TestServerRefreshToken(t *testing.T) { km := &StaticKeyManager{ signer: tt.signer, } - f, err := makeTestFixtures() + f, err := makeCrossClientTestFixtures() if err != nil { t.Fatalf("error making test fixtures: %v", err) } @@ -637,8 +712,27 @@ func TestServerRefreshToken(t *testing.T) { if err != nil { t.Errorf("Case %d: unexpected error: %v", i, err) } - if claims["iss"] != testIssuerURL.String() || claims["sub"] != testUserID1 || claims["aud"] != testClientID { - t.Errorf("Case %d: invalid claims: %v", i, claims) + + var expectedAud interface{} + if tt.expectedAud == nil { + expectedAud = testClientID + } else if len(tt.expectedAud) == 1 { + expectedAud = tt.expectedAud[0] + } else { + expectedAud = tt.expectedAud + } + + if claims["iss"] != testIssuerURL.String() { + t.Errorf("Case %d: want=%v, got=%v", i, + testIssuerURL.String(), claims["iss"]) + } + if claims["sub"] != testUserID1 { + t.Errorf("Case %d: want=%v, got=%v", i, + testUserID1, claims["sub"]) + } + if diff := pretty.Compare(claims["aud"], expectedAud); diff != "" { + t.Errorf("Case %d: want=%v, got=%v", i, + expectedAud, claims["aud"]) } } } diff --git a/server/testutil.go b/server/testutil.go index 06979b03..2155d70d 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -39,6 +39,18 @@ var ( ID: testClientID, Secret: clientTestSecret, } + testClients = []client.LoadableClient{ + { + Client: client.Client{ + Credentials: testClientCredentials, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + testRedirectURL, + }, + }, + }, + }, + } testConnectorID1 = "IDPC-1" @@ -169,18 +181,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err var clients []client.LoadableClient if options.clients == nil { - clients = []client.LoadableClient{ - { - Client: client.Client{ - Credentials: testClientCredentials, - Metadata: oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - testRedirectURL, - }, - }, - }, - }, - } + clients = testClients } else { clients = options.clients } @@ -247,6 +248,10 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err srv.absURL(httpPathAcceptInvitation), ) + clientCreds := map[string]oidc.ClientCredentials{} + for _, c := range clients { + clientCreds[c.Client.Credentials.ID] = c.Client.Credentials + } return &testFixtures{ srv: srv, redirectURL: testRedirectURL, @@ -255,9 +260,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err emailer: emailer, clientRepo: clientRepo, clientManager: clientManager, - clientCreds: map[string]oidc.ClientCredentials{ - testClientID: testClientCreds, - }, + clientCreds: clientCreds, }, nil }