server: expose user disable API endpoint

This commit is contained in:
Joe Bowers 2015-09-28 13:51:48 -07:00
parent b33cfbf556
commit e5db302312
11 changed files with 199 additions and 5 deletions

View file

@ -97,10 +97,25 @@ func (r *userRepo) Create(tx repo.Transaction, usr user.User) (err error) {
} }
err = r.insert(tx, usr) err = r.insert(tx, usr)
return err
}
func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) error {
if userID == "" {
return user.ErrorInvalidID
}
qt := pq.QuoteIdentifier(userTableName)
ex := r.executor(tx)
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable)
if err != nil { if err != nil {
return err return err
} }
if ct, err := result.RowsAffected(); err == nil && ct == 0 {
return user.ErrorInvalidID
}
return nil return nil
} }

View file

@ -141,6 +141,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f.trans = &tokenHandlerTransport{ f.trans = &tokenHandlerTransport{
Handler: usrSrv.HTTPHandler(), Handler: usrSrv.HTTPHandler(),
Token: userGoodToken,
} }
hc := &http.Client{ hc := &http.Client{
Transport: f.trans, Transport: f.trans,
@ -530,6 +531,48 @@ func TestCreateUser(t *testing.T) {
} }
} }
func TestDisableUser(t *testing.T) {
tests := []struct {
id string
disable bool
}{
{
id: "ID-2",
disable: true,
},
{
id: "ID-4",
disable: false,
},
}
for i, tt := range tests {
f := makeUserAPITestFixtures()
usr, err := f.client.Users.Get(tt.id).Do()
if err != nil {
t.Fatalf("case %v: unexpected error: %v", i, err)
}
if usr.User.Disabled == tt.disable {
t.Fatalf("case %v: misconfigured test, initial disabled state should be %v but was %v", i, !tt.disable, usr.User.Disabled)
}
_, err = f.client.Users.Disable(tt.id, &schema.UserDisableRequest{
Disable: tt.disable,
}).Do()
if err != nil {
t.Fatalf("case %v: unexpected error: %v", i, err)
}
usr, err = f.client.Users.Get(tt.id).Do()
if err != nil {
t.Fatalf("case %v: unexpected error: %v", i, err)
}
if usr.User.Disabled != tt.disable {
t.Errorf("case %v: user disabled state incorrect. wanted: %v found: %v", i, tt.disable, usr.User.Disabled)
}
}
}
type testEmailer struct { type testEmailer struct {
cantEmail bool cantEmail bool
lastEmail string lastEmail string

View file

@ -108,6 +108,8 @@ type User struct {
CreatedAt string `json:"createdAt,omitempty"` CreatedAt string `json:"createdAt,omitempty"`
Disabled bool `json:"disabled,omitempty"`
DisplayName string `json:"displayName,omitempty"` DisplayName string `json:"displayName,omitempty"`
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`

View file

@ -108,6 +108,9 @@ const DiscoveryJSON = `{
"admin": { "admin": {
"type": "boolean" "type": "boolean"
}, },
"disabled": {
"type": "boolean"
},
"createdAt": { "createdAt": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"

View file

@ -102,6 +102,9 @@
"admin": { "admin": {
"type": "boolean" "type": "boolean"
}, },
"disabled": {
"type": "boolean"
},
"createdAt": { "createdAt": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"

View file

@ -244,7 +244,11 @@ func (s *Server) HTTPHandler() http.Handler {
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler)) mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID) usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID)
mux.Handle(path.Join(apiBasePath, UsersSubTree), NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()) handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()
path := path.Join(apiBasePath, UsersSubTree)
mux.Handle(path, handler)
mux.Handle(path+"/", handler)
return http.Handler(mux) return http.Handler(mux)
} }

View file

@ -23,10 +23,11 @@ const (
) )
var ( var (
UsersSubTree = "/users" UsersSubTree = "/users"
UsersListEndpoint = addBasePath(UsersSubTree) UsersListEndpoint = addBasePath(UsersSubTree)
UsersCreateEndpoint = addBasePath(UsersSubTree) UsersCreateEndpoint = addBasePath(UsersSubTree)
UsersGetEndpoint = addBasePath(UsersSubTree + "/:id") UsersGetEndpoint = addBasePath(UsersSubTree + "/:id")
UsersDisableEndpoint = addBasePath(UsersSubTree + "/:id/disable")
) )
type UserMgmtServer struct { type UserMgmtServer struct {
@ -51,6 +52,7 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler {
r.RedirectFixedPath = false r.RedirectFixedPath = false
r.GET(UsersListEndpoint, s.listUsers) r.GET(UsersListEndpoint, s.listUsers)
r.POST(UsersCreateEndpoint, s.createUser) r.POST(UsersCreateEndpoint, s.createUser)
r.POST(UsersDisableEndpoint, s.disableUser)
r.GET(UsersGetEndpoint, s.getUser) r.GET(UsersGetEndpoint, s.getUser)
return r return r
} }
@ -140,6 +142,35 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h
writeResponseWithBody(w, http.StatusOK, createdResponse) writeResponseWithBody(w, http.StatusOK, createdResponse)
} }
func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
id := ps.ByName("id")
if id == "" {
writeAPIError(w, http.StatusBadRequest,
newAPIError(errorInvalidRequest, "id is required"))
return
}
disableReq := schema.UserDisableRequest{}
err = json.NewDecoder(r.Body).Decode(&disableReq)
if err != nil {
writeInvalidRequest(w, "cannot parse JSON body")
}
resp, err := s.api.DisableUser(creds, id, disableReq.Disable)
if err != nil {
s.writeError(w, err)
return
}
writeResponseWithBody(w, http.StatusOK, resp)
}
func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) { func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) {
log.Errorf("Error calling user management API: %v: ", err) log.Errorf("Error calling user management API: %v: ", err)
if apiErr, ok := err.(api.Error); ok { if apiErr, ok := err.(api.Error); ok {

View file

@ -121,6 +121,21 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
return userToSchemaUser(usr), nil return userToSchemaUser(usr), nil
} }
func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema.UserDisableResponse, error) {
log.Infof("userAPI: DisableUser")
if !u.Authorize(creds) {
return schema.UserDisableResponse{}, ErrorUnauthorized
}
if err := u.manager.Disable(userID, disable); err != nil {
return schema.UserDisableResponse{}, mapError(err)
}
return schema.UserDisableResponse{
Ok: true,
}, nil
}
func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (schema.UserCreateResponse, error) { func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (schema.UserCreateResponse, error) {
log.Infof("userAPI: CreateUser") log.Infof("userAPI: CreateUser")
if !u.Authorize(creds) { if !u.Authorize(creds) {
@ -207,6 +222,7 @@ func userToSchemaUser(usr user.User) schema.User {
EmailVerified: usr.EmailVerified, EmailVerified: usr.EmailVerified,
DisplayName: usr.DisplayName, DisplayName: usr.DisplayName,
Admin: usr.Admin, Admin: usr.Admin,
Disabled: usr.Disabled,
CreatedAt: usr.CreatedAt.UTC().Format(time.RFC3339), CreatedAt: usr.CreatedAt.UTC().Format(time.RFC3339),
} }
} }
@ -218,6 +234,7 @@ func schemaUserToUser(usr schema.User) user.User {
EmailVerified: usr.EmailVerified, EmailVerified: usr.EmailVerified,
DisplayName: usr.DisplayName, DisplayName: usr.DisplayName,
Admin: usr.Admin, Admin: usr.Admin,
Disabled: usr.Disabled,
} }
} }

