Merge pull request #140 from joeatwork/disable-users-api

Expose API to enable and disable users
This commit is contained in:
Joe Bowers 2015-09-29 16:47:43 -07:00
commit a426943054
13 changed files with 442 additions and 12 deletions

View file

@ -97,10 +97,29 @@ func (r *userRepo) Create(tx repo.Transaction, usr user.User) (err error) {
} }
err = r.insert(tx, usr) err = r.insert(tx, usr)
return err
}
func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) error {
if userID == "" {
return user.ErrorInvalidID
}
qt := pq.QuoteIdentifier(userTableName)
ex := r.executor(tx)
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable)
if err != nil { if err != nil {
return err return err
} }
ct, err := result.RowsAffected()
switch {
case err != nil:
return err
case ct == 0:
return user.ErrorNotFound
}
return nil return nil
} }

View file

@ -35,6 +35,7 @@ var (
ID: "ID-2", ID: "ID-2",
Email: "Email-2@example.com", Email: "Email-2@example.com",
CreatedAt: time.Now(), CreatedAt: time.Now(),
Disabled: true,
}, },
RemoteIdentities: []user.RemoteIdentity{ RemoteIdentities: []user.RemoteIdentity{
{ {
@ -232,6 +233,61 @@ func TestUpdateUser(t *testing.T) {
} }
} }
func TestDisableUser(t *testing.T) {
tests := []struct {
id string
disable bool
err error
}{
{
id: "ID-1",
},
{
id: "ID-1",
disable: true,
},
{
id: "ID-2",
},
{
id: "ID-2",
disable: true,
},
{
id: "NO SUCH ID",
err: user.ErrorNotFound,
},
{
id: "NO SUCH ID",
err: user.ErrorNotFound,
disable: true,
},
{
id: "",
err: user.ErrorInvalidID,
},
}
for i, tt := range tests {
repo := makeTestUserRepo()
err := repo.Disable(nil, tt.id, tt.disable)
switch {
case err != tt.err:
t.Errorf("case %d: want=%q, got=%q", i, tt.err, err)
case tt.err == nil:
gotUser, err := repo.Get(nil, tt.id)
if err != nil {
t.Fatalf("case %d: want nil err, got %q", i, err)
}
if gotUser.Disabled != tt.disable {
t.Errorf("case %d: disabled status want=%v got=%v",
i, tt.disable, gotUser.Disabled)
}
}
}
}
func TestAttachRemoteIdentity(t *testing.T) { func TestAttachRemoteIdentity(t *testing.T) {
tests := []struct { tests := []struct {
id string id string

View file

@ -141,6 +141,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f.trans = &tokenHandlerTransport{ f.trans = &tokenHandlerTransport{
Handler: usrSrv.HTTPHandler(), Handler: usrSrv.HTTPHandler(),
Token: userGoodToken,
} }
hc := &http.Client{ hc := &http.Client{
Transport: f.trans, Transport: f.trans,
@ -530,6 +531,48 @@ func TestCreateUser(t *testing.T) {
} }
} }
func TestDisableUser(t *testing.T) {
tests := []struct {
id string
disable bool
}{
{
id: "ID-2",
disable: true,
},
{
id: "ID-4",
disable: false,
},
}
for i, tt := range tests {
f := makeUserAPITestFixtures()
usr, err := f.client.Users.Get(tt.id).Do()
if err != nil {
t.Fatalf("case %v: unexpected error: %v", i, err)
}
if usr.User.Disabled == tt.disable {
t.Fatalf("case %v: misconfigured test, initial disabled state should be %v but was %v", i, !tt.disable, usr.User.Disabled)
}
_, err = f.client.Users.Disable(tt.id, &schema.UserDisableRequest{
Disable: tt.disable,
}).Do()
if err != nil {
t.Fatalf("case %v: unexpected error: %v", i, err)
}
usr, err = f.client.Users.Get(tt.id).Do()
if err != nil {
t.Fatalf("case %v: unexpected error: %v", i, err)
}
if usr.User.Disabled != tt.disable {
t.Errorf("case %v: user disabled state incorrect. wanted: %v found: %v", i, tt.disable, usr.User.Disabled)
}
}
}
type testEmailer struct { type testEmailer struct {
cantEmail bool cantEmail bool
lastEmail string lastEmail string

View file

@ -4,7 +4,7 @@
// //
// Usage example: // Usage example:
// //
// import "github.com/coreos/dex/Godeps/_workspace/src/google.golang.org/api/adminschema/v1" // import "google.golang.org/api/adminschema/v1"
// ... // ...
// adminschemaService, err := adminschema.New(oauthHttpClient) // adminschemaService, err := adminschema.New(oauthHttpClient)
package adminschema package adminschema

View file

@ -108,6 +108,8 @@ type User struct {
CreatedAt string `json:"createdAt,omitempty"` CreatedAt string `json:"createdAt,omitempty"`
Disabled bool `json:"disabled,omitempty"`
DisplayName string `json:"displayName,omitempty"` DisplayName string `json:"displayName,omitempty"`
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
@ -134,6 +136,15 @@ type UserCreateResponse struct {
type UserCreateResponseUser struct { type UserCreateResponseUser struct {
} }
type UserDisableRequest struct {
// Disable: If true, disable this user, if false, enable them
Disable bool `json:"disable,omitempty"`
}
type UserDisableResponse struct {
Ok bool `json:"ok,omitempty"`
}
type UserResponse struct { type UserResponse struct {
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
} }
@ -355,6 +366,89 @@ func (c *UsersCreateCall) Do() (*UserCreateResponse, error) {
} }
// method id "dex.User.Disable":
type UsersDisableCall struct {
s *Service
id string
userdisablerequest *UserDisableRequest
opt_ map[string]interface{}
}
// Disable: Enable or disable a user.
func (r *UsersService) Disable(id string, userdisablerequest *UserDisableRequest) *UsersDisableCall {
c := &UsersDisableCall{s: r.s, opt_: make(map[string]interface{})}
c.id = id
c.userdisablerequest = userdisablerequest
return c
}
// Fields allows partial responses to be retrieved.
// See https://developers.google.com/gdata/docs/2.0/basics#PartialResponse
// for more information.
func (c *UsersDisableCall) Fields(s ...googleapi.Field) *UsersDisableCall {
c.opt_["fields"] = googleapi.CombineFields(s)
return c
}
func (c *UsersDisableCall) Do() (*UserDisableResponse, error) {
var body io.Reader = nil
body, err := googleapi.WithoutDataWrapper.JSONReader(c.userdisablerequest)
if err != nil {
return nil, err
}
ctype := "application/json"
params := make(url.Values)
params.Set("alt", "json")
if v, ok := c.opt_["fields"]; ok {
params.Set("fields", fmt.Sprintf("%v", v))
}
urls := googleapi.ResolveRelative(c.s.BasePath, "users/{id}/disable")
urls += "?" + params.Encode()
req, _ := http.NewRequest("POST", urls, body)
googleapi.Expand(req.URL, map[string]string{
"id": c.id,
})
req.Header.Set("Content-Type", ctype)
req.Header.Set("User-Agent", "google-api-go-client/0.5")
res, err := c.s.client.Do(req)
if err != nil {
return nil, err
}
defer googleapi.CloseBody(res)
if err := googleapi.CheckResponse(res); err != nil {
return nil, err
}
var ret *UserDisableResponse
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
return nil, err
}
return ret, nil
// {
// "description": "Enable or disable a user.",
// "httpMethod": "POST",
// "id": "dex.User.Disable",
// "parameterOrder": [
// "id"
// ],
// "parameters": {
// "id": {
// "location": "path",
// "required": true,
// "type": "string"
// }
// },
// "path": "users/{id}/disable",
// "request": {
// "$ref": "UserDisableRequest"
// },
// "response": {
// "$ref": "UserDisableResponse"
// }
// }
}
// method id "dex.User.Get": // method id "dex.User.Get":
type UsersGetCall struct { type UsersGetCall struct {
@ -363,7 +457,7 @@ type UsersGetCall struct {
opt_ map[string]interface{} opt_ map[string]interface{}
} }
// Get: Get a single use object. // Get: Get a single User object by id.
func (r *UsersService) Get(id string) *UsersGetCall { func (r *UsersService) Get(id string) *UsersGetCall {
c := &UsersGetCall{s: r.s, opt_: make(map[string]interface{})} c := &UsersGetCall{s: r.s, opt_: make(map[string]interface{})}
c.id = id c.id = id
@ -406,7 +500,7 @@ func (c *UsersGetCall) Do() (*UserResponse, error) {
} }
return ret, nil return ret, nil
// { // {
// "description": "Get a single use object.", // "description": "Get a single User object by id.",
// "httpMethod": "GET", // "httpMethod": "GET",
// "id": "dex.User.Get", // "id": "dex.User.Get",
// "parameterOrder": [ // "parameterOrder": [

View file

@ -1,5 +1,4 @@
package workerschema package workerschema
// //
// This file is automatically generated by schema/generator // This file is automatically generated by schema/generator
// //
@ -109,6 +108,9 @@ const DiscoveryJSON = `{
"admin": { "admin": {
"type": "boolean" "type": "boolean"
}, },
"disabled": {
"type": "boolean"
},
"createdAt": { "createdAt": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"
@ -167,6 +169,25 @@ const DiscoveryJSON = `{
"type": "boolean" "type": "boolean"
} }
} }
},
"UserDisableRequest": {
"id": "UserDisableRequest",
"type": "object",
"properties": {
"disable": {
"type": "boolean",
"description": "If true, disable this user, if false, enable them"
}
}
},
"UserDisableResponse": {
"id": "UserDisableResponse",
"type": "object",
"properties": {
"ok": {
"type": "boolean"
}
}
} }
}, },
"resources": { "resources": {
@ -224,7 +245,7 @@ const DiscoveryJSON = `{
}, },
"Get": { "Get": {
"id": "dex.User.Get", "id": "dex.User.Get",
"description": "Get a single use object.", "description": "Get a single User object by id.",
"httpMethod": "GET", "httpMethod": "GET",
"path": "users/{id}", "path": "users/{id}",
"parameters": { "parameters": {
@ -252,6 +273,28 @@ const DiscoveryJSON = `{
"response": { "response": {
"$ref": "UserCreateResponse" "$ref": "UserCreateResponse"
} }
},
"Disable": {
"id": "dex.User.Disable",
"description": "Enable or disable a user.",
"httpMethod": "POST",
"path": "users/{id}/disable",
"parameters": {
"id": {
"type": "string",
"required": true,
"location": "path"
}
},
"parameterOrder": [
"id"
],
"request": {
"$ref": "UserDisableRequest"
},
"response": {
"$ref": "UserDisableResponse"
}
} }
} }
} }

