diff --git a/db/user.go b/db/user.go index 5f9c0435..6863883a 100644 --- a/db/user.go +++ b/db/user.go @@ -97,10 +97,25 @@ func (r *userRepo) Create(tx repo.Transaction, usr user.User) (err error) { } err = r.insert(tx, usr) + return err +} + +func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) error { + if userID == "" { + return user.ErrorInvalidID + } + + qt := pq.QuoteIdentifier(userTableName) + ex := r.executor(tx) + result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable) if err != nil { return err } + if ct, err := result.RowsAffected(); err == nil && ct == 0 { + return user.ErrorInvalidID + } + return nil } diff --git a/integration/user_api_test.go b/integration/user_api_test.go index 6d18c857..7c321d12 100644 --- a/integration/user_api_test.go +++ b/integration/user_api_test.go @@ -141,6 +141,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures { f.trans = &tokenHandlerTransport{ Handler: usrSrv.HTTPHandler(), + Token: userGoodToken, } hc := &http.Client{ Transport: f.trans, @@ -530,6 +531,48 @@ func TestCreateUser(t *testing.T) { } } +func TestDisableUser(t *testing.T) { + tests := []struct { + id string + disable bool + }{ + { + id: "ID-2", + disable: true, + }, + { + id: "ID-4", + disable: false, + }, + } + + for i, tt := range tests { + f := makeUserAPITestFixtures() + + usr, err := f.client.Users.Get(tt.id).Do() + if err != nil { + t.Fatalf("case %v: unexpected error: %v", i, err) + } + if usr.User.Disabled == tt.disable { + t.Fatalf("case %v: misconfigured test, initial disabled state should be %v but was %v", i, !tt.disable, usr.User.Disabled) + } + + _, err = f.client.Users.Disable(tt.id, &schema.UserDisableRequest{ + Disable: tt.disable, + }).Do() + if err != nil { + t.Fatalf("case %v: unexpected error: %v", i, err) + } + usr, err = f.client.Users.Get(tt.id).Do() + if err != nil { + t.Fatalf("case %v: unexpected error: %v", i, err) + } + if usr.User.Disabled != tt.disable { + t.Errorf("case %v: user disabled state incorrect. wanted: %v found: %v", i, tt.disable, usr.User.Disabled) + } + } +} + type testEmailer struct { cantEmail bool lastEmail string diff --git a/schema/workerschema/v1-gen.go b/schema/workerschema/v1-gen.go index 385c829a..950c9278 100644 --- a/schema/workerschema/v1-gen.go +++ b/schema/workerschema/v1-gen.go @@ -108,6 +108,8 @@ type User struct { CreatedAt string `json:"createdAt,omitempty"` + Disabled bool `json:"disabled,omitempty"` + DisplayName string `json:"displayName,omitempty"` Email string `json:"email,omitempty"` diff --git a/schema/workerschema/v1-json.go b/schema/workerschema/v1-json.go index d3cb5ab7..5e6c576d 100644 --- a/schema/workerschema/v1-json.go +++ b/schema/workerschema/v1-json.go @@ -108,6 +108,9 @@ const DiscoveryJSON = `{ "admin": { "type": "boolean" }, + "disabled": { + "type": "boolean" + }, "createdAt": { "type": "string", "format": "date-time" diff --git a/schema/workerschema/v1.json b/schema/workerschema/v1.json index 78fdfa68..7d3570f9 100644 --- a/schema/workerschema/v1.json +++ b/schema/workerschema/v1.json @@ -102,6 +102,9 @@ "admin": { "type": "boolean" }, + "disabled": { + "type": "boolean" + }, "createdAt": { "type": "string", "format": "date-time" diff --git a/server/server.go b/server/server.go index 7860989d..38714781 100644 --- a/server/server.go +++ b/server/server.go @@ -244,7 +244,11 @@ func (s *Server) HTTPHandler() http.Handler { mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler)) usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID) - mux.Handle(path.Join(apiBasePath, UsersSubTree), NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()) + handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler() + path := path.Join(apiBasePath, UsersSubTree) + + mux.Handle(path, handler) + mux.Handle(path+"/", handler) return http.Handler(mux) } diff --git a/server/user.go b/server/user.go index 0c5350d9..ce2aecd6 100644 --- a/server/user.go +++ b/server/user.go @@ -23,10 +23,11 @@ const ( ) var ( - UsersSubTree = "/users" - UsersListEndpoint = addBasePath(UsersSubTree) - UsersCreateEndpoint = addBasePath(UsersSubTree) - UsersGetEndpoint = addBasePath(UsersSubTree + "/:id") + UsersSubTree = "/users" + UsersListEndpoint = addBasePath(UsersSubTree) + UsersCreateEndpoint = addBasePath(UsersSubTree) + UsersGetEndpoint = addBasePath(UsersSubTree + "/:id") + UsersDisableEndpoint = addBasePath(UsersSubTree + "/:id/disable") ) type UserMgmtServer struct { @@ -51,6 +52,7 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler { r.RedirectFixedPath = false r.GET(UsersListEndpoint, s.listUsers) r.POST(UsersCreateEndpoint, s.createUser) + r.POST(UsersDisableEndpoint, s.disableUser) r.GET(UsersGetEndpoint, s.getUser) return r } @@ -140,6 +142,35 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h writeResponseWithBody(w, http.StatusOK, createdResponse) } +func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + creds, err := s.getCreds(r) + if err != nil { + s.writeError(w, err) + return + } + + id := ps.ByName("id") + if id == "" { + writeAPIError(w, http.StatusBadRequest, + newAPIError(errorInvalidRequest, "id is required")) + return + } + + disableReq := schema.UserDisableRequest{} + err = json.NewDecoder(r.Body).Decode(&disableReq) + if err != nil { + writeInvalidRequest(w, "cannot parse JSON body") + } + + resp, err := s.api.DisableUser(creds, id, disableReq.Disable) + if err != nil { + s.writeError(w, err) + return + } + + writeResponseWithBody(w, http.StatusOK, resp) +} + func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) { log.Errorf("Error calling user management API: %v: ", err) if apiErr, ok := err.(api.Error); ok { diff --git a/user/api/api.go b/user/api/api.go index 06085ada..a05cc746 100644 --- a/user/api/api.go +++ b/user/api/api.go @@ -121,6 +121,21 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) { return userToSchemaUser(usr), nil } +func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema.UserDisableResponse, error) { + log.Infof("userAPI: DisableUser") + if !u.Authorize(creds) { + return schema.UserDisableResponse{}, ErrorUnauthorized + } + + if err := u.manager.Disable(userID, disable); err != nil { + return schema.UserDisableResponse{}, mapError(err) + } + + return schema.UserDisableResponse{ + Ok: true, + }, nil +} + func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (schema.UserCreateResponse, error) { log.Infof("userAPI: CreateUser") if !u.Authorize(creds) { @@ -207,6 +222,7 @@ func userToSchemaUser(usr user.User) schema.User { EmailVerified: usr.EmailVerified, DisplayName: usr.DisplayName, Admin: usr.Admin, + Disabled: usr.Disabled, CreatedAt: usr.CreatedAt.UTC().Format(time.RFC3339), } } @@ -218,6 +234,7 @@ func schemaUserToUser(usr schema.User) user.User { EmailVerified: usr.EmailVerified, DisplayName: usr.DisplayName, Admin: usr.Admin, + Disabled: usr.Disabled, } } diff --git a/user/api/api_test.go b/user/api/api_test.go index fce1f61e..fa5f85b1 100644 --- a/user/api/api_test.go +++ b/user/api/api_test.go @@ -94,6 +94,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { Email: "id3@example.com", CreatedAt: clock.Now(), }, + }, { + User: user.User{ + ID: "ID-4", + Email: "id4@example.com", + CreatedAt: clock.Now(), + Disabled: true, + }, }, }) pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ @@ -369,3 +376,44 @@ func TestCreateUser(t *testing.T) { } } } + +func TestDisableUsers(t *testing.T) { + tests := []struct { + id string + disable bool + }{ + { + id: "ID-1", + disable: true, + }, + { + id: "ID-1", + disable: false, + }, + { + id: "ID-4", + disable: true, + }, + { + id: "ID-4", + disable: false, + }, + } + + for i, tt := range tests { + api, _ := makeTestFixtures() + _, err := api.DisableUser(goodCreds, tt.id, tt.disable) + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + usr, err := api.GetUser(goodCreds, tt.id) + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + if usr.Disabled != tt.disable { + t.Errorf("case %d: user disable state wrong. wanted: %v got: %v", i, tt.disable, usr.Disabled) + } + } +} diff --git a/user/manager.go b/user/manager.go index cc09444c..1e046589 100644 --- a/user/manager.go +++ b/user/manager.go @@ -102,6 +102,22 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string) return user.ID, nil } +func (m *Manager) Disable(userID string, disabled bool) error { + tx, err := m.begin() + + if err = m.userRepo.Disable(tx, userID, disabled); err != nil { + rollback(tx) + return err + } + + if err = tx.Commit(); err != nil { + rollback(tx) + return err + } + + return nil +} + // RegisterWithRemoteIdentity creates new user and attaches the given remote identity. func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) { tx, err := m.begin() diff --git a/user/user.go b/user/user.go index c6b48465..9fecb835 100644 --- a/user/user.go +++ b/user/user.go @@ -80,6 +80,8 @@ type UserRepo interface { GetByEmail(tx repo.Transaction, email string) (User, error) + Disable(tx repo.Transaction, id string, disabled bool) error + Update(repo.Transaction, User) error GetByRemoteIdentity(repo.Transaction, RemoteIdentity) (User, error) @@ -254,6 +256,16 @@ func (r *memUserRepo) Update(_ repo.Transaction, user User) error { return nil } +func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error { + user, ok := r.usersByID[id] + if !ok { + return ErrorNotFound + } + user.Disabled = disable + r.set(user) + return nil +} + func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error { _, ok := r.usersByID[userID] if !ok {