diff --git a/db/migrate_sqlite3.go b/db/migrate_sqlite3.go index 6e2142a4..523cce64 100644 --- a/db/migrate_sqlite3.go +++ b/db/migrate_sqlite3.go @@ -39,7 +39,8 @@ CREATE TABLE refresh_token ( id integer PRIMARY KEY, payload_hash blob, user_id text, - client_id text + client_id text, + scopes text ); CREATE TABLE remote_identity_mapping ( diff --git a/db/migrations/0013_add_scopes_to_refresh_tokens.sql b/db/migrations/0013_add_scopes_to_refresh_tokens.sql new file mode 100644 index 00000000..fa383ea9 --- /dev/null +++ b/db/migrations/0013_add_scopes_to_refresh_tokens.sql @@ -0,0 +1,4 @@ +-- +migrate Up +ALTER TABLE refresh_token ADD COLUMN "scopes" text; + +UPDATE refresh_token SET scopes = 'openid profile email offline_access'; diff --git a/db/migrations/assets.go b/db/migrations/assets.go index e0d995b4..6798acc1 100644 --- a/db/migrations/assets.go +++ b/db/migrations/assets.go @@ -78,5 +78,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{ "-- +migrate Up\nCREATE TABLE IF NOT EXISTS \"trusted_peers\" (\n \"client_id\" text not null,\n \"trusted_client_id\" text not null,\n primary key (\"client_id\", \"trusted_client_id\")) ;\n", }, }, + { + Id: "0013_add_scopes_to_refresh_tokens.sql", + Up: []string{ + "-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n", + }, + }, }, } diff --git a/db/refresh.go b/db/refresh.go index 8ebc9ce6..d16f313f 100644 --- a/db/refresh.go +++ b/db/refresh.go @@ -15,6 +15,7 @@ import ( "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/refresh" "github.com/coreos/dex/repo" + "github.com/coreos/dex/scope" ) const ( @@ -38,10 +39,9 @@ type refreshTokenRepo struct { type refreshTokenModel struct { ID int64 `db:"id"` PayloadHash []byte `db:"payload_hash"` - // TODO(yifan): Use some sort of foreign key to manage database level - // data integrity. - UserID string `db:"user_id"` - ClientID string `db:"client_id"` + UserID string `db:"user_id"` + ClientID string `db:"client_id"` + Scopes string `db:"scopes"` } // buildToken combines the token ID and token payload to create a new token. @@ -89,7 +89,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG } } -func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { +func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) { if userID == "" { return "", refresh.ErrorInvalidUserID } @@ -112,6 +112,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { PayloadHash: payloadHash, UserID: userID, ClientID: clientID, + Scopes: strings.Join(scopes, " "), } if err := r.executor(nil).Insert(record); err != nil { @@ -121,27 +122,32 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { return buildToken(record.ID, tokenPayload), nil } -func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) { +func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) { tokenID, tokenPayload, err := parseToken(token) if err != nil { - return "", err + return "", nil, err } record, err := r.get(nil, tokenID) if err != nil { - return "", err + return "", nil, err } if record.ClientID != clientID { - return "", refresh.ErrorInvalidClientID + return "", nil, refresh.ErrorInvalidClientID } if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { - return "", err + return "", nil, err } - return record.UserID, nil + var scopes []string + if len(record.Scopes) > 0 { + scopes = strings.Split(record.Scopes, " ") + } + + return record.UserID, scopes, nil } func (r *refreshTokenRepo) Revoke(userID, token string) error { @@ -190,7 +196,6 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli q := `SELECT c.* FROM %s as c INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;` q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName)) - var clients []clientModel if _, err := r.executor(nil).Select(&clients, q, userID); err != nil { return nil, err @@ -206,6 +211,7 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli // Do not share the secret. c[i].Credentials.Secret = "" } + return c, nil } diff --git a/functional/db_test.go b/functional/db_test.go index cc78d8a0..af6d0583 100644 --- a/functional/db_test.go +++ b/functional/db_test.go @@ -1,7 +1,6 @@ package functional import ( - "encoding/base64" "fmt" "net/url" "os" @@ -16,7 +15,6 @@ import ( "github.com/coreos/dex/client" "github.com/coreos/dex/client/manager" "github.com/coreos/dex/db" - "github.com/coreos/dex/refresh" "github.com/coreos/dex/session" ) @@ -411,207 +409,3 @@ func TestDBClientAll(t *testing.T) { t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count) } } - -// buildRefreshToken combines the token ID and token payload to create a new token. -// used in the tests to created a refresh token. -func buildRefreshToken(tokenID int64, tokenPayload []byte) string { - return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload)) -} - -func TestDBRefreshRepoCreate(t *testing.T) { - r := db.NewRefreshTokenRepo(connect(t)) - - tests := []struct { - userID string - clientID string - err error - }{ - { - "", - "client-foo", - refresh.ErrorInvalidUserID, - }, - { - "user-foo", - "", - refresh.ErrorInvalidClientID, - }, - { - "user-foo", - "client-foo", - nil, - }, - } - - for i, tt := range tests { - token, err := r.Create(tt.userID, tt.clientID) - if err != nil { - if tt.err == nil { - t.Errorf("case %d: create failed: %v", i, err) - } - continue - } - if tt.err != nil { - t.Errorf("case %d: expected error, didn't get one", i) - continue - } - userID, err := r.Verify(tt.clientID, token) - if err != nil { - t.Errorf("case %d: failed to verify good token: %v", i, err) - continue - } - if userID != tt.userID { - t.Errorf("case %d: want userID=%s, got userID=%s", i, tt.userID, userID) - } - } -} - -func TestDBRefreshRepoVerify(t *testing.T) { - r := db.NewRefreshTokenRepo(connect(t)) - - token, err := r.Create("user-foo", "client-foo") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - badTokenPayload, err := refresh.DefaultRefreshTokenGenerator() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - tokenWithBadID := "404" + token[1:] - tokenWithBadPayload := buildRefreshToken(1, badTokenPayload) - - tests := []struct { - token string - creds oidc.ClientCredentials - err error - expected string - }{ - { - "invalid-token-format", - oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, - refresh.ErrorInvalidToken, - "", - }, - { - "b/invalid-base64-encoded-format", - oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, - refresh.ErrorInvalidToken, - "", - }, - { - "1/invalid-base64-encoded-format", - oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, - refresh.ErrorInvalidToken, - "", - }, - { - token + "corrupted-token-payload", - oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, - refresh.ErrorInvalidToken, - "", - }, - { - // The token's ID content is invalid. - tokenWithBadID, - oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, - refresh.ErrorInvalidToken, - "", - }, - { - // The token's payload content is invalid. - tokenWithBadPayload, - oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, - refresh.ErrorInvalidToken, - "", - }, - { - token, - oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"}, - refresh.ErrorInvalidClientID, - "", - }, - { - token, - oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, - nil, - "user-foo", - }, - } - - for i, tt := range tests { - result, err := r.Verify(tt.creds.ID, tt.token) - if err != tt.err { - t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) - } - if result != tt.expected { - t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result) - } - } -} - -func TestDBRefreshRepoRevoke(t *testing.T) { - r := db.NewRefreshTokenRepo(connect(t)) - - token, err := r.Create("user-foo", "client-foo") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - badTokenPayload, err := refresh.DefaultRefreshTokenGenerator() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - tokenWithBadID := "404" + token[1:] - tokenWithBadPayload := buildRefreshToken(1, badTokenPayload) - - tests := []struct { - token string - userID string - err error - }{ - { - "invalid-token-format", - "user-foo", - refresh.ErrorInvalidToken, - }, - { - "1/invalid-base64-encoded-format", - "user-foo", - refresh.ErrorInvalidToken, - }, - { - token + "corrupted-token-payload", - "user-foo", - refresh.ErrorInvalidToken, - }, - { - // The token's ID is invalid. - tokenWithBadID, - "user-foo", - refresh.ErrorInvalidToken, - }, - { - // The token's payload is invalid. - tokenWithBadPayload, - "user-foo", - refresh.ErrorInvalidToken, - }, - { - token, - "invalid-user", - refresh.ErrorInvalidUserID, - }, - { - token, - "user-foo", - nil, - }, - } - - for i, tt := range tests { - if err := r.Revoke(tt.userID, tt.token); err != tt.err { - t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) - } - } -} diff --git a/functional/repo/refresh_repo_test.go b/functional/repo/refresh_repo_test.go index 9df9aa8a..f18d0f09 100644 --- a/functional/repo/refresh_repo_test.go +++ b/functional/repo/refresh_repo_test.go @@ -2,13 +2,13 @@ package repo import ( "encoding/base64" + "fmt" "net/url" - "os" + "sort" "testing" "time" "github.com/coreos/go-oidc/oidc" - "github.com/go-gorp/gorp" "github.com/kylelemons/godebug/pretty" "github.com/coreos/dex/client" @@ -17,40 +17,43 @@ import ( "github.com/coreos/dex/user" ) -func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.Client) refresh.RefreshTokenRepo { - var dbMap *gorp.DbMap - if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" { - dbMap = db.NewMemDB() - } else { - dbMap = connect(t) - } - if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil { - t.Fatalf("Unable to add users: %v", err) - } - - return db.NewRefreshTokenRepo(dbMap) -} - -func TestRefreshTokenRepo(t *testing.T) { - clientID := "client1" - userID := "user1" - clients := []client.Client{ +var ( + testRefreshClientID = "client1" + testRefreshClientID2 = "client2" + testRefreshClients = []client.LoadableClient{ { - Credentials: oidc.ClientCredentials{ - ID: clientID, - Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), + Client: client.Client{ + Credentials: oidc.ClientCredentials{ + ID: testRefreshClientID, + Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"}, + }, + }, }, - Metadata: oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"}, + }, + { + Client: client.Client{ + Credentials: oidc.ClientCredentials{ + ID: testRefreshClientID2, + Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + url.URL{Scheme: "https", Host: "client2.example.com", Path: "/callback"}, + }, }, }, }, } - users := []user.UserWithRemoteIdentities{ + + testRefreshUserID = "user1" + testRefreshUsers = []user.UserWithRemoteIdentities{ { User: user.User{ - ID: userID, + ID: testRefreshUserID, Email: "Email-1@example.com", CreatedAt: time.Now().Truncate(time.Second), }, @@ -62,31 +65,318 @@ func TestRefreshTokenRepo(t *testing.T) { }, }, } +) - repo := newRefreshRepo(t, users, clients) - tok, err := repo.Create(userID, clientID) - if err != nil { - t.Fatalf("failed to create refresh token: %v", err) - } - if tokUserID, err := repo.Verify(clientID, tok); err != nil { - t.Errorf("Could not verify token: %v", err) - } else if tokUserID != userID { - t.Errorf("Verified token returned wrong user id, want=%s, got=%s", userID, tokUserID) +func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.LoadableClient) refresh.RefreshTokenRepo { + dbMap := connect(t) + if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil { + t.Fatalf("Unable to add users: %v", err) } - if userClients, err := repo.ClientsWithRefreshTokens(userID); err != nil { - t.Errorf("Failed to get the list of clients the user was logged into: %v", err) - } else { - if diff := pretty.Compare(userClients, clients); diff == "" { - t.Errorf("Clients user logged into: want did not equal got %s", diff) + if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil { + t.Fatalf("Unable to add clients: %v", err) + } + + return db.NewRefreshTokenRepo(dbMap) +} + +func TestRefreshTokenRepoCreateVerify(t *testing.T) { + tests := []struct { + createScopes []string + verifyClientID string + wantVerifyErr bool + }{ + { + createScopes: []string{"openid", "profile"}, + verifyClientID: testRefreshClientID, + }, + { + createScopes: []string{}, + verifyClientID: testRefreshClientID, + }, + { + createScopes: []string{"openid", "profile"}, + verifyClientID: "not-a-client", + wantVerifyErr: true, + }, + } + + for i, tt := range tests { + repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) + tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes) + if err != nil { + t.Fatalf("case %d: failed to create refresh token: %v", i, err) + } + + tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok) + if tt.wantVerifyErr { + if err == nil { + t.Errorf("case %d: want non-nil error.", i) + } + continue + } + + if diff := pretty.Compare(tt.createScopes, gotScopes); diff != "" { + t.Errorf("case %d: Compare(want, got): %v", i, diff) + } + + if err != nil { + t.Errorf("case %d: Could not verify token: %v", i, err) + } else if tokUserID != testRefreshUserID { + t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i, + testRefreshUserID, tokUserID) } } +} - if err := repo.RevokeTokensForClient(userID, clientID); err != nil { - t.Errorf("Failed to revoke refresh token: %v", err) +// buildRefreshToken combines the token ID and token payload to create a new token. +// used in the tests to created a refresh token. +func buildRefreshToken(tokenID int64, tokenPayload []byte) string { + return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload)) +} + +func TestRefreshRepoVerifyInvalidTokens(t *testing.T) { + r := db.NewRefreshTokenRepo(connect(t)) + + token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope) + if err != nil { + t.Fatalf("Unexpected error: %v", err) } - if _, err := repo.Verify(clientID, tok); err == nil { - t.Errorf("Token which should have been revoked was verified") + badTokenPayload, err := refresh.DefaultRefreshTokenGenerator() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + tokenWithBadID := "404" + token[1:] + tokenWithBadPayload := buildRefreshToken(1, badTokenPayload) + + tests := []struct { + token string + creds oidc.ClientCredentials + err error + expected string + }{ + { + "invalid-token-format", + oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, + refresh.ErrorInvalidToken, + "", + }, + { + "b/invalid-base64-encoded-format", + oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, + refresh.ErrorInvalidToken, + "", + }, + { + "1/invalid-base64-encoded-format", + oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, + refresh.ErrorInvalidToken, + "", + }, + { + token + "corrupted-token-payload", + oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, + refresh.ErrorInvalidToken, + "", + }, + { + // The token's ID content is invalid. + tokenWithBadID, + oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, + refresh.ErrorInvalidToken, + "", + }, + { + // The token's payload content is invalid. + tokenWithBadPayload, + oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, + refresh.ErrorInvalidToken, + "", + }, + { + token, + oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"}, + refresh.ErrorInvalidClientID, + "", + }, + { + token, + oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, + nil, + "user-foo", + }, + } + + for i, tt := range tests { + result, _, err := r.Verify(tt.creds.ID, tt.token) + if err != tt.err { + t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) + } + if result != tt.expected { + t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result) + } + } +} + +func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) { + tests := []struct { + clientIDs []string + }{ + {clientIDs: []string{"client1", "client2"}}, + {clientIDs: []string{"client1"}}, + {clientIDs: []string{}}, + } + + for i, tt := range tests { + repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) + + for _, clientID := range tt.clientIDs { + _, err := repo.Create(testRefreshUserID, clientID, []string{"openid"}) + if err != nil { + t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) + } + } + + clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID) + if err != nil { + t.Fatalf("case %d: unexpected error fetching clients %q", i, err) + } + var clientIDs []string + for _, client := range clients { + clientIDs = append(clientIDs, client.Credentials.ID) + } + sort.Strings(clientIDs) + + if diff := pretty.Compare(clientIDs, tt.clientIDs); diff != "" { + t.Errorf("case %d: Compare(want, got): %v", i, diff) + } + } +} + +func TestRefreshTokenRepoRevokeForClient(t *testing.T) { + tests := []struct { + createIDs []string + revokeID string + }{ + { + createIDs: []string{"client1", "client2"}, + revokeID: "client1", + }, + { + createIDs: []string{"client2"}, + revokeID: "client1", + }, + { + createIDs: []string{"client1"}, + revokeID: "client1", + }, + { + createIDs: []string{}, + revokeID: "oops", + }, + } + + for i, tt := range tests { + repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) + + for _, clientID := range tt.createIDs { + _, err := repo.Create(testRefreshUserID, clientID, []string{"openid"}) + if err != nil { + t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) + } + + if err := repo.RevokeTokensForClient(testRefreshUserID, tt.revokeID); err != nil { + t.Fatalf("case %d: couldn't revoke refresh token(s): %v", i, err) + } + } + + var wantIDs []string + for _, id := range tt.createIDs { + if id != tt.revokeID { + wantIDs = append(wantIDs, id) + } + } + + clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID) + if err != nil { + t.Fatalf("case %d: unexpected error fetching clients %q", i, err) + } + + var gotIDs []string + for _, client := range clients { + gotIDs = append(gotIDs, client.Credentials.ID) + } + sort.Strings(gotIDs) + + if diff := pretty.Compare(wantIDs, gotIDs); diff != "" { + t.Errorf("case %d: Compare(wantIDs, gotIDs): %v", i, diff) + } + } +} + +func TestRefreshRepoRevoke(t *testing.T) { + r := db.NewRefreshTokenRepo(connect(t)) + + token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + badTokenPayload, err := refresh.DefaultRefreshTokenGenerator() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + tokenWithBadID := "404" + token[1:] + tokenWithBadPayload := buildRefreshToken(1, badTokenPayload) + + tests := []struct { + token string + userID string + err error + }{ + { + "invalid-token-format", + "user-foo", + refresh.ErrorInvalidToken, + }, + { + "1/invalid-base64-encoded-format", + "user-foo", + refresh.ErrorInvalidToken, + }, + { + token + "corrupted-token-payload", + "user-foo", + refresh.ErrorInvalidToken, + }, + { + // The token's ID is invalid. + tokenWithBadID, + "user-foo", + refresh.ErrorInvalidToken, + }, + { + // The token's payload is invalid. + tokenWithBadPayload, + "user-foo", + refresh.ErrorInvalidToken, + }, + { + token, + "invalid-user", + refresh.ErrorInvalidUserID, + }, + { + token, + "user-foo", + nil, + }, + } + + for i, tt := range tests { + if err := r.Revoke(tt.userID, tt.token); err != tt.err { + t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) + } } } diff --git a/functional/repo/testutil.go b/functional/repo/testutil.go index 16ee2e60..b2fdbc98 100644 --- a/functional/repo/testutil.go +++ b/functional/repo/testutil.go @@ -12,7 +12,8 @@ import ( func connect(t *testing.T) *gorp.DbMap { dsn := os.Getenv("DEX_TEST_DSN") if dsn == "" { - t.Fatal("DEX_TEST_DSN environment variable not set") + return db.NewMemDB() + } c, err := db.NewConnection(db.Config{DSN: dsn}) if err != nil { diff --git a/integration/oidc_test.go b/integration/oidc_test.go index 4d8ac072..64c404b5 100644 --- a/integration/oidc_test.go +++ b/integration/oidc_test.go @@ -231,7 +231,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { // this will actually happen due to some interaction between the // end-user and a remote identity provider - sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"}) + sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access", "email", "profile"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/integration/user_api_test.go b/integration/user_api_test.go index 163a1382..fa9c4400 100644 --- a/integration/user_api_test.go +++ b/integration/user_api_test.go @@ -148,7 +148,8 @@ func makeUserAPITestFixtures() *userAPITestFixtures { refreshRepo := db.NewRefreshTokenRepo(dbMap) for _, user := range userUsers { - if _, err := refreshRepo.Create(user.User.ID, testClientID); err != nil { + if _, err := refreshRepo.Create(user.User.ID, testClientID, + append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil { panic("Failed to create refresh token: " + err.Error()) } } diff --git a/refresh/repo.go b/refresh/repo.go index 8f196a15..df81a426 100644 --- a/refresh/repo.go +++ b/refresh/repo.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/coreos/dex/client" + "github.com/coreos/dex/scope" ) const ( @@ -40,12 +41,15 @@ func DefaultRefreshTokenGenerator() ([]byte, error) { type RefreshTokenRepo interface { // Create generates and returns a new refresh token for the given client-user pair. - // On success the token will be return. - Create(userID, clientID string) (string, error) + // The scopes will be stored with the refresh token, and used to verify + // against future OIDC refresh requests' scopes. + // On success the token will be returned. + Create(userID, clientID string, scope []string) (string, error) - // Verify verifies that a token belongs to the client, and returns the corresponding user ID. - // Note that this assumes the client validation is currently done in the application layer, - Verify(clientID, token string) (string, error) + // Verify verifies that a token belongs to the client. + // It returns the user ID to which the token belongs, and the scopes stored + // with token. + Verify(clientID, token string) (string, scope.Scopes, error) // Revoke deletes the refresh token if the token belongs to the given userID. Revoke(userID, token string) error diff --git a/scope/scope.go b/scope/scope.go index 50fd266f..105c7b77 100644 --- a/scope/scope.go +++ b/scope/scope.go @@ -32,3 +32,17 @@ func (s Scopes) CrossClientIDs() []string { } return clients } + +func (s Scopes) Contains(other Scopes) bool { + rScopes := map[string]struct{}{} + for _, scope := range s { + rScopes[scope] = struct{}{} + } + + for _, scope := range other { + if _, ok := rScopes[scope]; !ok { + return false + } + } + return true +} diff --git a/server/cross_client_test.go b/server/cross_client_test.go index 4d1d4120..22d5b252 100644 --- a/server/cross_client_test.go +++ b/server/cross_client_test.go @@ -14,29 +14,24 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/coreos/dex/client" - clientmanager "github.com/coreos/dex/client/manager" "github.com/coreos/dex/connector" "github.com/coreos/dex/scope" ) func makeCrossClientTestFixtures() (*testFixtures, error) { - f, err := makeTestFixtures() - if err != nil { - return nil, fmt.Errorf("couldn't make test fixtures: %v", err) - } - + xClients := []client.LoadableClient{} for _, cliData := range []struct { - id string - authorized []string + id string + trustedPeers []string }{ { id: "client_a", }, { - id: "client_b", - authorized: []string{"client_a"}, + id: "client_b", + trustedPeers: []string{"client_a"}, }, { - id: "client_c", - authorized: []string{"client_a", "client_b"}, + id: "client_c", + trustedPeers: []string{"client_a", "client_b"}, }, } { u := url.URL{ @@ -44,20 +39,27 @@ func makeCrossClientTestFixtures() (*testFixtures, error) { Path: cliData.id, Host: cliData.id, } - cliCreds, err := f.clientManager.New(client.Client{ - Credentials: oidc.ClientCredentials{ - ID: cliData.id, + xClients = append(xClients, client.LoadableClient{ + Client: client.Client{ + Credentials: oidc.ClientCredentials{ + ID: cliData.id, + Secret: base64.URLEncoding.EncodeToString( + []byte(cliData.id + "_secret")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{u}, + }, }, - Metadata: oidc.ClientMetadata{ - RedirectURIs: []url.URL{u}, - }, - }, &clientmanager.ClientOptions{ - TrustedPeers: cliData.authorized, + TrustedPeers: cliData.trustedPeers, }) - if err != nil { - return nil, fmt.Errorf("Unexpected error creating clients: %v", err) - } - f.clientCreds[cliData.id] = *cliCreds + } + + xClients = append(xClients, testClients...) + f, err := makeTestFixturesWithOptions(testFixtureOptions{ + clients: xClients, + }) + if err != nil { + return nil, fmt.Errorf("couldn't make test fixtures: %v", err) } return f, nil } diff --git a/server/http.go b/server/http.go index 542442ad..4e4cd531 100644 --- a/server/http.go +++ b/server/http.go @@ -518,11 +518,12 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { } case oauth2.GrantTypeRefreshToken: token := r.PostForm.Get("refresh_token") + scopes := r.PostForm.Get("scope") if token == "" { writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return } - jwt, err = srv.RefreshToken(creds, token) + jwt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token) if err != nil { writeTokenError(w, err, state) return diff --git a/server/server.go b/server/server.go index 8c7bd361..a151f0b3 100644 --- a/server/server.go +++ b/server/server.go @@ -23,6 +23,7 @@ import ( "github.com/coreos/dex/connector" "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/refresh" + "github.com/coreos/dex/scope" "github.com/coreos/dex/session" sessionmanager "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" @@ -53,7 +54,7 @@ type OIDCServer interface { // RefreshToken takes a previously generated refresh token and returns a new ID token // if the token is valid. - RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) + RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) KillSession(string) error @@ -444,35 +445,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo claims := ses.Claims(s.IssuerURL.String()) user.AddToClaims(claims) - crossClientIDs := ses.Scope.CrossClientIDs() - if len(crossClientIDs) > 0 { - var aud []string - for _, id := range crossClientIDs { - if ses.ClientID == id { - aud = append(aud, id) - continue - } - allowed, err := s.CrossClientAuthAllowed(ses.ClientID, id) - if err != nil { - log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", ses.ClientID, id, err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) - } - if !allowed { - err := oauth2.NewError(oauth2.ErrorInvalidRequest) - err.Description = fmt.Sprintf( - "%q is not authorized to perform cross-client requests for %q", - ses.ClientID, id) - return nil, "", err - } - aud = append(aud, id) - } - if len(aud) == 1 { - claims.Add("aud", aud[0]) - } else { - claims.Add("aud", aud) - } - claims.Add("azp", ses.ClientID) - } + s.addClaimsFromScope(claims, ses.Scope, ses.ClientID) jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { @@ -487,7 +460,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo if scope == "offline_access" { log.Infof("Session %s requests offline access, will generate refresh token", sessionID) - refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID) + refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope) switch err { case nil: break @@ -503,7 +476,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo return jwt, refreshToken, nil } -func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) { +func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) { ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) @@ -514,7 +487,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose return nil, oauth2.NewError(oauth2.ErrorInvalidClient) } - userID, err := s.RefreshTokenRepo.Verify(creds.ID, token) + userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) switch err { case nil: break @@ -526,6 +499,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose return nil, oauth2.NewError(oauth2.ErrorServerError) } + if len(scopes) == 0 { + scopes = rtScopes + } else { + if !rtScopes.Contains(scopes) { + return nil, oauth2.NewError(oauth2.ErrorInvalidRequest) + } + } + user, err := s.UserRepo.Get(nil, userID) if err != nil { // The error can be user.ErrorNotFound, but we are not deleting @@ -546,6 +527,8 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt) user.AddToClaims(claims) + s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID) + jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) @@ -587,6 +570,41 @@ func (s *Server) JWTVerifierFactory() JWTVerifierFactory { } } +// addClaimsFromScope adds claims that are based on the scopes that the client requested. +// Currently, these include cross-client claims (aud, azp). +func (s *Server) addClaimsFromScope(claims jose.Claims, scopes scope.Scopes, clientID string) error { + crossClientIDs := scopes.CrossClientIDs() + if len(crossClientIDs) > 0 { + var aud []string + for _, id := range crossClientIDs { + if clientID == id { + aud = append(aud, id) + continue + } + allowed, err := s.CrossClientAuthAllowed(clientID, id) + if err != nil { + log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", clientID, id, err) + return oauth2.NewError(oauth2.ErrorServerError) + } + if !allowed { + err := oauth2.NewError(oauth2.ErrorInvalidRequest) + err.Description = fmt.Sprintf( + "%q is not authorized to perform cross-client requests for %q", + clientID, id) + return err + } + aud = append(aud, id) + } + if len(aud) == 1 { + claims.Add("aud", aud[0]) + } else { + claims.Add("aud", aud) + } + claims.Add("azp", clientID) + } + return nil +} + type sortableIDPCs []connector.Connector func (s sortableIDPCs) Len() int { diff --git a/server/server_test.go b/server/server_test.go index 1650d37f..44a7f2ed 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -18,6 +18,7 @@ import ( "github.com/coreos/dex/client" "github.com/coreos/dex/db" "github.com/coreos/dex/refresh/refreshtest" + "github.com/coreos/dex/scope" "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" ) @@ -484,91 +485,197 @@ func TestServerRefreshToken(t *testing.T) { // NOTE(ericchiang): These tests assume that the database ID of the first // refresh token will be "1". tests := []struct { - token string - clientID string // The client that associates with the token. - creds oidc.ClientCredentials - signer jose.Signer - err error + token string + clientID string // The client that associates with the token. + creds oidc.ClientCredentials + signer jose.Signer + createScopes []string + refreshScopes []string + expectedAud []string + err error }{ // Everything is good. { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - testClientCredentials, - signerFixture, - nil, + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + }, + // Asking for a scope not originally granted to you. + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile", "extra_scope"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(malformatted). { - "invalid-token", - testClientID, - testClientCredentials, - signerFixture, - oauth2.NewError(oauth2.ErrorInvalidRequest), + token: "invalid-token", + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid payload content). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), - testClientID, - testClientCredentials, - signerFixture, - oauth2.NewError(oauth2.ErrorInvalidRequest), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid refresh token(invalid ID content). { - fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - testClientCredentials, - signerFixture, - oauth2.NewError(oauth2.ErrorInvalidRequest), + token: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, // Invalid client(client is not associated with the token). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - clientB.Credentials, - signerFixture, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: clientB.Credentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no client ID). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - oidc.ClientCredentials{ID: "", Secret: "aaa"}, - signerFixture, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: oidc.ClientCredentials{ID: "", Secret: "aaa"}, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no such client). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, - signerFixture, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(no secrets). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - oidc.ClientCredentials{ID: testClientID}, - signerFixture, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: oidc.ClientCredentials{ID: testClientID}, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Invalid client(invalid secret). { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, - signerFixture, - oauth2.NewError(oauth2.ErrorInvalidClient), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorInvalidClient), }, // Signing operation fails. { - fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - testClientID, - testClientCredentials, - &StaticSigner{sig: nil, err: errors.New("fail")}, - oauth2.NewError(oauth2.ErrorServerError), + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: testClientID, + creds: testClientCredentials, + signer: &StaticSigner{sig: nil, err: errors.New("fail")}, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, + err: oauth2.NewError(oauth2.ErrorServerError), + }, + // Valid Cross-Client + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + expectedAud: []string{"client_b"}, + }, + // Valid Cross-Client - but this time we leave out the scopes in the + // refresh request, which should result in the original stored scopes + // being used. + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + refreshScopes: []string{}, + expectedAud: []string{"client_b"}, + }, + // Valid Cross-Client - asking for fewer scopes than originally used + // when creating the refresh token, which is ok. + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"}, + refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + expectedAud: []string{"client_b"}, + }, + // Valid Cross-Client - asking for multiple clients in the audience. + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"}, + refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"}, + expectedAud: []string{"client_b", "client_c"}, + }, + // Invalid Cross-Client - didn't orignally request cross-client when + // refresh token was created. + { + token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + clientID: "client_a", + creds: oidc.ClientCredentials{ + ID: "client_a", + Secret: base64.URLEncoding.EncodeToString( + []byte("client_a_secret")), + }, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"}, + err: oauth2.NewError(oauth2.ErrorInvalidRequest), }, } @@ -576,7 +683,7 @@ func TestServerRefreshToken(t *testing.T) { km := &StaticKeyManager{ signer: tt.signer, } - f, err := makeTestFixtures() + f, err := makeCrossClientTestFixtures() if err != nil { t.Fatalf("error making test fixtures: %v", err) } @@ -587,11 +694,12 @@ func TestServerRefreshToken(t *testing.T) { t.Errorf("case %d: error creating other client: %v", i, err) } - if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID); err != nil { + if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, + tt.createScopes); err != nil { t.Fatalf("Unexpected error: %v", err) } - jwt, err := f.srv.RefreshToken(tt.creds, tt.token) + jwt, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token) if !reflect.DeepEqual(err, tt.err) { t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) } @@ -604,8 +712,27 @@ func TestServerRefreshToken(t *testing.T) { if err != nil { t.Errorf("Case %d: unexpected error: %v", i, err) } - if claims["iss"] != testIssuerURL.String() || claims["sub"] != testUserID1 || claims["aud"] != testClientID { - t.Errorf("Case %d: invalid claims: %v", i, claims) + + var expectedAud interface{} + if tt.expectedAud == nil { + expectedAud = testClientID + } else if len(tt.expectedAud) == 1 { + expectedAud = tt.expectedAud[0] + } else { + expectedAud = tt.expectedAud + } + + if claims["iss"] != testIssuerURL.String() { + t.Errorf("Case %d: want=%v, got=%v", i, + testIssuerURL.String(), claims["iss"]) + } + if claims["sub"] != testUserID1 { + t.Errorf("Case %d: want=%v, got=%v", i, + testUserID1, claims["sub"]) + } + if diff := pretty.Compare(claims["aud"], expectedAud); diff != "" { + t.Errorf("Case %d: want=%v, got=%v", i, + expectedAud, claims["aud"]) } } } diff --git a/server/testutil.go b/server/testutil.go index 06979b03..2155d70d 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -39,6 +39,18 @@ var ( ID: testClientID, Secret: clientTestSecret, } + testClients = []client.LoadableClient{ + { + Client: client.Client{ + Credentials: testClientCredentials, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + testRedirectURL, + }, + }, + }, + }, + } testConnectorID1 = "IDPC-1" @@ -169,18 +181,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err var clients []client.LoadableClient if options.clients == nil { - clients = []client.LoadableClient{ - { - Client: client.Client{ - Credentials: testClientCredentials, - Metadata: oidc.ClientMetadata{ - RedirectURIs: []url.URL{ - testRedirectURL, - }, - }, - }, - }, - } + clients = testClients } else { clients = options.clients } @@ -247,6 +248,10 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err srv.absURL(httpPathAcceptInvitation), ) + clientCreds := map[string]oidc.ClientCredentials{} + for _, c := range clients { + clientCreds[c.Client.Credentials.ID] = c.Client.Credentials + } return &testFixtures{ srv: srv, redirectURL: testRedirectURL, @@ -255,9 +260,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err emailer: emailer, clientRepo: clientRepo, clientManager: clientManager, - clientCreds: map[string]oidc.ClientCredentials{ - testClientID: testClientCreds, - }, + clientCreds: clientCreds, }, nil } diff --git a/user/api/api_test.go b/user/api/api_test.go index 2fe97189..90d68687 100644 --- a/user/api/api_test.go +++ b/user/api/api_test.go @@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { } refreshRepo := db.NewRefreshTokenRepo(dbMap) for _, token := range refreshTokens { - if _, err := refreshRepo.Create(token.userID, token.clientID); err != nil { + if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil { panic("Failed to create refresh token: " + err.Error()) } }