ClientCredentials flow in UserAPI

Fixes #528
This commit is contained in:
Adrián López Gómez 2016-08-01 10:17:06 +02:00
parent fa8f98acac
commit 9b8ab3bdc6
7 changed files with 347 additions and 83 deletions

View file

@ -51,6 +51,9 @@ func main() {
enableClientRegistration := fs.Bool("enable-client-registration", false, "Allow dynamic registration of clients") enableClientRegistration := fs.Bool("enable-client-registration", false, "Allow dynamic registration of clients")
// Client credentials administration
apiUseClientCredentials := fs.Bool("api-use-client-credentials", false, "Forces API to authenticate using client credentials instead of ID token. Clients must be 'admin clients' to use the API.")
noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing") noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing")
// UI-related: // UI-related:
@ -146,6 +149,7 @@ func main() {
IssuerLogoURL: *issuerLogoURL, IssuerLogoURL: *issuerLogoURL,
EnableRegistration: *enableRegistration, EnableRegistration: *enableRegistration,
EnableClientRegistration: *enableClientRegistration, EnableClientRegistration: *enableClientRegistration,
EnableClientCredentialAccess: *apiUseClientCredentials,
RegisterOnFirstLogin: *registerOnFirstLogin, RegisterOnFirstLogin: *registerOnFirstLogin,
} }

View file

