diff --git a/integration/user_api_test.go b/integration/user_api_test.go index f2f4758f..14a2f161 100644 --- a/integration/user_api_test.go +++ b/integration/user_api_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "sort" "strings" "testing" "time" @@ -99,10 +100,9 @@ var ( func makeUserAPITestFixtures() *userAPITestFixtures { f := &userAPITestFixtures{} - _, _, _, um := makeUserObjects(userUsers, userPasswords) - + dbMap, _, _, um := makeUserObjects(userUsers, userPasswords) cir := func() client.ClientIdentityRepo { - repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ + repo, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: testClientID, @@ -144,8 +144,16 @@ func makeUserAPITestFixtures() *userAPITestFixtures { return oidc.NewJWTVerifier(testIssuerURL.String(), clientID, noop, keysFunc) } + refreshRepo := db.NewRefreshTokenRepo(dbMap) + for _, user := range userUsers { + if _, err := refreshRepo.Create(user.User.ID, testClientID); err != nil { + panic("Failed to create refresh token: " + err.Error()) + } + } + f.emailer = &testEmailer{} - api := api.NewUsersAPI(um, cir, f.emailer, "local") + um.Clock = clock + api := api.NewUsersAPI(dbMap, um, f.emailer, "local") usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir) f.hSrv = httptest.NewServer(usrSrv.HTTPHandler()) @@ -584,6 +592,48 @@ func TestDisableUser(t *testing.T) { } } +func TestRefreshTokenEndpoints(t *testing.T) { + + tests := []struct { + userID string + clients []string + }{ + {"ID-1", []string{testClientID}}, + {"ID-2", []string{testClientID}}, + } + + for i, tt := range tests { + f := makeUserAPITestFixtures() + list, err := f.client.RefreshClient.List(tt.userID).Do() + if err != nil { + t.Errorf("case %d: list clients: %v", i, err) + continue + } + var ids []string + for _, client := range list.Clients { + ids = append(ids, client.ClientID) + } + sort.Strings(ids) + sort.Strings(tt.clients) + if diff := pretty.Compare(tt.clients, ids); diff != "" { + t.Errorf("case %d: expected client ids did not match actual: %s", i, diff) + } + for _, clientID := range ids { + if err := f.client.Clients.Revoke(tt.userID, clientID).Do(); err != nil { + t.Errorf("case %d: failed to revoke client: %v", i, err) + } + } + list, err = f.client.RefreshClient.List(tt.userID).Do() + if err != nil { + t.Errorf("case %d: list clients after revocation: %v", i, err) + continue + } + if n := len(list.Clients); n != 0 { + t.Errorf("case %d: expected no refresh tokens after revocation, got %d", i, n) + } + } +} + func TestResendEmailInvitation(t *testing.T) { tests := []struct { req schema.ResendEmailInvitationRequest diff --git a/user/api/api.go b/user/api/api.go index 9eca4b38..35e0e907 100644 --- a/user/api/api.go +++ b/user/api/api.go @@ -9,8 +9,12 @@ import ( "net/url" "time" + "github.com/go-gorp/gorp" + "github.com/coreos/dex/client" + "github.com/coreos/dex/db" "github.com/coreos/dex/pkg/log" + "github.com/coreos/dex/refresh" schema "github.com/coreos/dex/schema/workerschema" "github.com/coreos/dex/user" "github.com/coreos/dex/user/manager" @@ -87,6 +91,7 @@ type UsersAPI struct { manager *manager.UserManager localConnectorID string clientIdentityRepo client.ClientIdentityRepo + refreshRepo refresh.RefreshTokenRepo emailer Emailer } @@ -99,10 +104,12 @@ type Creds struct { User user.User } -func NewUsersAPI(manager *manager.UserManager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI { +// TODO(ericchiang): Don't pass a dbMap. See #385. +func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI { return &UsersAPI{ - manager: manager, - clientIdentityRepo: cir, + manager: userManager, + refreshRepo: db.NewRefreshTokenRepo(dbMap), + clientIdentityRepo: db.NewClientIdentityRepo(dbMap), localConnectorID: localConnectorID, emailer: emailer, } @@ -258,6 +265,47 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string) return list, tok, nil } +// ListClientsWithRefreshTokens returns all clients issued refresh tokens +// for the authenticated user. +func (u *UsersAPI) ListClientsWithRefreshTokens(creds Creds, userID string) ([]*schema.RefreshClient, error) { + // Users must either be an admin or be requesting data associated with their own account. + if !creds.User.Admin && (creds.User.ID != userID) { + return nil, ErrorUnauthorized + } + clientIdentities, err := u.refreshRepo.ClientsWithRefreshTokens(userID) + if err != nil { + return nil, err + } + clients := make([]*schema.RefreshClient, len(clientIdentities)) + + urlToString := func(u *url.URL) string { + if u == nil { + return "" + } + return u.String() + } + + for i, identity := range clientIdentities { + clients[i] = &schema.RefreshClient{ + ClientID: identity.Credentials.ID, + ClientName: identity.Metadata.ClientName, + ClientURI: urlToString(identity.Metadata.ClientURI), + LogoURI: urlToString(identity.Metadata.LogoURI), + } + } + return clients, nil +} + +// RevokeClient revokes all refresh tokens issued to this client for the +// authenticiated user. +func (u *UsersAPI) RevokeRefreshTokensForClient(creds Creds, userID, clientID string) error { + // Users must either be an admin or be requesting data associated with their own account. + if !creds.User.Admin && (creds.User.ID != userID) { + return ErrorUnauthorized + } + return u.refreshRepo.RevokeTokensForClient(userID, clientID) +} + func (u *UsersAPI) Authorize(creds Creds) bool { return creds.User.Admin && !creds.User.Disabled } diff --git a/user/api/api_test.go b/user/api/api_test.go index 5412cb2d..de5b62d0 100644 --- a/user/api/api_test.go +++ b/user/api/api_test.go @@ -3,6 +3,7 @@ package api import ( "encoding/base64" "net/url" + "sort" "testing" "time" @@ -10,7 +11,6 @@ import ( "github.com/jonboulle/clockwork" "github.com/kylelemons/godebug/pretty" - "github.com/coreos/dex/client" "github.com/coreos/dex/connector" "github.com/coreos/dex/db" schema "github.com/coreos/dex/schema/workerschema" @@ -166,16 +166,27 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { }, }, } - cir := func() client.ClientIdentityRepo { - repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) - if err != nil { - panic("Failed to create client identity repo: " + err.Error()) + if _, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci}); err != nil { + panic("Failed to create client identity repo: " + err.Error()) + } + + // Used in TestRevokeRefreshToken test. + refreshTokens := []struct { + clientID string + userID string + }{ + {"XXX", "ID-1"}, + {"XXX", "ID-2"}, + } + refreshRepo := db.NewRefreshTokenRepo(dbMap) + for _, token := range refreshTokens { + if _, err := refreshRepo.Create(token.userID, token.clientID); err != nil { + panic("Failed to create refresh token: " + err.Error()) } - return repo - }() + } emailer := &testEmailer{} - api := NewUsersAPI(mgr, cir, emailer, "local") + api := NewUsersAPI(dbMap, mgr, emailer, "local") return api, emailer } @@ -562,3 +573,57 @@ func TestResendEmailInvitation(t *testing.T) { } } } + +func TestRevokeRefreshToken(t *testing.T) { + tests := []struct { + userID string + toRevoke string + before []string // clientIDs expected before the change. + after []string // clientIDs expected after the change. + }{ + {"ID-1", "XXX", []string{"XXX"}, []string{}}, + {"ID-2", "XXX", []string{"XXX"}, []string{}}, + } + + api, _ := makeTestFixtures() + + listClientsWithRefreshTokens := func(creds Creds, userID string) ([]string, error) { + clients, err := api.ListClientsWithRefreshTokens(creds, userID) + if err != nil { + return nil, err + } + clientIDs := make([]string, len(clients)) + for i, client := range clients { + clientIDs[i] = client.ClientID + } + sort.Strings(clientIDs) + return clientIDs, nil + } + + for i, tt := range tests { + creds := Creds{User: user.User{ID: tt.userID}} + + gotBefore, err := listClientsWithRefreshTokens(creds, tt.userID) + if err != nil { + t.Errorf("case %d: list clients failed: %v", i, err) + } else { + if diff := pretty.Compare(tt.before, gotBefore); diff != "" { + t.Errorf("case %d: before exp!=got: %s", i, diff) + } + } + + if err := api.RevokeRefreshTokensForClient(creds, tt.userID, tt.toRevoke); err != nil { + t.Errorf("case %d: failed to revoke client: %v", i, err) + continue + } + + gotAfter, err := listClientsWithRefreshTokens(creds, tt.userID) + if err != nil { + t.Errorf("case %d: list clients failed: %v", i, err) + } else { + if diff := pretty.Compare(tt.after, gotAfter); diff != "" { + t.Errorf("case %d: after exp!=got: %s", i, diff) + } + } + } +}