diff --git a/server/user.go b/server/user.go index 713263f7..b5cfb95b 100644 --- a/server/user.go +++ b/server/user.go @@ -51,24 +51,33 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler { r := httprouter.New() r.RedirectTrailingSlash = false r.RedirectFixedPath = false - r.GET(UsersListEndpoint, s.listUsers) - r.POST(UsersCreateEndpoint, s.createUser) - r.POST(UsersDisableEndpoint, s.disableUser) - r.GET(UsersGetEndpoint, s.getUser) + r.GET(UsersListEndpoint, s.authAPIHandle(s.listUsers)) + r.POST(UsersCreateEndpoint, s.authAPIHandle(s.createUser)) + r.POST(UsersDisableEndpoint, s.authAPIHandle(s.disableUser)) + r.GET(UsersGetEndpoint, s.authAPIHandle(s.getUser)) return r } -func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - creds, err := s.getCreds(r) - if err != nil { - s.writeError(w, err) - return - } +// authedHandle is an HTTP handle which requires requests to be authenticated as an admin user. +type authedHandle func(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) +// authAPIHandle is a middleware function with authenticates an HTTP request before passing +// it along to the authedHandle. +func (s *UserMgmtServer) authAPIHandle(handle authedHandle) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + creds, err := s.getCreds(r) + if err != nil { + s.writeError(w, err) + return + } + handle(w, r, ps, creds) + } +} + +func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) { nextPageToken := r.URL.Query().Get("nextPageToken") maxResults, err := intFromQuery(r.URL.Query(), "maxResults", defaultMaxResults) - if err != nil { writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidRequest, "maxResults must be an integer")) @@ -88,13 +97,7 @@ func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps ht writeResponseWithBody(w, http.StatusOK, usersResponse) } -func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - creds, err := s.getCreds(r) - if err != nil { - s.writeError(w, err) - return - } - +func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) { id := ps.ByName("id") if id == "" { writeAPIError(w, http.StatusBadRequest, @@ -113,16 +116,9 @@ func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps http writeResponseWithBody(w, http.StatusOK, userResponse) } -func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - creds, err := s.getCreds(r) - if err != nil { - s.writeError(w, err) - return - } - +func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) { createReq := schema.UserCreateRequest{} - err = json.NewDecoder(r.Body).Decode(&createReq) - if err != nil { + if err := json.NewDecoder(r.Body).Decode(&createReq); err != nil { writeInvalidRequest(w, "cannot parse JSON body") return } @@ -143,13 +139,7 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h writeResponseWithBody(w, http.StatusOK, createdResponse) } -func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - creds, err := s.getCreds(r) - if err != nil { - s.writeError(w, err) - return - } - +func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) { id := ps.ByName("id") if id == "" { writeAPIError(w, http.StatusBadRequest, @@ -158,8 +148,7 @@ func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps } disableReq := schema.UserDisableRequest{} - err = json.NewDecoder(r.Body).Decode(&disableReq) - if err != nil { + if err := json.NewDecoder(r.Body).Decode(&disableReq); err != nil { writeInvalidRequest(w, "cannot parse JSON body") } @@ -240,7 +229,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request) (api.Creds, error) { return api.Creds{}, err } if !isAdmin { - return api.Creds{}, api.ErrorUnauthorized + return api.Creds{}, api.ErrorForbidden } return api.Creds{ diff --git a/user/api/api.go b/user/api/api.go index 2eea7f5c..d5218ac9 100644 --- a/user/api/api.go +++ b/user/api/api.go @@ -31,7 +31,8 @@ var ( ErrorDuplicateEmail = newError("duplicate_email", "Email already in use.", http.StatusBadRequest) ErrorResourceNotFound = newError("resource_not_found", "Resource could not be found.", http.StatusNotFound) - ErrorUnauthorized = newError("unauthorized", "The given user and client are not authorized to make this request.", http.StatusUnauthorized) + ErrorUnauthorized = newError("unauthorized", "Necessary credentials not provided.", http.StatusUnauthorized) + ErrorForbidden = newError("forbidden", "The given user and client are not authorized to make this request.", http.StatusForbidden) ErrorMaxResultsTooHigh = newError("max_results_too_high", fmt.Sprintf("The max number of results per page is %d", maxUsersPerPage), http.StatusBadRequest)