View file

@ -94,6 +94,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
Email: "id3@example.com", Email: "id3@example.com",
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
}, },
}, {
User: user.User{
ID: "ID-4",
Email: "id4@example.com",
CreatedAt: clock.Now(),
Disabled: true,
},
}, },
}) })
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
@ -369,3 +376,44 @@ func TestCreateUser(t *testing.T) {
} }
} }
} }
func TestDisableUsers(t *testing.T) {
tests := []struct {
id string
disable bool
}{
{
id: "ID-1",
disable: true,
},
{
id: "ID-1",
disable: false,
},
{
id: "ID-4",
disable: true,
},
{
id: "ID-4",
disable: false,
},
}
for i, tt := range tests {
api, _ := makeTestFixtures()
_, err := api.DisableUser(goodCreds, tt.id, tt.disable)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
usr, err := api.GetUser(goodCreds, tt.id)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
if usr.Disabled != tt.disable {
t.Errorf("case %d: user disable state wrong. wanted: %v got: %v", i, tt.disable, usr.Disabled)
}
}
}

View file

@ -102,6 +102,22 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string)
return user.ID, nil return user.ID, nil
} }
func (m *Manager) Disable(userID string, disabled bool) error {
tx, err := m.begin()
if err = m.userRepo.Disable(tx, userID, disabled); err != nil {
rollback(tx)
return err
}
if err = tx.Commit(); err != nil {
rollback(tx)
return err
}
return nil
}
// RegisterWithRemoteIdentity creates new user and attaches the given remote identity. // RegisterWithRemoteIdentity creates new user and attaches the given remote identity.
func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) { func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) {
tx, err := m.begin() tx, err := m.begin()

View file

@ -80,6 +80,8 @@ type UserRepo interface {
GetByEmail(tx repo.Transaction, email string) (User, error) GetByEmail(tx repo.Transaction, email string) (User, error)
Disable(tx repo.Transaction, id string, disabled bool) error
Update(repo.Transaction, User) error Update(repo.Transaction, User) error
GetByRemoteIdentity(repo.Transaction, RemoteIdentity) (User, error) GetByRemoteIdentity(repo.Transaction, RemoteIdentity) (User, error)
@ -254,6 +256,16 @@ func (r *memUserRepo) Update(_ repo.Transaction, user User) error {
return nil return nil
} }
func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error {
user, ok := r.usersByID[id]
if !ok {
return ErrorNotFound
}
user.Disabled = disable
r.set(user)
return nil
}
func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error { func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
_, ok := r.usersByID[userID] _, ok := r.usersByID[userID]
if !ok { if !ok {