forked from mystiq/dex
server: expose user disable API endpoint
This commit is contained in:
parent
b33cfbf556
commit
e5db302312
11 changed files with 199 additions and 5 deletions
15
db/user.go
15
db/user.go
|
@ -97,10 +97,25 @@ func (r *userRepo) Create(tx repo.Transaction, usr user.User) (err error) {
|
|||
}
|
||||
|
||||
err = r.insert(tx, usr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) error {
|
||||
if userID == "" {
|
||||
return user.ErrorInvalidID
|
||||
}
|
||||
|
||||
qt := pq.QuoteIdentifier(userTableName)
|
||||
ex := r.executor(tx)
|
||||
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ct, err := result.RowsAffected(); err == nil && ct == 0 {
|
||||
return user.ErrorInvalidID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -141,6 +141,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
|||
|
||||
f.trans = &tokenHandlerTransport{
|
||||
Handler: usrSrv.HTTPHandler(),
|
||||
Token: userGoodToken,
|
||||
}
|
||||
hc := &http.Client{
|
||||
Transport: f.trans,
|
||||
|
@ -530,6 +531,48 @@ func TestCreateUser(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDisableUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
id string
|
||||
disable bool
|
||||
}{
|
||||
{
|
||||
id: "ID-2",
|
||||
disable: true,
|
||||
},
|
||||
{
|
||||
id: "ID-4",
|
||||
disable: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
f := makeUserAPITestFixtures()
|
||||
|
||||
usr, err := f.client.Users.Get(tt.id).Do()
|
||||
if err != nil {
|
||||
t.Fatalf("case %v: unexpected error: %v", i, err)
|
||||
}
|
||||
if usr.User.Disabled == tt.disable {
|
||||
t.Fatalf("case %v: misconfigured test, initial disabled state should be %v but was %v", i, !tt.disable, usr.User.Disabled)
|
||||
}
|
||||
|
||||
_, err = f.client.Users.Disable(tt.id, &schema.UserDisableRequest{
|
||||
Disable: tt.disable,
|
||||
}).Do()
|
||||
if err != nil {
|
||||
t.Fatalf("case %v: unexpected error: %v", i, err)
|
||||
}
|
||||
usr, err = f.client.Users.Get(tt.id).Do()
|
||||
if err != nil {
|
||||
t.Fatalf("case %v: unexpected error: %v", i, err)
|
||||
}
|
||||
if usr.User.Disabled != tt.disable {
|
||||
t.Errorf("case %v: user disabled state incorrect. wanted: %v found: %v", i, tt.disable, usr.User.Disabled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type testEmailer struct {
|
||||
cantEmail bool
|
||||
lastEmail string
|
||||
|
|
|
@ -108,6 +108,8 @@ type User struct {
|
|||
|
||||
CreatedAt string `json:"createdAt,omitempty"`
|
||||
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
|
||||
Email string `json:"email,omitempty"`
|
||||
|
|
|
@ -108,6 +108,9 @@ const DiscoveryJSON = `{
|
|||
"admin": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"disabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"createdAt": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
|
|
|
@ -102,6 +102,9 @@
|
|||
"admin": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"disabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"createdAt": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
|
|
|
@ -244,7 +244,11 @@ func (s *Server) HTTPHandler() http.Handler {
|
|||
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
|
||||
|
||||
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID)
|
||||
mux.Handle(path.Join(apiBasePath, UsersSubTree), NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler())
|
||||
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()
|
||||
path := path.Join(apiBasePath, UsersSubTree)
|
||||
|
||||
mux.Handle(path, handler)
|
||||
mux.Handle(path+"/", handler)
|
||||
|
||||
return http.Handler(mux)
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ var (
|
|||
UsersListEndpoint = addBasePath(UsersSubTree)
|
||||
UsersCreateEndpoint = addBasePath(UsersSubTree)
|
||||
UsersGetEndpoint = addBasePath(UsersSubTree + "/:id")
|
||||
UsersDisableEndpoint = addBasePath(UsersSubTree + "/:id/disable")
|
||||
)
|
||||
|
||||
type UserMgmtServer struct {
|
||||
|
@ -51,6 +52,7 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler {
|
|||
r.RedirectFixedPath = false
|
||||
r.GET(UsersListEndpoint, s.listUsers)
|
||||
r.POST(UsersCreateEndpoint, s.createUser)
|
||||
r.POST(UsersDisableEndpoint, s.disableUser)
|
||||
r.GET(UsersGetEndpoint, s.getUser)
|
||||
return r
|
||||
}
|
||||
|
@ -140,6 +142,35 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h
|
|||
writeResponseWithBody(w, http.StatusOK, createdResponse)
|
||||
}
|
||||
|
||||
func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
creds, err := s.getCreds(r)
|
||||
if err != nil {
|
||||
s.writeError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
id := ps.ByName("id")
|
||||
if id == "" {
|
||||
writeAPIError(w, http.StatusBadRequest,
|
||||
newAPIError(errorInvalidRequest, "id is required"))
|
||||
return
|
||||
}
|
||||
|
||||
disableReq := schema.UserDisableRequest{}
|
||||
err = json.NewDecoder(r.Body).Decode(&disableReq)
|
||||
if err != nil {
|
||||
writeInvalidRequest(w, "cannot parse JSON body")
|
||||
}
|
||||
|
||||
resp, err := s.api.DisableUser(creds, id, disableReq.Disable)
|
||||
if err != nil {
|
||||
s.writeError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
writeResponseWithBody(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) {
|
||||
log.Errorf("Error calling user management API: %v: ", err)
|
||||
if apiErr, ok := err.(api.Error); ok {
|
||||
|
|
|
@ -121,6 +121,21 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
|
|||
return userToSchemaUser(usr), nil
|
||||
}
|
||||
|
||||
func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema.UserDisableResponse, error) {
|
||||
log.Infof("userAPI: DisableUser")
|
||||
if !u.Authorize(creds) {
|
||||
return schema.UserDisableResponse{}, ErrorUnauthorized
|
||||
}
|
||||
|
||||
if err := u.manager.Disable(userID, disable); err != nil {
|
||||
return schema.UserDisableResponse{}, mapError(err)
|
||||
}
|
||||
|
||||
return schema.UserDisableResponse{
|
||||
Ok: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (schema.UserCreateResponse, error) {
|
||||
log.Infof("userAPI: CreateUser")
|
||||
if !u.Authorize(creds) {
|
||||
|
@ -207,6 +222,7 @@ func userToSchemaUser(usr user.User) schema.User {
|
|||
EmailVerified: usr.EmailVerified,
|
||||
DisplayName: usr.DisplayName,
|
||||
Admin: usr.Admin,
|
||||
Disabled: usr.Disabled,
|
||||
CreatedAt: usr.CreatedAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
@ -218,6 +234,7 @@ func schemaUserToUser(usr schema.User) user.User {
|
|||
EmailVerified: usr.EmailVerified,
|
||||
DisplayName: usr.DisplayName,
|
||||
Admin: usr.Admin,
|
||||
Disabled: usr.Disabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -94,6 +94,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
Email: "id3@example.com",
|
||||
CreatedAt: clock.Now(),
|
||||
},
|
||||
}, {
|
||||
User: user.User{
|
||||
ID: "ID-4",
|
||||
Email: "id4@example.com",
|
||||
CreatedAt: clock.Now(),
|
||||
Disabled: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
|
@ -369,3 +376,44 @@ func TestCreateUser(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisableUsers(t *testing.T) {
|
||||
tests := []struct {
|
||||
id string
|
||||
disable bool
|
||||
}{
|
||||
{
|
||||
id: "ID-1",
|
||||
disable: true,
|
||||
},
|
||||
{
|
||||
id: "ID-1",
|
||||
disable: false,
|
||||
},
|
||||
{
|
||||
id: "ID-4",
|
||||
disable: true,
|
||||
},
|
||||
{
|
||||
id: "ID-4",
|
||||
disable: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
api, _ := makeTestFixtures()
|
||||
_, err := api.DisableUser(goodCreds, tt.id, tt.disable)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
usr, err := api.GetUser(goodCreds, tt.id)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
if usr.Disabled != tt.disable {
|
||||
t.Errorf("case %d: user disable state wrong. wanted: %v got: %v", i, tt.disable, usr.Disabled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -102,6 +102,22 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string)
|
|||
return user.ID, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Disable(userID string, disabled bool) error {
|
||||
tx, err := m.begin()
|
||||
|
||||
if err = m.userRepo.Disable(tx, userID, disabled); err != nil {
|
||||
rollback(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
rollback(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterWithRemoteIdentity creates new user and attaches the given remote identity.
|
||||
func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) {
|
||||
tx, err := m.begin()
|
||||
|
|
12
user/user.go
12
user/user.go
|
@ -80,6 +80,8 @@ type UserRepo interface {
|
|||
|
||||
GetByEmail(tx repo.Transaction, email string) (User, error)
|
||||
|
||||
Disable(tx repo.Transaction, id string, disabled bool) error
|
||||
|
||||
Update(repo.Transaction, User) error
|
||||
|
||||
GetByRemoteIdentity(repo.Transaction, RemoteIdentity) (User, error)
|
||||
|
@ -254,6 +256,16 @@ func (r *memUserRepo) Update(_ repo.Transaction, user User) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error {
|
||||
user, ok := r.usersByID[id]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
user.Disabled = disable
|
||||
r.set(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
|
||||
_, ok := r.usersByID[userID]
|
||||
if !ok {
|
||||
|
|
Loading…
Reference in a new issue