diff --git a/db/user.go b/db/user.go index 5f9c0435..303911ea 100644 --- a/db/user.go +++ b/db/user.go @@ -97,10 +97,29 @@ func (r *userRepo) Create(tx repo.Transaction, usr user.User) (err error) { } err = r.insert(tx, usr) + return err +} + +func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) error { + if userID == "" { + return user.ErrorInvalidID + } + + qt := pq.QuoteIdentifier(userTableName) + ex := r.executor(tx) + result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable) if err != nil { return err } + ct, err := result.RowsAffected() + switch { + case err != nil: + return err + case ct == 0: + return user.ErrorNotFound + } + return nil } diff --git a/functional/repo/user_repo_test.go b/functional/repo/user_repo_test.go index 5a1ac934..f99ce80b 100644 --- a/functional/repo/user_repo_test.go +++ b/functional/repo/user_repo_test.go @@ -35,6 +35,7 @@ var ( ID: "ID-2", Email: "Email-2@example.com", CreatedAt: time.Now(), + Disabled: true, }, RemoteIdentities: []user.RemoteIdentity{ { @@ -232,6 +233,61 @@ func TestUpdateUser(t *testing.T) { } } +func TestDisableUser(t *testing.T) { + tests := []struct { + id string + disable bool + err error + }{ + { + id: "ID-1", + }, + { + id: "ID-1", + disable: true, + }, + { + id: "ID-2", + }, + { + id: "ID-2", + disable: true, + }, + { + id: "NO SUCH ID", + err: user.ErrorNotFound, + }, + { + id: "NO SUCH ID", + err: user.ErrorNotFound, + disable: true, + }, + { + id: "", + err: user.ErrorInvalidID, + }, + } + + for i, tt := range tests { + repo := makeTestUserRepo() + err := repo.Disable(nil, tt.id, tt.disable) + switch { + case err != tt.err: + t.Errorf("case %d: want=%q, got=%q", i, tt.err, err) + case tt.err == nil: + gotUser, err := repo.Get(nil, tt.id) + if err != nil { + t.Fatalf("case %d: want nil err, got %q", i, err) + } + + if gotUser.Disabled != tt.disable { + t.Errorf("case %d: disabled status want=%v got=%v", + i, tt.disable, gotUser.Disabled) + } + } + } +} + func TestAttachRemoteIdentity(t *testing.T) { tests := []struct { id string diff --git a/integration/user_api_test.go b/integration/user_api_test.go index 6d18c857..7c321d12 100644 --- a/integration/user_api_test.go +++ b/integration/user_api_test.go @@ -141,6 +141,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures { f.trans = &tokenHandlerTransport{ Handler: usrSrv.HTTPHandler(), + Token: userGoodToken, } hc := &http.Client{ Transport: f.trans, @@ -530,6 +531,48 @@ func TestCreateUser(t *testing.T) { } } +func TestDisableUser(t *testing.T) { + tests := []struct { + id string + disable bool + }{ + { + id: "ID-2", + disable: true, + }, + { + id: "ID-4", + disable: false, + }, + } + + for i, tt := range tests { + f := makeUserAPITestFixtures() + + usr, err := f.client.Users.Get(tt.id).Do() + if err != nil { + t.Fatalf("case %v: unexpected error: %v", i, err) + } + if usr.User.Disabled == tt.disable { + t.Fatalf("case %v: misconfigured test, initial disabled state should be %v but was %v", i, !tt.disable, usr.User.Disabled) + } + + _, err = f.client.Users.Disable(tt.id, &schema.UserDisableRequest{ + Disable: tt.disable, + }).Do() + if err != nil { + t.Fatalf("case %v: unexpected error: %v", i, err) + } + usr, err = f.client.Users.Get(tt.id).Do() + if err != nil { + t.Fatalf("case %v: unexpected error: %v", i, err) + } + if usr.User.Disabled != tt.disable { + t.Errorf("case %v: user disabled state incorrect. wanted: %v found: %v", i, tt.disable, usr.User.Disabled) + } + } +} + type testEmailer struct { cantEmail bool lastEmail string diff --git a/schema/adminschema/v1-gen.go b/schema/adminschema/v1-gen.go index 88da0c92..7e7549aa 100644 --- a/schema/adminschema/v1-gen.go +++ b/schema/adminschema/v1-gen.go @@ -4,7 +4,7 @@ // // Usage example: // -// import "github.com/coreos/dex/Godeps/_workspace/src/google.golang.org/api/adminschema/v1" +// import "google.golang.org/api/adminschema/v1" // ... // adminschemaService, err := adminschema.New(oauthHttpClient) package adminschema diff --git a/schema/workerschema/v1-gen.go b/schema/workerschema/v1-gen.go index f090b79f..950c9278 100644 --- a/schema/workerschema/v1-gen.go +++ b/schema/workerschema/v1-gen.go @@ -108,6 +108,8 @@ type User struct { CreatedAt string `json:"createdAt,omitempty"` + Disabled bool `json:"disabled,omitempty"` + DisplayName string `json:"displayName,omitempty"` Email string `json:"email,omitempty"` @@ -134,6 +136,15 @@ type UserCreateResponse struct { type UserCreateResponseUser struct { } +type UserDisableRequest struct { + // Disable: If true, disable this user, if false, enable them + Disable bool `json:"disable,omitempty"` +} + +type UserDisableResponse struct { + Ok bool `json:"ok,omitempty"` +} + type UserResponse struct { User *User `json:"user,omitempty"` } @@ -355,6 +366,89 @@ func (c *UsersCreateCall) Do() (*UserCreateResponse, error) { } +// method id "dex.User.Disable": + +type UsersDisableCall struct { + s *Service + id string + userdisablerequest *UserDisableRequest + opt_ map[string]interface{} +} + +// Disable: Enable or disable a user. +func (r *UsersService) Disable(id string, userdisablerequest *UserDisableRequest) *UsersDisableCall { + c := &UsersDisableCall{s: r.s, opt_: make(map[string]interface{})} + c.id = id + c.userdisablerequest = userdisablerequest + return c +} + +// Fields allows partial responses to be retrieved. +// See https://developers.google.com/gdata/docs/2.0/basics#PartialResponse +// for more information. +func (c *UsersDisableCall) Fields(s ...googleapi.Field) *UsersDisableCall { + c.opt_["fields"] = googleapi.CombineFields(s) + return c +} + +func (c *UsersDisableCall) Do() (*UserDisableResponse, error) { + var body io.Reader = nil + body, err := googleapi.WithoutDataWrapper.JSONReader(c.userdisablerequest) + if err != nil { + return nil, err + } + ctype := "application/json" + params := make(url.Values) + params.Set("alt", "json") + if v, ok := c.opt_["fields"]; ok { + params.Set("fields", fmt.Sprintf("%v", v)) + } + urls := googleapi.ResolveRelative(c.s.BasePath, "users/{id}/disable") + urls += "?" + params.Encode() + req, _ := http.NewRequest("POST", urls, body) + googleapi.Expand(req.URL, map[string]string{ + "id": c.id, + }) + req.Header.Set("Content-Type", ctype) + req.Header.Set("User-Agent", "google-api-go-client/0.5") + res, err := c.s.client.Do(req) + if err != nil { + return nil, err + } + defer googleapi.CloseBody(res) + if err := googleapi.CheckResponse(res); err != nil { + return nil, err + } + var ret *UserDisableResponse + if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { + return nil, err + } + return ret, nil + // { + // "description": "Enable or disable a user.", + // "httpMethod": "POST", + // "id": "dex.User.Disable", + // "parameterOrder": [ + // "id" + // ], + // "parameters": { + // "id": { + // "location": "path", + // "required": true, + // "type": "string" + // } + // }, + // "path": "users/{id}/disable", + // "request": { + // "$ref": "UserDisableRequest" + // }, + // "response": { + // "$ref": "UserDisableResponse" + // } + // } + +} + // method id "dex.User.Get": type UsersGetCall struct { @@ -363,7 +457,7 @@ type UsersGetCall struct { opt_ map[string]interface{} } -// Get: Get a single use object. +// Get: Get a single User object by id. func (r *UsersService) Get(id string) *UsersGetCall { c := &UsersGetCall{s: r.s, opt_: make(map[string]interface{})} c.id = id @@ -406,7 +500,7 @@ func (c *UsersGetCall) Do() (*UserResponse, error) { } return ret, nil // { - // "description": "Get a single use object.", + // "description": "Get a single User object by id.", // "httpMethod": "GET", // "id": "dex.User.Get", // "parameterOrder": [ diff --git a/schema/workerschema/v1-json.go b/schema/workerschema/v1-json.go index 0e9d7665..5e6c576d 100644 --- a/schema/workerschema/v1-json.go +++ b/schema/workerschema/v1-json.go @@ -1,5 +1,4 @@ package workerschema - // // This file is automatically generated by schema/generator // @@ -109,6 +108,9 @@ const DiscoveryJSON = `{ "admin": { "type": "boolean" }, + "disabled": { + "type": "boolean" + }, "createdAt": { "type": "string", "format": "date-time" @@ -167,6 +169,25 @@ const DiscoveryJSON = `{ "type": "boolean" } } + }, + "UserDisableRequest": { + "id": "UserDisableRequest", + "type": "object", + "properties": { + "disable": { + "type": "boolean", + "description": "If true, disable this user, if false, enable them" + } + } + }, + "UserDisableResponse": { + "id": "UserDisableResponse", + "type": "object", + "properties": { + "ok": { + "type": "boolean" + } + } } }, "resources": { @@ -224,7 +245,7 @@ const DiscoveryJSON = `{ }, "Get": { "id": "dex.User.Get", - "description": "Get a single use object.", + "description": "Get a single User object by id.", "httpMethod": "GET", "path": "users/{id}", "parameters": { @@ -252,9 +273,31 @@ const DiscoveryJSON = `{ "response": { "$ref": "UserCreateResponse" } + }, + "Disable": { + "id": "dex.User.Disable", + "description": "Enable or disable a user.", + "httpMethod": "POST", + "path": "users/{id}/disable", + "parameters": { + "id": { + "type": "string", + "required": true, + "location": "path" + } + }, + "parameterOrder": [ + "id" + ], + "request": { + "$ref": "UserDisableRequest" + }, + "response": { + "$ref": "UserDisableResponse" + } } } } } } -` +` \ No newline at end of file diff --git a/schema/workerschema/v1.json b/schema/workerschema/v1.json index 546f3efa..7d3570f9 100644 --- a/schema/workerschema/v1.json +++ b/schema/workerschema/v1.json @@ -102,6 +102,9 @@ "admin": { "type": "boolean" }, + "disabled": { + "type": "boolean" + }, "createdAt": { "type": "string", "format": "date-time" @@ -160,6 +163,25 @@ "type": "boolean" } } + }, + "UserDisableRequest": { + "id": "UserDisableRequest", + "type": "object", + "properties": { + "disable": { + "type": "boolean", + "description": "If true, disable this user, if false, enable them. No error is signaled if the user state doesn't change." + } + } + }, + "UserDisableResponse": { + "id": "UserDisableResponse", + "type": "object", + "properties": { + "ok": { + "type": "boolean" + } + } } }, "resources": { @@ -217,7 +239,7 @@ }, "Get": { "id": "dex.User.Get", - "description": "Get a single use object.", + "description": "Get a single User object by id.", "httpMethod": "GET", "path": "users/{id}", "parameters": { @@ -245,6 +267,28 @@ "response": { "$ref": "UserCreateResponse" } + }, + "Disable": { + "id": "dex.User.Disable", + "description": "Enable or disable a user.", + "httpMethod": "POST", + "path": "users/{id}/disable", + "parameters": { + "id": { + "type": "string", + "required": true, + "location": "path" + } + }, + "parameterOrder": [ + "id" + ], + "request": { + "$ref": "UserDisableRequest" + }, + "response": { + "$ref": "UserDisableResponse" + } } } } diff --git a/server/server.go b/server/server.go index 7860989d..38714781 100644 --- a/server/server.go +++ b/server/server.go @@ -244,7 +244,11 @@ func (s *Server) HTTPHandler() http.Handler { mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler)) usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID) - mux.Handle(path.Join(apiBasePath, UsersSubTree), NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()) + handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler() + path := path.Join(apiBasePath, UsersSubTree) + + mux.Handle(path, handler) + mux.Handle(path+"/", handler) return http.Handler(mux) } diff --git a/server/user.go b/server/user.go index 0c5350d9..ce2aecd6 100644 --- a/server/user.go +++ b/server/user.go @@ -23,10 +23,11 @@ const ( ) var ( - UsersSubTree = "/users" - UsersListEndpoint = addBasePath(UsersSubTree) - UsersCreateEndpoint = addBasePath(UsersSubTree) - UsersGetEndpoint = addBasePath(UsersSubTree + "/:id") + UsersSubTree = "/users" + UsersListEndpoint = addBasePath(UsersSubTree) + UsersCreateEndpoint = addBasePath(UsersSubTree) + UsersGetEndpoint = addBasePath(UsersSubTree + "/:id") + UsersDisableEndpoint = addBasePath(UsersSubTree + "/:id/disable") ) type UserMgmtServer struct { @@ -51,6 +52,7 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler { r.RedirectFixedPath = false r.GET(UsersListEndpoint, s.listUsers) r.POST(UsersCreateEndpoint, s.createUser) + r.POST(UsersDisableEndpoint, s.disableUser) r.GET(UsersGetEndpoint, s.getUser) return r } @@ -140,6 +142,35 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h writeResponseWithBody(w, http.StatusOK, createdResponse) } +func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + creds, err := s.getCreds(r) + if err != nil { + s.writeError(w, err) + return + } + + id := ps.ByName("id") + if id == "" { + writeAPIError(w, http.StatusBadRequest, + newAPIError(errorInvalidRequest, "id is required")) + return + } + + disableReq := schema.UserDisableRequest{} + err = json.NewDecoder(r.Body).Decode(&disableReq) + if err != nil { + writeInvalidRequest(w, "cannot parse JSON body") + } + + resp, err := s.api.DisableUser(creds, id, disableReq.Disable) + if err != nil { + s.writeError(w, err) + return + } + + writeResponseWithBody(w, http.StatusOK, resp) +} + func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) { log.Errorf("Error calling user management API: %v: ", err) if apiErr, ok := err.(api.Error); ok { diff --git a/user/api/api.go b/user/api/api.go index 06085ada..a05cc746 100644 --- a/user/api/api.go +++ b/user/api/api.go @@ -121,6 +121,21 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) { return userToSchemaUser(usr), nil } +func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema.UserDisableResponse, error) { + log.Infof("userAPI: DisableUser") + if !u.Authorize(creds) { + return schema.UserDisableResponse{}, ErrorUnauthorized + } + + if err := u.manager.Disable(userID, disable); err != nil { + return schema.UserDisableResponse{}, mapError(err) + } + + return schema.UserDisableResponse{ + Ok: true, + }, nil +} + func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (schema.UserCreateResponse, error) { log.Infof("userAPI: CreateUser") if !u.Authorize(creds) { @@ -207,6 +222,7 @@ func userToSchemaUser(usr user.User) schema.User { EmailVerified: usr.EmailVerified, DisplayName: usr.DisplayName, Admin: usr.Admin, + Disabled: usr.Disabled, CreatedAt: usr.CreatedAt.UTC().Format(time.RFC3339), } } @@ -218,6 +234,7 @@ func schemaUserToUser(usr schema.User) user.User { EmailVerified: usr.EmailVerified, DisplayName: usr.DisplayName, Admin: usr.Admin, + Disabled: usr.Disabled, } } diff --git a/user/api/api_test.go b/user/api/api_test.go index fce1f61e..fa5f85b1 100644 --- a/user/api/api_test.go +++ b/user/api/api_test.go @@ -94,6 +94,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { Email: "id3@example.com", CreatedAt: clock.Now(), }, + }, { + User: user.User{ + ID: "ID-4", + Email: "id4@example.com", + CreatedAt: clock.Now(), + Disabled: true, + }, }, }) pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ @@ -369,3 +376,44 @@ func TestCreateUser(t *testing.T) { } } } + +func TestDisableUsers(t *testing.T) { + tests := []struct { + id string + disable bool + }{ + { + id: "ID-1", + disable: true, + }, + { + id: "ID-1", + disable: false, + }, + { + id: "ID-4", + disable: true, + }, + { + id: "ID-4", + disable: false, + }, + } + + for i, tt := range tests { + api, _ := makeTestFixtures() + _, err := api.DisableUser(goodCreds, tt.id, tt.disable) + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + usr, err := api.GetUser(goodCreds, tt.id) + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + if usr.Disabled != tt.disable { + t.Errorf("case %d: user disable state wrong. wanted: %v got: %v", i, tt.disable, usr.Disabled) + } + } +} diff --git a/user/manager.go b/user/manager.go index cc09444c..1e046589 100644 --- a/user/manager.go +++ b/user/manager.go @@ -102,6 +102,22 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string) return user.ID, nil } +func (m *Manager) Disable(userID string, disabled bool) error { + tx, err := m.begin() + + if err = m.userRepo.Disable(tx, userID, disabled); err != nil { + rollback(tx) + return err + } + + if err = tx.Commit(); err != nil { + rollback(tx) + return err + } + + return nil +} + // RegisterWithRemoteIdentity creates new user and attaches the given remote identity. func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) { tx, err := m.begin() diff --git a/user/user.go b/user/user.go index c6b48465..e771e667 100644 --- a/user/user.go +++ b/user/user.go @@ -80,6 +80,8 @@ type UserRepo interface { GetByEmail(tx repo.Transaction, email string) (User, error) + Disable(tx repo.Transaction, id string, disabled bool) error + Update(repo.Transaction, User) error GetByRemoteIdentity(repo.Transaction, RemoteIdentity) (User, error) @@ -254,6 +256,19 @@ func (r *memUserRepo) Update(_ repo.Transaction, user User) error { return nil } +func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error { + if id == "" { + return ErrorInvalidID + } + user, ok := r.usersByID[id] + if !ok { + return ErrorNotFound + } + user.Disabled = disable + r.set(user) + return nil +} + func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error { _, ok := r.usersByID[userID] if !ok {