*: implement refresh revocation user API methods

This commit is contained in:
Eric Chiang 2016-04-06 11:29:09 -07:00
parent aa00a4b094
commit 64380734e6
3 changed files with 178 additions and 15 deletions

View file

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"sort"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -99,10 +100,9 @@ var (
func makeUserAPITestFixtures() *userAPITestFixtures { func makeUserAPITestFixtures() *userAPITestFixtures {
f := &userAPITestFixtures{} f := &userAPITestFixtures{}
_, _, _, um := makeUserObjects(userUsers, userPasswords) dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
cir := func() client.ClientIdentityRepo { cir := func() client.ClientIdentityRepo {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ repo, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: testClientID, ID: testClientID,
@ -144,8 +144,16 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
return oidc.NewJWTVerifier(testIssuerURL.String(), clientID, noop, keysFunc) return oidc.NewJWTVerifier(testIssuerURL.String(), clientID, noop, keysFunc)
} }
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, user := range userUsers {
if _, err := refreshRepo.Create(user.User.ID, testClientID); err != nil {
panic("Failed to create refresh token: " + err.Error())
}
}
f.emailer = &testEmailer{} f.emailer = &testEmailer{}
api := api.NewUsersAPI(um, cir, f.emailer, "local") um.Clock = clock
api := api.NewUsersAPI(dbMap, um, f.emailer, "local")
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir) usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir)
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler()) f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
@ -584,6 +592,48 @@ func TestDisableUser(t *testing.T) {
} }
} }
func TestRefreshTokenEndpoints(t *testing.T) {
tests := []struct {
userID string
clients []string
}{
{"ID-1", []string{testClientID}},
{"ID-2", []string{testClientID}},
}
for i, tt := range tests {
f := makeUserAPITestFixtures()
list, err := f.client.RefreshClient.List(tt.userID).Do()
if err != nil {
t.Errorf("case %d: list clients: %v", i, err)
continue
}
var ids []string
for _, client := range list.Clients {
ids = append(ids, client.ClientID)
}
sort.Strings(ids)
sort.Strings(tt.clients)
if diff := pretty.Compare(tt.clients, ids); diff != "" {
t.Errorf("case %d: expected client ids did not match actual: %s", i, diff)
}
for _, clientID := range ids {
if err := f.client.Clients.Revoke(tt.userID, clientID).Do(); err != nil {
t.Errorf("case %d: failed to revoke client: %v", i, err)
}
}
list, err = f.client.RefreshClient.List(tt.userID).Do()
if err != nil {
t.Errorf("case %d: list clients after revocation: %v", i, err)
continue
}
if n := len(list.Clients); n != 0 {
t.Errorf("case %d: expected no refresh tokens after revocation, got %d", i, n)
}
}
}
func TestResendEmailInvitation(t *testing.T) { func TestResendEmailInvitation(t *testing.T) {
tests := []struct { tests := []struct {
req schema.ResendEmailInvitationRequest req schema.ResendEmailInvitationRequest

View file

@ -9,8 +9,12 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" "github.com/coreos/dex/user/manager"
@ -87,6 +91,7 @@ type UsersAPI struct {
manager *manager.UserManager manager *manager.UserManager
localConnectorID string localConnectorID string
clientIdentityRepo client.ClientIdentityRepo clientIdentityRepo client.ClientIdentityRepo
refreshRepo refresh.RefreshTokenRepo
emailer Emailer emailer Emailer
} }
@ -99,10 +104,12 @@ type Creds struct {
User user.User User user.User
} }
func NewUsersAPI(manager *manager.UserManager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI { // TODO(ericchiang): Don't pass a dbMap. See #385.
func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI {
return &UsersAPI{ return &UsersAPI{
manager: manager, manager: userManager,
clientIdentityRepo: cir, refreshRepo: db.NewRefreshTokenRepo(dbMap),
clientIdentityRepo: db.NewClientIdentityRepo(dbMap),
localConnectorID: localConnectorID, localConnectorID: localConnectorID,
emailer: emailer, emailer: emailer,
} }
@ -258,6 +265,47 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string)
return list, tok, nil return list, tok, nil
} }
// ListClientsWithRefreshTokens returns all clients issued refresh tokens
// for the authenticated user.
func (u *UsersAPI) ListClientsWithRefreshTokens(creds Creds, userID string) ([]*schema.RefreshClient, error) {
// Users must either be an admin or be requesting data associated with their own account.
if !creds.User.Admin && (creds.User.ID != userID) {
return nil, ErrorUnauthorized
}
clientIdentities, err := u.refreshRepo.ClientsWithRefreshTokens(userID)
if err != nil {
return nil, err
}
clients := make([]*schema.RefreshClient, len(clientIdentities))
urlToString := func(u *url.URL) string {
if u == nil {
return ""
}
return u.String()
}
for i, identity := range clientIdentities {
clients[i] = &schema.RefreshClient{
ClientID: identity.Credentials.ID,
ClientName: identity.Metadata.ClientName,
ClientURI: urlToString(identity.Metadata.ClientURI),
LogoURI: urlToString(identity.Metadata.LogoURI),
}
}
return clients, nil
}
// RevokeClient revokes all refresh tokens issued to this client for the
// authenticiated user.
func (u *UsersAPI) RevokeRefreshTokensForClient(creds Creds, userID, clientID string) error {
// Users must either be an admin or be requesting data associated with their own account.
if !creds.User.Admin && (creds.User.ID != userID) {
return ErrorUnauthorized
}
return u.refreshRepo.RevokeTokensForClient(userID, clientID)
}
func (u *UsersAPI) Authorize(creds Creds) bool { func (u *UsersAPI) Authorize(creds Creds) bool {
return creds.User.Admin && !creds.User.Disabled return creds.User.Admin && !creds.User.Disabled
} }

View file

@ -3,6 +3,7 @@ package api
import ( import (
"encoding/base64" "encoding/base64"
"net/url" "net/url"
"sort"
"testing" "testing"
"time" "time"
@ -10,7 +11,6 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
@ -166,16 +166,27 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}, },
}, },
} }
cir := func() client.ClientIdentityRepo { if _, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci}); err != nil {
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci}) panic("Failed to create client identity repo: " + err.Error())
if err != nil { }
panic("Failed to create client identity repo: " + err.Error())
// Used in TestRevokeRefreshToken test.
refreshTokens := []struct {
clientID string
userID string
}{
{"XXX", "ID-1"},
{"XXX", "ID-2"},
}
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens {
if _, err := refreshRepo.Create(token.userID, token.clientID); err != nil {
panic("Failed to create refresh token: " + err.Error())
} }
return repo }
}()
emailer := &testEmailer{} emailer := &testEmailer{}
api := NewUsersAPI(mgr, cir, emailer, "local") api := NewUsersAPI(dbMap, mgr, emailer, "local")
return api, emailer return api, emailer
} }
@ -562,3 +573,57 @@ func TestResendEmailInvitation(t *testing.T) {
} }
} }
} }
func TestRevokeRefreshToken(t *testing.T) {
tests := []struct {
userID string
toRevoke string
before []string // clientIDs expected before the change.
after []string // clientIDs expected after the change.
}{
{"ID-1", "XXX", []string{"XXX"}, []string{}},
{"ID-2", "XXX", []string{"XXX"}, []string{}},
}
api, _ := makeTestFixtures()
listClientsWithRefreshTokens := func(creds Creds, userID string) ([]string, error) {
clients, err := api.ListClientsWithRefreshTokens(creds, userID)
if err != nil {
return nil, err
}
clientIDs := make([]string, len(clients))
for i, client := range clients {
clientIDs[i] = client.ClientID
}
sort.Strings(clientIDs)
return clientIDs, nil
}
for i, tt := range tests {
creds := Creds{User: user.User{ID: tt.userID}}
gotBefore, err := listClientsWithRefreshTokens(creds, tt.userID)
if err != nil {
t.Errorf("case %d: list clients failed: %v", i, err)
} else {
if diff := pretty.Compare(tt.before, gotBefore); diff != "" {
t.Errorf("case %d: before exp!=got: %s", i, diff)
}
}
if err := api.RevokeRefreshTokensForClient(creds, tt.userID, tt.toRevoke); err != nil {
t.Errorf("case %d: failed to revoke client: %v", i, err)
continue
}
gotAfter, err := listClientsWithRefreshTokens(creds, tt.userID)
if err != nil {
t.Errorf("case %d: list clients failed: %v", i, err)
} else {
if diff := pretty.Compare(tt.after, gotAfter); diff != "" {
t.Errorf("case %d: after exp!=got: %s", i, diff)
}
}
}
}