View file

@ -102,6 +102,9 @@
"admin": { "admin": {
"type": "boolean" "type": "boolean"
}, },
"disabled": {
"type": "boolean"
},
"createdAt": { "createdAt": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"
@ -160,6 +163,25 @@
"type": "boolean" "type": "boolean"
} }
} }
},
"UserDisableRequest": {
"id": "UserDisableRequest",
"type": "object",
"properties": {
"disable": {
"type": "boolean",
"description": "If true, disable this user, if false, enable them. No error is signaled if the user state doesn't change."
}
}
},
"UserDisableResponse": {
"id": "UserDisableResponse",
"type": "object",
"properties": {
"ok": {
"type": "boolean"
}
}
} }
}, },
"resources": { "resources": {
@ -217,7 +239,7 @@
}, },
"Get": { "Get": {
"id": "dex.User.Get", "id": "dex.User.Get",
"description": "Get a single use object.", "description": "Get a single User object by id.",
"httpMethod": "GET", "httpMethod": "GET",
"path": "users/{id}", "path": "users/{id}",
"parameters": { "parameters": {
@ -245,6 +267,28 @@
"response": { "response": {
"$ref": "UserCreateResponse" "$ref": "UserCreateResponse"
} }
},
"Disable": {
"id": "dex.User.Disable",
"description": "Enable or disable a user.",
"httpMethod": "POST",
"path": "users/{id}/disable",
"parameters": {
"id": {
"type": "string",
"required": true,
"location": "path"
}
},
"parameterOrder": [
"id"
],
"request": {
"$ref": "UserDisableRequest"
},
"response": {
"$ref": "UserDisableResponse"
}
} }
} }
} }

View file

@ -244,7 +244,11 @@ func (s *Server) HTTPHandler() http.Handler {
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler)) mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID) usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID)
mux.Handle(path.Join(apiBasePath, UsersSubTree), NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()) handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()
path := path.Join(apiBasePath, UsersSubTree)
mux.Handle(path, handler)
mux.Handle(path+"/", handler)
return http.Handler(mux) return http.Handler(mux)
} }

View file

@ -27,6 +27,7 @@ var (
UsersListEndpoint = addBasePath(UsersSubTree) UsersListEndpoint = addBasePath(UsersSubTree)
UsersCreateEndpoint = addBasePath(UsersSubTree) UsersCreateEndpoint = addBasePath(UsersSubTree)
UsersGetEndpoint = addBasePath(UsersSubTree + "/:id") UsersGetEndpoint = addBasePath(UsersSubTree + "/:id")
UsersDisableEndpoint = addBasePath(UsersSubTree + "/:id/disable")
) )
type UserMgmtServer struct { type UserMgmtServer struct {
@ -51,6 +52,7 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler {
r.RedirectFixedPath = false r.RedirectFixedPath = false
r.GET(UsersListEndpoint, s.listUsers) r.GET(UsersListEndpoint, s.listUsers)
r.POST(UsersCreateEndpoint, s.createUser) r.POST(UsersCreateEndpoint, s.createUser)
r.POST(UsersDisableEndpoint, s.disableUser)
r.GET(UsersGetEndpoint, s.getUser) r.GET(UsersGetEndpoint, s.getUser)
return r return r
} }
@ -140,6 +142,35 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h
writeResponseWithBody(w, http.StatusOK, createdResponse) writeResponseWithBody(w, http.StatusOK, createdResponse)
} }
func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
id := ps.ByName("id")
if id == "" {
writeAPIError(w, http.StatusBadRequest,
newAPIError(errorInvalidRequest, "id is required"))
return
}
disableReq := schema.UserDisableRequest{}
err = json.NewDecoder(r.Body).Decode(&disableReq)
if err != nil {
writeInvalidRequest(w, "cannot parse JSON body")
}
resp, err := s.api.DisableUser(creds, id, disableReq.Disable)
if err != nil {
s.writeError(w, err)
return
}
writeResponseWithBody(w, http.StatusOK, resp)
}
func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) { func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) {
log.Errorf("Error calling user management API: %v: ", err) log.Errorf("Error calling user management API: %v: ", err)
if apiErr, ok := err.(api.Error); ok { if apiErr, ok := err.(api.Error); ok {

View file

@ -121,6 +121,21 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
return userToSchemaUser(usr), nil return userToSchemaUser(usr), nil
} }
func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema.UserDisableResponse, error) {
log.Infof("userAPI: DisableUser")
if !u.Authorize(creds) {
return schema.UserDisableResponse{}, ErrorUnauthorized
}
if err := u.manager.Disable(userID, disable); err != nil {
return schema.UserDisableResponse{}, mapError(err)
}
return schema.UserDisableResponse{
Ok: true,
}, nil
}
func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (schema.UserCreateResponse, error) { func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (schema.UserCreateResponse, error) {
log.Infof("userAPI: CreateUser") log.Infof("userAPI: CreateUser")
if !u.Authorize(creds) { if !u.Authorize(creds) {
@ -207,6 +222,7 @@ func userToSchemaUser(usr user.User) schema.User {
EmailVerified: usr.EmailVerified, EmailVerified: usr.EmailVerified,
DisplayName: usr.DisplayName, DisplayName: usr.DisplayName,
Admin: usr.Admin, Admin: usr.Admin,
Disabled: usr.Disabled,
CreatedAt: usr.CreatedAt.UTC().Format(time.RFC3339), CreatedAt: usr.CreatedAt.UTC().Format(time.RFC3339),
} }
} }
@ -218,6 +234,7 @@ func schemaUserToUser(usr schema.User) user.User {
EmailVerified: usr.EmailVerified, EmailVerified: usr.EmailVerified,
DisplayName: usr.DisplayName, DisplayName: usr.DisplayName,
Admin: usr.Admin, Admin: usr.Admin,
Disabled: usr.Disabled,
} }
} }

View file

@ -94,6 +94,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
Email: "id3@example.com", Email: "id3@example.com",
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
}, },
}, {
User: user.User{
ID: "ID-4",
Email: "id4@example.com",
CreatedAt: clock.Now(),
Disabled: true,
},
}, },
}) })
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
@ -369,3 +376,44 @@ func TestCreateUser(t *testing.T) {
} }
} }
} }
func TestDisableUsers(t *testing.T) {
tests := []struct {
id string
disable bool
}{
{
id: "ID-1",
disable: true,
},
{
id: "ID-1",
disable: false,
},
{
id: "ID-4",
disable: true,
},
{
id: "ID-4",
disable: false,
},
}
for i, tt := range tests {
api, _ := makeTestFixtures()
_, err := api.DisableUser(goodCreds, tt.id, tt.disable)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
usr, err := api.GetUser(goodCreds, tt.id)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
if usr.Disabled != tt.disable {
t.Errorf("case %d: user disable state wrong. wanted: %v got: %v", i, tt.disable, usr.Disabled)
}
}
}

View file

@ -102,6 +102,22 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string)
return user.ID, nil return user.ID, nil
} }
func (m *Manager) Disable(userID string, disabled bool) error {
tx, err := m.begin()
if err = m.userRepo.Disable(tx, userID, disabled); err != nil {
rollback(tx)
return err
}
if err = tx.Commit(); err != nil {
rollback(tx)
return err
}
return nil
}
// RegisterWithRemoteIdentity creates new user and attaches the given remote identity. // RegisterWithRemoteIdentity creates new user and attaches the given remote identity.
func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) { func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) {
tx, err := m.begin() tx, err := m.begin()

View file

@ -80,6 +80,8 @@ type UserRepo interface {
GetByEmail(tx repo.Transaction, email string) (User, error) GetByEmail(tx repo.Transaction, email string) (User, error)
Disable(tx repo.Transaction, id string, disabled bool) error
Update(repo.Transaction, User) error Update(repo.Transaction, User) error
GetByRemoteIdentity(repo.Transaction, RemoteIdentity) (User, error) GetByRemoteIdentity(repo.Transaction, RemoteIdentity) (User, error)
@ -254,6 +256,19 @@ func (r *memUserRepo) Update(_ repo.Transaction, user User) error {
return nil return nil
} }
func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error {
if id == "" {
return ErrorInvalidID
}
user, ok := r.usersByID[id]
if !ok {
return ErrorNotFound
}
user.Disabled = disable
r.set(user)
return nil
}
func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error { func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
_, ok := r.usersByID[userID] _, ok := r.usersByID[userID]
if !ok { if !ok {