@ -84,6 +84,12 @@ var (
userGoodToken = makeUserToken(testIssuerURL, userGoodToken = makeUserToken(testIssuerURL,
"ID-1", testClientID, time.Hour*1, testPrivKey) "ID-1", testClientID, time.Hour*1, testPrivKey)
clientToken = makeClientToken(testIssuerURL,
testClientID, time.Hour*1, testPrivKey)
badClientToken = makeClientToken(testIssuerURL,
userBadClientID, time.Hour*1, testPrivKey)
userBadTokenNotAdmin = makeUserToken(testIssuerURL, userBadTokenNotAdmin = makeUserToken(testIssuerURL,
"ID-2", testClientID, time.Hour*1, testPrivKey) "ID-2", testClientID, time.Hour*1, testPrivKey)
@ -97,7 +103,7 @@ var (
"ID-4", testClientID, time.Hour*1, testPrivKey) "ID-4", testClientID, time.Hour*1, testPrivKey)
) )
func makeUserAPITestFixtures() *userAPITestFixtures { func makeUserAPITestFixtures(clientCredsFlag bool) *userAPITestFixtures {
f := &userAPITestFixtures{} f := &userAPITestFixtures{}
dbMap, _, _, um := makeUserObjects(userUsers, userPasswords) dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
@ -157,8 +163,8 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f.emailer = &testEmailer{} f.emailer = &testEmailer{}
um.Clock = clock um.Clock = clock
api := api.NewUsersAPI(um, clientManager, refreshRepo, f.emailer, "local") api := api.NewUsersAPI(um, clientManager, refreshRepo, f.emailer, "local", clientCredsFlag)
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, clientManager) usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, clientManager, clientCredsFlag)
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler()) f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
f.trans = &tokenHandlerTransport{ f.trans = &tokenHandlerTransport{
@ -180,48 +186,89 @@ func TestGetUser(t *testing.T) {
token string token string
errCode int errCode int
clientCredsFlag bool
}{ }{
{ {
id: "ID-1", id: "ID-1",
token: userGoodToken, token: userGoodToken,
errCode: 0, errCode: 0,
}, {
clientCredsFlag: false,
},
{
id: "ID-1",
token: clientToken,
errCode: 0,
clientCredsFlag: true,
},
{
id: "ID-1",
token: badClientToken,
errCode: http.StatusForbidden,
clientCredsFlag: true,
},
{
id: "ID-1",
token: clientToken,
errCode: http.StatusUnauthorized,
clientCredsFlag: false,
},
{
id: "NOONE", id: "NOONE",
token: userGoodToken, token: userGoodToken,
errCode: http.StatusNotFound, errCode: http.StatusNotFound,
clientCredsFlag: false,
}, { }, {
id: "ID-1", id: "ID-1",
token: userBadTokenNotAdmin, token: userBadTokenNotAdmin,
errCode: http.StatusUnauthorized, errCode: http.StatusUnauthorized,
clientCredsFlag: false,
}, { }, {
id: "ID-1", id: "ID-1",
token: userBadTokenExpired, token: userBadTokenExpired,
errCode: http.StatusUnauthorized, errCode: http.StatusUnauthorized,
clientCredsFlag: false,
}, { }, {
id: "ID-1", id: "ID-1",
token: userBadTokenDisabled, token: userBadTokenDisabled,
errCode: http.StatusUnauthorized, errCode: http.StatusUnauthorized,
clientCredsFlag: false,
}, { }, {
id: "ID-1", id: "ID-1",
token: "", token: "",
errCode: http.StatusUnauthorized, errCode: http.StatusUnauthorized,
clientCredsFlag: false,
}, { }, {
id: "ID-1", id: "ID-1",
token: "gibberish", token: "gibberish",
errCode: http.StatusUnauthorized, errCode: http.StatusUnauthorized,
clientCredsFlag: false,
}, },
} }
for i, tt := range tests { for i, tt := range tests {
func() { func() {
f := makeUserAPITestFixtures() f := makeUserAPITestFixtures(tt.clientCredsFlag)
f.trans.Token = tt.token f.trans.Token = tt.token
defer f.close() defer f.close()
@ -318,7 +365,7 @@ func TestListUsers(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
func() { func() {
f := makeUserAPITestFixtures() f := makeUserAPITestFixtures(false)
defer f.close() defer f.close()
f.trans.Token = tt.token f.trans.Token = tt.token
@ -382,6 +429,8 @@ func TestCreateUser(t *testing.T) {
wantResponse schema.UserCreateResponse wantResponse schema.UserCreateResponse
wantCode int wantCode int
clientCredsFlag bool
}{ }{
{ {
@ -409,6 +458,53 @@ func TestCreateUser(t *testing.T) {
}, },
}, },
}, },
{
req: schema.UserCreateRequest{
User: &schema.User{
Email: "newuser@example.com",
DisplayName: "New User",
EmailVerified: true,
Admin: false,
CreatedAt: clock.Now().Format(time.RFC3339),
},
RedirectURL: testRedirectURL.String(),
},
token: clientToken,
wantResponse: schema.UserCreateResponse{
EmailSent: true,
User: &schema.User{
Email: "newuser@example.com",
DisplayName: "New User",
EmailVerified: true,
Admin: false,
CreatedAt: clock.Now().Format(time.RFC3339),
},
},
clientCredsFlag: true,
},
{
req: schema.UserCreateRequest{
User: &schema.User{
Email: "newuser@example.com",
DisplayName: "New User",
EmailVerified: true,
Admin: false,
CreatedAt: clock.Now().Format(time.RFC3339),
},
RedirectURL: testRedirectURL.String(),
},
token: badClientToken,
wantCode: http.StatusForbidden,
clientCredsFlag: true,
},
{ {
// Duplicate email // Duplicate email
@ -488,6 +584,28 @@ func TestCreateUser(t *testing.T) {
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
{
req: schema.UserCreateRequest{
User: &schema.User{
Email: "newuser@example.com",
DisplayName: "New User",
EmailVerified: true,
Admin: false,
CreatedAt: clock.Now().Format(time.RFC3339),
},
RedirectURL: testRedirectURL.String(),
},
// make sure that the endpoint is protected, but don't exhaustively
// try every variation like in TestGetUser
token: clientToken,
wantCode: http.StatusUnauthorized,
clientCredsFlag: false,
},
{ {
req: schema.UserCreateRequest{ req: schema.UserCreateRequest{
User: &schema.User{ User: &schema.User{
@ -507,7 +625,7 @@ func TestCreateUser(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
func() { func() {
f := makeUserAPITestFixtures() f := makeUserAPITestFixtures(tt.clientCredsFlag)
defer f.close() defer f.close()
f.trans.Token = tt.token f.trans.Token = tt.token
f.emailer.cantEmail = tt.cantEmail f.emailer.cantEmail = tt.cantEmail
@ -588,7 +706,7 @@ func TestDisableUser(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
f := makeUserAPITestFixtures() f := makeUserAPITestFixtures(false)
usr, err := f.client.Users.Get(tt.id).Do() usr, err := f.client.Users.Get(tt.id).Do()
if err != nil { if err != nil {
@ -625,7 +743,7 @@ func TestRefreshTokenEndpoints(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
f := makeUserAPITestFixtures() f := makeUserAPITestFixtures(false)
list, err := f.client.RefreshClient.List(tt.userID).Do() list, err := f.client.RefreshClient.List(tt.userID).Do()
if err != nil { if err != nil {
t.Errorf("case %d: list clients: %v", i, err) t.Errorf("case %d: list clients: %v", i, err)
@ -666,6 +784,8 @@ func TestResendEmailInvitation(t *testing.T) {
wantResponse schema.ResendEmailInvitationResponse wantResponse schema.ResendEmailInvitationResponse
wantCode int wantCode int
clientCredsFlag bool
}{ }{
{ {
@ -687,6 +807,36 @@ func TestResendEmailInvitation(t *testing.T) {
RedirectURL: testRedirectURL.String(), RedirectURL: testRedirectURL.String(),
}, },
userID: "ID-3",
email: "Email-3@example.com",
token: clientToken,
wantResponse: schema.ResendEmailInvitationResponse{
EmailSent: true,
},
clientCredsFlag: true,
},
{
req: schema.ResendEmailInvitationRequest{
RedirectURL: testRedirectURL.String(),
},
userID: "ID-3",
email: "Email-3@example.com",
token: badClientToken,
wantCode: http.StatusForbidden,
clientCredsFlag: true,
},
{
req: schema.ResendEmailInvitationRequest{
RedirectURL: testRedirectURL.String(),
},
userID: "ID-3", userID: "ID-3",
email: "Email-3@example.com", email: "Email-3@example.com",
cantEmail: true, cantEmail: true,
@ -747,6 +897,19 @@ func TestResendEmailInvitation(t *testing.T) {
RedirectURL: testRedirectURL.String(), RedirectURL: testRedirectURL.String(),
}, },
userID: "ID-3",
email: "Email-3@example.com",
token: clientToken,
wantCode: http.StatusUnauthorized,
clientCredsFlag: false,
},
{
req: schema.ResendEmailInvitationRequest{
RedirectURL: testRedirectURL.String(),
},
userID: "ID-3", userID: "ID-3",
email: "Email-3@example.com", email: "Email-3@example.com",
token: userBadTokenExpired, token: userBadTokenExpired,
@ -778,7 +941,7 @@ func TestResendEmailInvitation(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
func() { func() {
f := makeUserAPITestFixtures() f := makeUserAPITestFixtures(tt.clientCredsFlag)
defer f.close() defer f.close()
f.trans.Token = tt.token f.trans.Token = tt.token
f.emailer.cantEmail = tt.cantEmail f.emailer.cantEmail = tt.cantEmail
@ -869,6 +1032,10 @@ func (t *testEmailer) SendInviteEmail(email string, redirectURL url.URL, clientI
return retURL, nil return retURL, nil
} }
func makeClientToken(issuerURL url.URL, clientID string, expires time.Duration, privKey *key.PrivateKey) string {
return makeUserToken(issuerURL, clientID, clientID, expires, privKey)
}
func makeUserToken(issuerURL url.URL, userID, clientID string, expires time.Duration, privKey *key.PrivateKey) string { func makeUserToken(issuerURL url.URL, userID, clientID string, expires time.Duration, privKey *key.PrivateKey) string {
signer := key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, signer := key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey},

View file

@ -39,6 +39,7 @@ type ServerConfig struct {
StateConfig StateConfigurer StateConfig StateConfigurer
EnableRegistration bool EnableRegistration bool
EnableClientRegistration bool EnableClientRegistration bool
EnableClientCredentialAccess bool
RegisterOnFirstLogin bool RegisterOnFirstLogin bool
} }
@ -80,6 +81,7 @@ func (cfg *ServerConfig) Server() (*Server, error) {
EnableRegistration: cfg.EnableRegistration, EnableRegistration: cfg.EnableRegistration,
EnableClientRegistration: cfg.EnableClientRegistration, EnableClientRegistration: cfg.EnableClientRegistration,
EnableClientCredentialAccess: cfg.EnableClientCredentialAccess,
RegisterOnFirstLogin: cfg.RegisterOnFirstLogin, RegisterOnFirstLogin: cfg.RegisterOnFirstLogin,
} }

View file

@ -95,6 +95,7 @@ type Server struct {
EnableRegistration bool EnableRegistration bool
EnableClientRegistration bool EnableClientRegistration bool
EnableClientCredentialAccess bool
RegisterOnFirstLogin bool RegisterOnFirstLogin bool
dbMap *gorp.DbMap dbMap *gorp.DbMap
@ -300,8 +301,8 @@ func (s *Server) HTTPHandler() http.Handler {
apiBasePath := path.Join(httpPathAPI, APIVersion) apiBasePath := path.Join(httpPathAPI, APIVersion)
registerDiscoveryResource(apiBasePath, mux) registerDiscoveryResource(apiBasePath, mux)
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientManager, s.RefreshTokenRepo, s.UserEmailer, s.localConnectorID) usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientManager, s.RefreshTokenRepo, s.UserEmailer, s.localConnectorID, s.EnableClientCredentialAccess)
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientManager).HTTPHandler() handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientManager, s.EnableClientCredentialAccess).HTTPHandler()
handleStripPrefix(apiBasePath+"/", handler) handleStripPrefix(apiBasePath+"/", handler)

View file

@ -39,14 +39,16 @@ type UserMgmtServer struct {
jwtvFactory JWTVerifierFactory jwtvFactory JWTVerifierFactory
um *usermanager.UserManager um *usermanager.UserManager
cm *clientmanager.ClientManager cm *clientmanager.ClientManager
allowClientCredsAuth bool
} }
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *usermanager.UserManager, cm *clientmanager.ClientManager) *UserMgmtServer { func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *usermanager.UserManager, cm *clientmanager.ClientManager, allowClientCredsAuth bool) *UserMgmtServer {
return &UserMgmtServer{ return &UserMgmtServer{
api: userMgmtAPI, api: userMgmtAPI,
jwtvFactory: jwtvFactory, jwtvFactory: jwtvFactory,
um: um, um: um,
cm: cm, cm: cm,
allowClientCredsAuth: allowClientCredsAuth,
} }
} }
@ -92,7 +94,7 @@ func (s *UserMgmtServer) authAPIHandle(handle authedHandle, requiresAdmin bool)
s.writeError(w, err) s.writeError(w, err)
return return
} }
if creds.User.Disabled || (requiresAdmin && !creds.User.Admin) { if !s.allowClientCredsAuth && (creds.User.Disabled || (requiresAdmin && !creds.User.Admin)) {
s.writeError(w, api.ErrorUnauthorized) s.writeError(w, api.ErrorUnauthorized)
return return
} }
@ -299,6 +301,20 @@ func (s *UserMgmtServer) getCreds(r *http.Request, requiresAdmin bool) (api.Cred
return api.Creds{}, api.ErrorUnauthorized return api.Creds{}, api.ErrorUnauthorized
} }
if s.allowClientCredsAuth && (len(clientIDs) == 1) && (sub == clientIDs[0]) {
isAdmin, err := s.cm.IsDexAdmin(clientIDs[0])
if err != nil {
log.Errorf("userMgmtServer: GetCreds err: %q", err)
return api.Creds{}, err
}
if requiresAdmin && !isAdmin {
return api.Creds{}, api.ErrorForbidden
}
return api.Creds{
ClientIDs: clientIDs,
}, nil
}
usr, err := s.um.Get(sub) usr, err := s.um.Get(sub)
if err != nil { if err != nil {
if err == user.ErrorNotFound { if err == user.ErrorNotFound {

View file

@ -91,6 +91,7 @@ type UsersAPI struct {
clientManager *clientmanager.ClientManager clientManager *clientmanager.ClientManager
refreshRepo refresh.RefreshTokenRepo refreshRepo refresh.RefreshTokenRepo
emailer Emailer emailer Emailer
allowClientCreds bool
} }
type Emailer interface { type Emailer interface {
@ -104,19 +105,19 @@ type Creds struct {
} }
// TODO(ericchiang): Don't pass a dbMap. See #385. // TODO(ericchiang): Don't pass a dbMap. See #385.
func NewUsersAPI(userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, refreshRepo refresh.RefreshTokenRepo, emailer Emailer, localConnectorID string) *UsersAPI { func NewUsersAPI(userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, refreshRepo refresh.RefreshTokenRepo, emailer Emailer, localConnectorID string, allowClientCreds bool) *UsersAPI {
return &UsersAPI{ return &UsersAPI{
userManager: userManager, userManager: userManager,
refreshRepo: refreshRepo, refreshRepo: refreshRepo,
clientManager: clientManager, clientManager: clientManager,
localConnectorID: localConnectorID, localConnectorID: localConnectorID,
emailer: emailer, emailer: emailer,
allowClientCreds: allowClientCreds,
} }
} }
func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) { func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
log.Infof("userAPI: GetUser") log.Infof("userAPI: GetUser")
if !u.Authorize(creds) { if !u.Authorize(creds) {
return schema.User{}, ErrorUnauthorized return schema.User{}, ErrorUnauthorized
} }
@ -312,6 +313,11 @@ func (u *UsersAPI) RevokeRefreshTokensForClient(creds Creds, userID, clientID st
} }
func (u *UsersAPI) Authorize(creds Creds) bool { func (u *UsersAPI) Authorize(creds Creds) bool {
if u.allowClientCreds {
if creds.User.ID == "" {
return true
}
}
return creds.User.Admin && !creds.User.Disabled return creds.User.Admin && !creds.User.Disabled
} }

View file

@ -63,6 +63,14 @@ var (
ClientIDs: []string{goodClientID}, ClientIDs: []string{goodClientID},
} }
clientCreds = Creds{
User: user.User{
ID: "",
Admin: false,
},
ClientIDs: []string{goodClientID},
}
badCreds = Creds{ badCreds = Creds{
User: user.User{ User: user.User{
ID: "ID-2", ID: "ID-2",
@ -103,7 +111,7 @@ var (
} }
) )
func makeTestFixtures() (*UsersAPI, *testEmailer) { func makeTestFixtures(clientCredsFlag bool) (*UsersAPI, *testEmailer) {
dbMap := db.NewMemDB() dbMap := db.NewMemDB()
ur := func() user.UserRepo { ur := func() user.UserRepo {
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{ repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
@ -223,7 +231,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
} }
emailer := &testEmailer{} emailer := &testEmailer{}
api := NewUsersAPI(mgr, clientManager, refreshRepo, emailer, "local") api := NewUsersAPI(mgr, clientManager, refreshRepo, emailer, "local", clientCredsFlag)
return api, emailer return api, emailer
} }
@ -233,29 +241,39 @@ func TestGetUser(t *testing.T) {
creds Creds creds Creds
id string id string
wantErr error wantErr error
clientCredsFlag bool
}{ }{
{ {
creds: goodCreds, creds: goodCreds,
id: "ID-1", id: "ID-1",
clientCredsFlag: false,
}, },
{ {
creds: badCreds, creds: badCreds,
id: "ID-1", id: "ID-1",
wantErr: ErrorUnauthorized, wantErr: ErrorUnauthorized,
clientCredsFlag: false,
}, },
{ {
creds: goodCreds, creds: goodCreds,
id: "NO_ID", id: "NO_ID",
wantErr: ErrorResourceNotFound, wantErr: ErrorResourceNotFound,
clientCredsFlag: false,
}, },
{ {
creds: credsWithMultipleAudiences, creds: credsWithMultipleAudiences,
id: "ID-1", id: "ID-1",
clientCredsFlag: false,
},
{
creds: clientCreds,
id: "ID-1",
clientCredsFlag: true,
}, },
} }
for i, tt := range tests { for i, tt := range tests {
api, _ := makeTestFixtures() api, _ := makeTestFixtures(tt.clientCredsFlag)
usr, err := api.GetUser(tt.creds, tt.id) usr, err := api.GetUser(tt.creds, tt.id)
if tt.wantErr != nil { if tt.wantErr != nil {
if err != tt.wantErr { if err != tt.wantErr {
@ -309,7 +327,7 @@ func TestListUsers(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
api, _ := makeTestFixtures() api, _ := makeTestFixtures(false)
gotIDs := [][]string{} gotIDs := [][]string{}
var next string var next string
@ -346,6 +364,8 @@ func TestCreateUser(t *testing.T) {
redirURL url.URL redirURL url.URL
cantEmail bool cantEmail bool
clientCredsFlag bool
wantResponse schema.UserCreateResponse wantResponse schema.UserCreateResponse
wantClientID string wantClientID string
wantErr error wantErr error
@ -371,6 +391,30 @@ func TestCreateUser(t *testing.T) {
}, },
}, },
wantClientID: goodClientID, wantClientID: goodClientID,
clientCredsFlag: false,
},
{
creds: clientCreds,
usr: schema.User{
Email: "newuser01@example.com",
DisplayName: "New User",
EmailVerified: true,
Admin: false,
},
redirURL: validRedirURL,
wantResponse: schema.UserCreateResponse{
EmailSent: true,
User: &schema.User{
Email: "newuser01@example.com",
DisplayName: "New User",
EmailVerified: true,
Admin: false,
CreatedAt: clock.Now().Format(time.RFC3339),
},
},
wantClientID: goodClientID,
clientCredsFlag: true,
}, },
{ {
creds: credsWithMultipleAudiences, creds: credsWithMultipleAudiences,
@ -393,6 +437,7 @@ func TestCreateUser(t *testing.T) {
}, },
}, },
wantClientID: goodClientID, wantClientID: goodClientID,
clientCredsFlag: false,
}, },
{ {
creds: goodCreds, creds: goodCreds,
@ -416,6 +461,7 @@ func TestCreateUser(t *testing.T) {
ResetPasswordLink: resetPasswordURL.String(), ResetPasswordLink: resetPasswordURL.String(),
}, },
wantClientID: goodClientID, wantClientID: goodClientID,
clientCredsFlag: false,
}, },
{ {
creds: goodCreds, creds: goodCreds,
@ -440,11 +486,12 @@ func TestCreateUser(t *testing.T) {
redirURL: validRedirURL, redirURL: validRedirURL,
wantErr: ErrorUnauthorized, wantErr: ErrorUnauthorized,
clientCredsFlag: false,
}, },
} }
for i, tt := range tests { for i, tt := range tests {
api, emailer := makeTestFixtures() api, emailer := makeTestFixtures(tt.clientCredsFlag)
emailer.cantEmail = tt.cantEmail emailer.cantEmail = tt.cantEmail
response, err := api.CreateUser(tt.creds, tt.usr, tt.redirURL) response, err := api.CreateUser(tt.creds, tt.usr, tt.redirURL)
@ -528,7 +575,7 @@ func TestDisableUsers(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
api, _ := makeTestFixtures() api, _ := makeTestFixtures(false)
_, err := api.DisableUser(goodCreds, tt.id, tt.disable) _, err := api.DisableUser(goodCreds, tt.id, tt.disable)
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Fatalf("case %d: unexpected error: %v", i, err)
@ -552,6 +599,8 @@ func TestResendEmailInvitation(t *testing.T) {
redirURL url.URL redirURL url.URL
cantEmail bool cantEmail bool
clientCredsFlag bool
wantResponse schema.ResendEmailInvitationResponse wantResponse schema.ResendEmailInvitationResponse
wantErr error wantErr error
wantClientID string wantClientID string
@ -566,6 +615,19 @@ func TestResendEmailInvitation(t *testing.T) {
EmailSent: true, EmailSent: true,
}, },
wantClientID: goodClientID, wantClientID: goodClientID,
clientCredsFlag: false,
},
{
creds: clientCreds,
userID: "ID-1",
email: "id1@example.com",
redirURL: validRedirURL,
wantResponse: schema.ResendEmailInvitationResponse{
EmailSent: true,
},
wantClientID: goodClientID,
clientCredsFlag: true,
}, },
{ {
creds: goodCreds, creds: goodCreds,
@ -579,6 +641,7 @@ func TestResendEmailInvitation(t *testing.T) {
ResetPasswordLink: resetPasswordURL.String(), ResetPasswordLink: resetPasswordURL.String(),
}, },
wantClientID: goodClientID, wantClientID: goodClientID,
clientCredsFlag: false,
}, },
{ {
creds: credsWithMultipleAudiences, creds: credsWithMultipleAudiences,
@ -592,6 +655,7 @@ func TestResendEmailInvitation(t *testing.T) {
ResetPasswordLink: resetPasswordURL.String(), ResetPasswordLink: resetPasswordURL.String(),
}, },
wantClientID: goodClientID, wantClientID: goodClientID,
clientCredsFlag: false,
}, },
{ {
creds: badCreds, creds: badCreds,
@ -600,6 +664,7 @@ func TestResendEmailInvitation(t *testing.T) {
redirURL: validRedirURL, redirURL: validRedirURL,
wantErr: ErrorUnauthorized, wantErr: ErrorUnauthorized,
clientCredsFlag: false,
}, },
{ {
creds: goodCreds, creds: goodCreds,
@ -608,6 +673,7 @@ func TestResendEmailInvitation(t *testing.T) {
redirURL: url.URL{Host: "scammers.com"}, redirURL: url.URL{Host: "scammers.com"},
wantErr: ErrorInvalidRedirectURL, wantErr: ErrorInvalidRedirectURL,
clientCredsFlag: false,
}, },
{ {
creds: goodCreds, creds: goodCreds,
@ -616,6 +682,7 @@ func TestResendEmailInvitation(t *testing.T) {
redirURL: validRedirURL, redirURL: validRedirURL,
wantErr: ErrorVerifiedEmail, wantErr: ErrorVerifiedEmail,
clientCredsFlag: false,
}, },
{ {
creds: goodCreds, creds: goodCreds,
@ -624,11 +691,12 @@ func TestResendEmailInvitation(t *testing.T) {
redirURL: validRedirURL, redirURL: validRedirURL,
wantErr: ErrorResourceNotFound, wantErr: ErrorResourceNotFound,
clientCredsFlag: false,
}, },
} }
for i, tt := range tests { for i, tt := range tests {
api, emailer := makeTestFixtures() api, emailer := makeTestFixtures(tt.clientCredsFlag)
emailer.cantEmail = tt.cantEmail emailer.cantEmail = tt.cantEmail
response, err := api.ResendEmailInvitation(tt.creds, tt.userID, tt.redirURL) response, err := api.ResendEmailInvitation(tt.creds, tt.userID, tt.redirURL)
@ -670,7 +738,7 @@ func TestRevokeRefreshToken(t *testing.T) {
{"ID-2", goodClientID, []string{goodClientID}, []string{}}, {"ID-2", goodClientID, []string{goodClientID}, []string{}},
} }
api, _ := makeTestFixtures() api, _ := makeTestFixtures(false)
listClientsWithRefreshTokens := func(creds Creds, userID string) ([]string, error) { listClientsWithRefreshTokens := func(creds Creds, userID string) ([]string, error) {
clients, err := api.ListClientsWithRefreshTokens(creds, userID) clients, err := api.ListClientsWithRefreshTokens(creds, userID)