forked from mystiq/dex
*: implement refresh revocation user API methods
This commit is contained in:
parent
aa00a4b094
commit
64380734e6
3 changed files with 178 additions and 15 deletions
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -99,10 +100,9 @@ var (
|
||||||
func makeUserAPITestFixtures() *userAPITestFixtures {
|
func makeUserAPITestFixtures() *userAPITestFixtures {
|
||||||
f := &userAPITestFixtures{}
|
f := &userAPITestFixtures{}
|
||||||
|
|
||||||
_, _, _, um := makeUserObjects(userUsers, userPasswords)
|
dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
|
||||||
|
|
||||||
cir := func() client.ClientIdentityRepo {
|
cir := func() client.ClientIdentityRepo {
|
||||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{
|
repo, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{
|
||||||
oidc.ClientIdentity{
|
oidc.ClientIdentity{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: testClientID,
|
ID: testClientID,
|
||||||
|
@ -144,8 +144,16 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
||||||
return oidc.NewJWTVerifier(testIssuerURL.String(), clientID, noop, keysFunc)
|
return oidc.NewJWTVerifier(testIssuerURL.String(), clientID, noop, keysFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||||
|
for _, user := range userUsers {
|
||||||
|
if _, err := refreshRepo.Create(user.User.ID, testClientID); err != nil {
|
||||||
|
panic("Failed to create refresh token: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
f.emailer = &testEmailer{}
|
f.emailer = &testEmailer{}
|
||||||
api := api.NewUsersAPI(um, cir, f.emailer, "local")
|
um.Clock = clock
|
||||||
|
api := api.NewUsersAPI(dbMap, um, f.emailer, "local")
|
||||||
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir)
|
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir)
|
||||||
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
|
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
|
||||||
|
|
||||||
|
@ -584,6 +592,48 @@ func TestDisableUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenEndpoints(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
userID string
|
||||||
|
clients []string
|
||||||
|
}{
|
||||||
|
{"ID-1", []string{testClientID}},
|
||||||
|
{"ID-2", []string{testClientID}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
f := makeUserAPITestFixtures()
|
||||||
|
list, err := f.client.RefreshClient.List(tt.userID).Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: list clients: %v", i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var ids []string
|
||||||
|
for _, client := range list.Clients {
|
||||||
|
ids = append(ids, client.ClientID)
|
||||||
|
}
|
||||||
|
sort.Strings(ids)
|
||||||
|
sort.Strings(tt.clients)
|
||||||
|
if diff := pretty.Compare(tt.clients, ids); diff != "" {
|
||||||
|
t.Errorf("case %d: expected client ids did not match actual: %s", i, diff)
|
||||||
|
}
|
||||||
|
for _, clientID := range ids {
|
||||||
|
if err := f.client.Clients.Revoke(tt.userID, clientID).Do(); err != nil {
|
||||||
|
t.Errorf("case %d: failed to revoke client: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
list, err = f.client.RefreshClient.List(tt.userID).Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: list clients after revocation: %v", i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n := len(list.Clients); n != 0 {
|
||||||
|
t.Errorf("case %d: expected no refresh tokens after revocation, got %d", i, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResendEmailInvitation(t *testing.T) {
|
func TestResendEmailInvitation(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
req schema.ResendEmailInvitationRequest
|
req schema.ResendEmailInvitationRequest
|
||||||
|
|
|
@ -9,8 +9,12 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gorp/gorp"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
|
"github.com/coreos/dex/refresh"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/dex/user/manager"
|
"github.com/coreos/dex/user/manager"
|
||||||
|
@ -87,6 +91,7 @@ type UsersAPI struct {
|
||||||
manager *manager.UserManager
|
manager *manager.UserManager
|
||||||
localConnectorID string
|
localConnectorID string
|
||||||
clientIdentityRepo client.ClientIdentityRepo
|
clientIdentityRepo client.ClientIdentityRepo
|
||||||
|
refreshRepo refresh.RefreshTokenRepo
|
||||||
emailer Emailer
|
emailer Emailer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,10 +104,12 @@ type Creds struct {
|
||||||
User user.User
|
User user.User
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUsersAPI(manager *manager.UserManager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI {
|
// TODO(ericchiang): Don't pass a dbMap. See #385.
|
||||||
|
func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI {
|
||||||
return &UsersAPI{
|
return &UsersAPI{
|
||||||
manager: manager,
|
manager: userManager,
|
||||||
clientIdentityRepo: cir,
|
refreshRepo: db.NewRefreshTokenRepo(dbMap),
|
||||||
|
clientIdentityRepo: db.NewClientIdentityRepo(dbMap),
|
||||||
localConnectorID: localConnectorID,
|
localConnectorID: localConnectorID,
|
||||||
emailer: emailer,
|
emailer: emailer,
|
||||||
}
|
}
|
||||||
|
@ -258,6 +265,47 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string)
|
||||||
return list, tok, nil
|
return list, tok, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListClientsWithRefreshTokens returns all clients issued refresh tokens
|
||||||
|
// for the authenticated user.
|
||||||
|
func (u *UsersAPI) ListClientsWithRefreshTokens(creds Creds, userID string) ([]*schema.RefreshClient, error) {
|
||||||
|
// Users must either be an admin or be requesting data associated with their own account.
|
||||||
|
if !creds.User.Admin && (creds.User.ID != userID) {
|
||||||
|
return nil, ErrorUnauthorized
|
||||||
|
}
|
||||||
|
clientIdentities, err := u.refreshRepo.ClientsWithRefreshTokens(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
clients := make([]*schema.RefreshClient, len(clientIdentities))
|
||||||
|
|
||||||
|
urlToString := func(u *url.URL) string {
|
||||||
|
if u == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, identity := range clientIdentities {
|
||||||
|
clients[i] = &schema.RefreshClient{
|
||||||
|
ClientID: identity.Credentials.ID,
|
||||||
|
ClientName: identity.Metadata.ClientName,
|
||||||
|
ClientURI: urlToString(identity.Metadata.ClientURI),
|
||||||
|
LogoURI: urlToString(identity.Metadata.LogoURI),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return clients, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeClient revokes all refresh tokens issued to this client for the
|
||||||
|
// authenticiated user.
|
||||||
|
func (u *UsersAPI) RevokeRefreshTokensForClient(creds Creds, userID, clientID string) error {
|
||||||
|
// Users must either be an admin or be requesting data associated with their own account.
|
||||||
|
if !creds.User.Admin && (creds.User.ID != userID) {
|
||||||
|
return ErrorUnauthorized
|
||||||
|
}
|
||||||
|
return u.refreshRepo.RevokeTokensForClient(userID, clientID)
|
||||||
|
}
|
||||||
|
|
||||||
func (u *UsersAPI) Authorize(creds Creds) bool {
|
func (u *UsersAPI) Authorize(creds Creds) bool {
|
||||||
return creds.User.Admin && !creds.User.Disabled
|
return creds.User.Admin && !creds.User.Disabled
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package api
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -10,7 +11,6 @@ import (
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
|
@ -166,16 +166,27 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cir := func() client.ClientIdentityRepo {
|
if _, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci}); err != nil {
|
||||||
repo, err := db.NewClientIdentityRepoFromClients(db.NewMemDB(), []oidc.ClientIdentity{ci})
|
|
||||||
if err != nil {
|
|
||||||
panic("Failed to create client identity repo: " + err.Error())
|
panic("Failed to create client identity repo: " + err.Error())
|
||||||
}
|
}
|
||||||
return repo
|
|
||||||
}()
|
// Used in TestRevokeRefreshToken test.
|
||||||
|
refreshTokens := []struct {
|
||||||
|
clientID string
|
||||||
|
userID string
|
||||||
|
}{
|
||||||
|
{"XXX", "ID-1"},
|
||||||
|
{"XXX", "ID-2"},
|
||||||
|
}
|
||||||
|
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||||
|
for _, token := range refreshTokens {
|
||||||
|
if _, err := refreshRepo.Create(token.userID, token.clientID); err != nil {
|
||||||
|
panic("Failed to create refresh token: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
emailer := &testEmailer{}
|
emailer := &testEmailer{}
|
||||||
api := NewUsersAPI(mgr, cir, emailer, "local")
|
api := NewUsersAPI(dbMap, mgr, emailer, "local")
|
||||||
return api, emailer
|
return api, emailer
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -562,3 +573,57 @@ func TestResendEmailInvitation(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRevokeRefreshToken(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
userID string
|
||||||
|
toRevoke string
|
||||||
|
before []string // clientIDs expected before the change.
|
||||||
|
after []string // clientIDs expected after the change.
|
||||||
|
}{
|
||||||
|
{"ID-1", "XXX", []string{"XXX"}, []string{}},
|
||||||
|
{"ID-2", "XXX", []string{"XXX"}, []string{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
api, _ := makeTestFixtures()
|
||||||
|
|
||||||
|
listClientsWithRefreshTokens := func(creds Creds, userID string) ([]string, error) {
|
||||||
|
clients, err := api.ListClientsWithRefreshTokens(creds, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
clientIDs := make([]string, len(clients))
|
||||||
|
for i, client := range clients {
|
||||||
|
clientIDs[i] = client.ClientID
|
||||||
|
}
|
||||||
|
sort.Strings(clientIDs)
|
||||||
|
return clientIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
creds := Creds{User: user.User{ID: tt.userID}}
|
||||||
|
|
||||||
|
gotBefore, err := listClientsWithRefreshTokens(creds, tt.userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: list clients failed: %v", i, err)
|
||||||
|
} else {
|
||||||
|
if diff := pretty.Compare(tt.before, gotBefore); diff != "" {
|
||||||
|
t.Errorf("case %d: before exp!=got: %s", i, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := api.RevokeRefreshTokensForClient(creds, tt.userID, tt.toRevoke); err != nil {
|
||||||
|
t.Errorf("case %d: failed to revoke client: %v", i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
gotAfter, err := listClientsWithRefreshTokens(creds, tt.userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: list clients failed: %v", i, err)
|
||||||
|
} else {
|
||||||
|
if diff := pretty.Compare(tt.after, gotAfter); diff != "" {
|
||||||
|
t.Errorf("case %d: after exp!=got: %s", i, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue