*: move user API auth to middleware and fix return status

Move client authentication into its own middleware and provide
differentiation between HTTP requests that do not provide
credentials (401) and requests that authenticate as a non-admin
user (403).

Closes #152
This commit is contained in:
Eric Chiang 2016-01-19 13:40:16 -08:00
parent b5c7f1978e
commit 0ada4c8010
2 changed files with 28 additions and 38 deletions

View file

@ -51,24 +51,33 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler {
r := httprouter.New() r := httprouter.New()
r.RedirectTrailingSlash = false r.RedirectTrailingSlash = false
r.RedirectFixedPath = false r.RedirectFixedPath = false
r.GET(UsersListEndpoint, s.listUsers) r.GET(UsersListEndpoint, s.authAPIHandle(s.listUsers))
r.POST(UsersCreateEndpoint, s.createUser) r.POST(UsersCreateEndpoint, s.authAPIHandle(s.createUser))
r.POST(UsersDisableEndpoint, s.disableUser) r.POST(UsersDisableEndpoint, s.authAPIHandle(s.disableUser))
r.GET(UsersGetEndpoint, s.getUser) r.GET(UsersGetEndpoint, s.authAPIHandle(s.getUser))
return r return r
} }
func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { // authedHandle is an HTTP handle which requires requests to be authenticated as an admin user.
creds, err := s.getCreds(r) type authedHandle func(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds)
if err != nil {
s.writeError(w, err)
return
}
// authAPIHandle is a middleware function with authenticates an HTTP request before passing
// it along to the authedHandle.
func (s *UserMgmtServer) authAPIHandle(handle authedHandle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
handle(w, r, ps, creds)
}
}
func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
nextPageToken := r.URL.Query().Get("nextPageToken") nextPageToken := r.URL.Query().Get("nextPageToken")
maxResults, err := intFromQuery(r.URL.Query(), "maxResults", defaultMaxResults) maxResults, err := intFromQuery(r.URL.Query(), "maxResults", defaultMaxResults)
if err != nil { if err != nil {
writeAPIError(w, http.StatusBadRequest, writeAPIError(w, http.StatusBadRequest,
newAPIError(errorInvalidRequest, "maxResults must be an integer")) newAPIError(errorInvalidRequest, "maxResults must be an integer"))
@ -88,13 +97,7 @@ func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps ht
writeResponseWithBody(w, http.StatusOK, usersResponse) writeResponseWithBody(w, http.StatusOK, usersResponse)
} }
func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
id := ps.ByName("id") id := ps.ByName("id")
if id == "" { if id == "" {
writeAPIError(w, http.StatusBadRequest, writeAPIError(w, http.StatusBadRequest,
@ -113,16 +116,9 @@ func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps http
writeResponseWithBody(w, http.StatusOK, userResponse) writeResponseWithBody(w, http.StatusOK, userResponse)
} }
func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
createReq := schema.UserCreateRequest{} createReq := schema.UserCreateRequest{}
err = json.NewDecoder(r.Body).Decode(&createReq) if err := json.NewDecoder(r.Body).Decode(&createReq); err != nil {
if err != nil {
writeInvalidRequest(w, "cannot parse JSON body") writeInvalidRequest(w, "cannot parse JSON body")
return return
} }
@ -143,13 +139,7 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h
writeResponseWithBody(w, http.StatusOK, createdResponse) writeResponseWithBody(w, http.StatusOK, createdResponse)
} }
func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
id := ps.ByName("id") id := ps.ByName("id")
if id == "" { if id == "" {
writeAPIError(w, http.StatusBadRequest, writeAPIError(w, http.StatusBadRequest,
@ -158,8 +148,7 @@ func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps
} }
disableReq := schema.UserDisableRequest{} disableReq := schema.UserDisableRequest{}
err = json.NewDecoder(r.Body).Decode(&disableReq) if err := json.NewDecoder(r.Body).Decode(&disableReq); err != nil {
if err != nil {
writeInvalidRequest(w, "cannot parse JSON body") writeInvalidRequest(w, "cannot parse JSON body")
} }
@ -240,7 +229,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request) (api.Creds, error) {
return api.Creds{}, err return api.Creds{}, err
} }
if !isAdmin { if !isAdmin {
return api.Creds{}, api.ErrorUnauthorized return api.Creds{}, api.ErrorForbidden
} }
return api.Creds{ return api.Creds{

View file

@ -31,7 +31,8 @@ var (
ErrorDuplicateEmail = newError("duplicate_email", "Email already in use.", http.StatusBadRequest) ErrorDuplicateEmail = newError("duplicate_email", "Email already in use.", http.StatusBadRequest)
ErrorResourceNotFound = newError("resource_not_found", "Resource could not be found.", http.StatusNotFound) ErrorResourceNotFound = newError("resource_not_found", "Resource could not be found.", http.StatusNotFound)
ErrorUnauthorized = newError("unauthorized", "The given user and client are not authorized to make this request.", http.StatusUnauthorized) ErrorUnauthorized = newError("unauthorized", "Necessary credentials not provided.", http.StatusUnauthorized)
ErrorForbidden = newError("forbidden", "The given user and client are not authorized to make this request.", http.StatusForbidden)
ErrorMaxResultsTooHigh = newError("max_results_too_high", fmt.Sprintf("The max number of results per page is %d", maxUsersPerPage), http.StatusBadRequest) ErrorMaxResultsTooHigh = newError("max_results_too_high", fmt.Sprintf("The max number of results per page is %d", maxUsersPerPage), http.StatusBadRequest)