forked from mystiq/dex
*: implement refresh revocation user API methods
This commit is contained in:
parent
aa00a4b094
commit
64380734e6
3 changed files with 178 additions and 15 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -99,10 +100,9 @@ var (
|
|||
func makeUserAPITestFixtures() *userAPITestFixtures {
|
||||
f := &userAPITestFixtures{}
|
||||
|
||||
_, _, _, um := makeUserObjects(userUsers, userPasswords)
|
||||
|
||||
dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
|
||||
cir := func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
||||
repo, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: testClientID,
|
||||
|
@ -144,8 +144,16 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
|||
return oidc.NewJWTVerifier(testIssuerURL.String(), clientID, noop, keysFunc)
|
||||
}
|
||||
|
||||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||
for _, user := range userUsers {
|
||||
if _, err := refreshRepo.Create(user.User.ID, testClientID); err != nil {
|
||||
panic("Failed to create refresh token: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
f.emailer = &testEmailer{}
|
||||
api := api.NewUsersAPI(um, cir, f.emailer, "local")
|
||||
um.Clock = clock
|
||||
api := api.NewUsersAPI(dbMap, um, f.emailer, "local")
|
||||
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir)
|
||||
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
|
||||
|
||||
|
@ -584,6 +592,48 @@ func TestDisableUser(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenEndpoints(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
userID string
|
||||
clients []string
|
||||
}{
|
||||
{"ID-1", []string{testClientID}},
|
||||
{"ID-2", []string{testClientID}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
f := makeUserAPITestFixtures()
|
||||
list, err := f.client.RefreshClient.List(tt.userID).Do()
|
||||
if err != nil {
|
||||
t.Errorf("case %d: list clients: %v", i, err)
|
||||
continue
|
||||
}
|
||||
var ids []string
|
||||
for _, client := range list.Clients {
|
||||
ids = append(ids, client.ClientID)
|
||||
}
|
||||
sort.Strings(ids)
|
||||
sort.Strings(tt.clients)
|
||||
if diff := pretty.Compare(tt.clients, ids); diff != "" {
|
||||
t.Errorf("case %d: expected client ids did not match actual: %s", i, diff)
|
||||
}
|
||||
for _, clientID := range ids {
|
||||
if err := f.client.Clients.Revoke(tt.userID, clientID).Do(); err != nil {
|
||||
t.Errorf("case %d: failed to revoke client: %v", i, err)
|
||||
}
|
||||
}
|
||||
list, err = f.client.RefreshClient.List(tt.userID).Do()
|
||||
if err != nil {
|
||||
t.Errorf("case %d: list clients after revocation: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if n := len(list.Clients); n != 0 {
|
||||
t.Errorf("case %d: expected no refresh tokens after revocation, got %d", i, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResendEmailInvitation(t *testing.T) {
|
||||
tests := []struct {
|
||||
req schema.ResendEmailInvitationRequest
|
||||
|
|
|
@ -9,8 +9,12 @@ import (
|
|||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
|
@ -87,6 +91,7 @@ type UsersAPI struct {
|
|||
manager *manager.UserManager
|
||||
localConnectorID string
|
||||
clientIdentityRepo client.ClientIdentityRepo
|
||||
refreshRepo refresh.RefreshTokenRepo
|
||||
emailer Emailer
|
||||
}
|
||||
|
||||
|
@ -99,10 +104,12 @@ type Creds struct {
|
|||
User user.User
|
||||
}
|
||||
|
||||
func NewUsersAPI(manager *manager.UserManager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI {
|
||||
// TODO(ericchiang): Don't pass a dbMap. See #385.
|
||||
func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI {
|
||||
return &UsersAPI{
|
||||
manager: manager,
|
||||
clientIdentityRepo: cir,
|
||||
manager: userManager,
|
||||
refreshRepo: db.NewRefreshTokenRepo(dbMap),
|
||||
clientIdentityRepo: db.NewClientIdentityRepo(dbMap),
|
||||
localConnectorID: localConnectorID,
|
||||
emailer: emailer,
|
||||
}
|
||||
|
@ -258,6 +265,47 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string)
|
|||
return list, tok, nil
|
||||
}
|
||||
|
||||
// ListClientsWithRefreshTokens returns all clients issued refresh tokens
|
||||
// for the authenticated user.
|
||||
func (u *UsersAPI) ListClientsWithRefreshTokens(creds Creds, userID string) ([]*schema.RefreshClient, error) {
|
||||
// Users must either be an admin or be requesting data associated with their own account.
|
||||
if !creds.User.Admin && (creds.User.ID != userID) {
|
||||
return nil, ErrorUnauthorized
|
||||
}
|
||||
clientIdentities, err := u.refreshRepo.ClientsWithRefreshTokens(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clients := make([]*schema.RefreshClient, len(clientIdentities))
|
||||
|
||||
urlToString := func(u *url.URL) string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
for i, identity := range clientIdentities {
|
||||
clients[i] = &schema.RefreshClient{
|
||||
ClientID: identity.Credentials.ID,
|
||||
ClientName: identity.Metadata.ClientName,
|
||||
ClientURI: urlToString(identity.Metadata.ClientURI),
|
||||
LogoURI: urlToString(identity.Metadata.LogoURI),
|
||||
}
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
|
||||
// RevokeClient revokes all refresh tokens issued to this client for the
|
||||
// authenticiated user.
|
||||
func (u *UsersAPI) RevokeRefreshTokensForClient(creds Creds, userID, clientID string) error {
|
||||
// Users must either be an admin or be requesting data associated with their own account.
|
||||
if !creds.User.Admin && (creds.User.ID != userID) {
|
||||
return ErrorUnauthorized
|
||||
}
|
||||
return u.refreshRepo.RevokeTokensForClient(userID, clientID)
|
||||
}
|
||||
|
||||
func (u *UsersAPI) Authorize(creds Creds) bool {
|
||||
return creds.User.Admin && !creds.User.Disabled
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package api
|
|||
import (
|
||||
"encoding/base64"
|
||||
"net/url"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -10,7 +11,6 @@ import (
|
|||
"github.com/jonboulle/clockwork"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
|
@ -166,16 +166,27 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
},
|
||||
},
|
||||
}
|
||||
cir := func() client.ClientIdentityRepo {
|
||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
|
||||
if err != nil {
|
||||
panic("Failed to create client identity repo: " + err.Error())
|
||||
if _, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci}); err != nil {
|
||||
panic("Failed to create client identity repo: " + err.Error())
|
||||
}
|
||||
|
||||
// Used in TestRevokeRefreshToken test.
|
||||
refreshTokens := []struct {
|
||||
clientID string
|
||||
userID string
|
||||
}{
|
||||
{"XXX", "ID-1"},
|
||||
{"XXX", "ID-2"},
|
||||
}
|
||||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||
for _, token := range refreshTokens {
|
||||
if _, err := refreshRepo.Create(token.userID, token.clientID); err != nil {
|
||||
panic("Failed to create refresh token: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
}
|
||||
|
||||
emailer := &testEmailer{}
|
||||
api := NewUsersAPI(mgr, cir, emailer, "local")
|
||||
api := NewUsersAPI(dbMap, mgr, emailer, "local")
|
||||
return api, emailer
|
||||
|
||||
}
|
||||
|
@ -562,3 +573,57 @@ func TestResendEmailInvitation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeRefreshToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
userID string
|
||||
toRevoke string
|
||||
before []string // clientIDs expected before the change.
|
||||
after []string // clientIDs expected after the change.
|
||||
}{
|
||||
{"ID-1", "XXX", []string{"XXX"}, []string{}},
|
||||
{"ID-2", "XXX", []string{"XXX"}, []string{}},
|
||||
}
|
||||
|
||||
api, _ := makeTestFixtures()
|
||||
|
||||
listClientsWithRefreshTokens := func(creds Creds, userID string) ([]string, error) {
|
||||
clients, err := api.ListClientsWithRefreshTokens(creds, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientIDs := make([]string, len(clients))
|
||||
for i, client := range clients {
|
||||
clientIDs[i] = client.ClientID
|
||||
}
|
||||
sort.Strings(clientIDs)
|
||||
return clientIDs, nil
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
creds := Creds{User: user.User{ID: tt.userID}}
|
||||
|
||||
gotBefore, err := listClientsWithRefreshTokens(creds, tt.userID)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: list clients failed: %v", i, err)
|
||||
} else {
|
||||
if diff := pretty.Compare(tt.before, gotBefore); diff != "" {
|
||||
t.Errorf("case %d: before exp!=got: %s", i, diff)
|
||||
}
|
||||
}
|
||||
|
||||
if err := api.RevokeRefreshTokensForClient(creds, tt.userID, tt.toRevoke); err != nil {
|
||||
t.Errorf("case %d: failed to revoke client: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
gotAfter, err := listClientsWithRefreshTokens(creds, tt.userID)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: list clients failed: %v", i, err)
|
||||
} else {
|
||||
if diff := pretty.Compare(tt.after, gotAfter); diff != "" {
|
||||
t.Errorf("case %d: after exp!=got: %s", i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue