handlers: change all handlers to pass down http request

Signed-off-by: Yannis Zarkadas <yanniszark@arrikto.com>
This commit is contained in:
Yannis Zarkadas 2019-09-27 16:56:32 +03:00
parent 8427f0f15c
commit 839130f01c

View file

@ -101,7 +101,7 @@ func (h *healthChecker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.mu.RUnlock() h.mu.RUnlock()
if err != nil { if err != nil {
h.s.renderError(w, http.StatusInternalServerError, "Health check failed.") h.s.renderError(r, w, http.StatusInternalServerError, "Health check failed.")
return return
} }
fmt.Fprintf(w, "Health check passed in %s", t) fmt.Fprintf(w, "Health check passed in %s", t)
@ -112,13 +112,13 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
keys, err := s.storage.GetKeys() keys, err := s.storage.GetKeys()
if err != nil { if err != nil {
s.logger.Errorf("failed to get keys: %v", err) s.logger.Errorf("failed to get keys: %v", err)
s.renderError(w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return return
} }
if keys.SigningKeyPub == nil { if keys.SigningKeyPub == nil {
s.logger.Errorf("No public keys found.") s.logger.Errorf("No public keys found.")
s.renderError(w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return return
} }
@ -133,7 +133,7 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
data, err := json.MarshalIndent(jwks, "", " ") data, err := json.MarshalIndent(jwks, "", " ")
if err != nil { if err != nil {
s.logger.Errorf("failed to marshal discovery data: %v", err) s.logger.Errorf("failed to marshal discovery data: %v", err)
s.renderError(w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return return
} }
maxAge := keys.NextRotation.Sub(s.now()) maxAge := keys.NextRotation.Sub(s.now())
@ -214,7 +214,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
status = err.Status() status = err.Status()
} }
s.renderError(w, status, err.Error()) s.renderError(r, w, status, err.Error())
return return
} }
@ -226,14 +226,14 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
authReq.Expiry = s.now().Add(s.authRequestsValidFor) authReq.Expiry = s.now().Add(s.authRequestsValidFor)
if err := s.storage.CreateAuthRequest(*authReq); err != nil { if err := s.storage.CreateAuthRequest(*authReq); err != nil {
s.logger.Errorf("Failed to create authorization request: %v", err) s.logger.Errorf("Failed to create authorization request: %v", err)
s.renderError(w, http.StatusInternalServerError, "Failed to connect to the database.") s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.")
return return
} }
connectors, err := s.storage.ListConnectors() connectors, err := s.storage.ListConnectors()
if err != nil { if err != nil {
s.logger.Errorf("Failed to get list of connectors: %v", err) s.logger.Errorf("Failed to get list of connectors: %v", err)
s.renderError(w, http.StatusInternalServerError, "Failed to retrieve connector list.") s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.")
return return
} }
@ -271,7 +271,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
i++ i++
} }
if err := s.templates.login(w, connectorInfos); err != nil { if err := s.templates.login(r, w, connectorInfos, r.URL.Path); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Errorf("Server template error: %v", err)
} }
} }
@ -281,7 +281,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
conn, err := s.getConnector(connID) conn, err := s.getConnector(connID)
if err != nil { if err != nil {
s.logger.Errorf("Failed to create authorization request: %v", err) s.logger.Errorf("Failed to create authorization request: %v", err)
s.renderError(w, http.StatusBadRequest, "Requested resource does not exist") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
return return
} }
@ -291,9 +291,9 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
s.logger.Errorf("Failed to get auth request: %v", err) s.logger.Errorf("Failed to get auth request: %v", err)
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
s.renderError(w, http.StatusBadRequest, "Login session expired.") s.renderError(r, w, http.StatusBadRequest, "Login session expired.")
} else { } else {
s.renderError(w, http.StatusInternalServerError, "Database error.") s.renderError(r, w, http.StatusInternalServerError, "Database error.")
} }
return return
} }
@ -306,7 +306,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
} }
if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil { if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil {
s.logger.Errorf("Failed to set connector ID on auth request: %v", err) s.logger.Errorf("Failed to set connector ID on auth request: %v", err)
s.renderError(w, http.StatusInternalServerError, "Database error.") s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return return
} }
} }
@ -324,19 +324,19 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID) callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID)
if err != nil { if err != nil {
s.logger.Errorf("Connector %q returned error when creating callback: %v", connID, err) s.logger.Errorf("Connector %q returned error when creating callback: %v", connID, err)
s.renderError(w, http.StatusInternalServerError, "Login error.") s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return return
} }
http.Redirect(w, r, callbackURL, http.StatusFound) http.Redirect(w, r, callbackURL, http.StatusFound)
case connector.PasswordConnector: case connector.PasswordConnector:
if err := s.templates.password(w, r.URL.String(), "", usernamePrompt(conn), false, showBacklink); err != nil { if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(conn), false, showBacklink, r.URL.Path); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Errorf("Server template error: %v", err)
} }
case connector.SAMLConnector: case connector.SAMLConnector:
action, value, err := conn.POSTData(scopes, authReqID) action, value, err := conn.POSTData(scopes, authReqID)
if err != nil { if err != nil {
s.logger.Errorf("Creating SAML data: %v", err) s.logger.Errorf("Creating SAML data: %v", err)
s.renderError(w, http.StatusInternalServerError, "Connector Login Error") s.renderError(r, w, http.StatusInternalServerError, "Connector Login Error")
return return
} }
@ -358,12 +358,12 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
</body> </body>
</html>`, action, value, authReqID) </html>`, action, value, authReqID)
default: default:
s.renderError(w, http.StatusBadRequest, "Requested resource does not exist.") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
} }
case http.MethodPost: case http.MethodPost:
passwordConnector, ok := conn.Connector.(connector.PasswordConnector) passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
if !ok { if !ok {
s.renderError(w, http.StatusBadRequest, "Requested resource does not exist.") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
return return
} }
@ -373,11 +373,11 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password) identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password)
if err != nil { if err != nil {
s.logger.Errorf("Failed to login user: %v", err) s.logger.Errorf("Failed to login user: %v", err)
s.renderError(w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err)) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err))
return return
} }
if !ok { if !ok {
if err := s.templates.password(w, r.URL.String(), username, usernamePrompt(passwordConnector), true, showBacklink); err != nil { if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(passwordConnector), true, showBacklink, r.URL.Path); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Errorf("Server template error: %v", err)
} }
return return
@ -385,13 +385,13 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
if err != nil { if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err) s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(w, http.StatusInternalServerError, "Login error.") s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return return
} }
http.Redirect(w, r, redirectURL, http.StatusSeeOther) http.Redirect(w, r, redirectURL, http.StatusSeeOther)
default: default:
s.renderError(w, http.StatusBadRequest, "Unsupported request method.") s.renderError(r, w, http.StatusBadRequest, "Unsupported request method.")
} }
} }
@ -400,16 +400,16 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
switch r.Method { switch r.Method {
case http.MethodGet: // OAuth2 callback case http.MethodGet: // OAuth2 callback
if authID = r.URL.Query().Get("state"); authID == "" { if authID = r.URL.Query().Get("state"); authID == "" {
s.renderError(w, http.StatusBadRequest, "User session error.") s.renderError(r, w, http.StatusBadRequest, "User session error.")
return return
} }
case http.MethodPost: // SAML POST binding case http.MethodPost: // SAML POST binding
if authID = r.PostFormValue("RelayState"); authID == "" { if authID = r.PostFormValue("RelayState"); authID == "" {
s.renderError(w, http.StatusBadRequest, "User session error.") s.renderError(r, w, http.StatusBadRequest, "User session error.")
return return
} }
default: default:
s.renderError(w, http.StatusBadRequest, "Method not supported") s.renderError(r, w, http.StatusBadRequest, "Method not supported")
return return
} }
@ -417,24 +417,24 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
if err != nil { if err != nil {
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
s.logger.Errorf("Invalid 'state' parameter provided: %v", err) s.logger.Errorf("Invalid 'state' parameter provided: %v", err)
s.renderError(w, http.StatusBadRequest, "Requested resource does not exist.") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
return return
} }
s.logger.Errorf("Failed to get auth request: %v", err) s.logger.Errorf("Failed to get auth request: %v", err)
s.renderError(w, http.StatusInternalServerError, "Database error.") s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return return
} }
if connID := mux.Vars(r)["connector"]; connID != "" && connID != authReq.ConnectorID { if connID := mux.Vars(r)["connector"]; connID != "" && connID != authReq.ConnectorID {
s.logger.Errorf("Connector mismatch: authentication started with id %q, but callback for id %q was triggered", authReq.ConnectorID, connID) s.logger.Errorf("Connector mismatch: authentication started with id %q, but callback for id %q was triggered", authReq.ConnectorID, connID)
s.renderError(w, http.StatusInternalServerError, "Requested resource does not exist.") s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return return
} }
conn, err := s.getConnector(authReq.ConnectorID) conn, err := s.getConnector(authReq.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err) s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err)
s.renderError(w, http.StatusInternalServerError, "Requested resource does not exist.") s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return return
} }
@ -443,32 +443,32 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
case connector.CallbackConnector: case connector.CallbackConnector:
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
s.logger.Errorf("SAML request mapped to OAuth2 connector") s.logger.Errorf("SAML request mapped to OAuth2 connector")
s.renderError(w, http.StatusBadRequest, "Invalid request") s.renderError(r, w, http.StatusBadRequest, "Invalid request")
return return
} }
identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r) identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r)
case connector.SAMLConnector: case connector.SAMLConnector:
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
s.logger.Errorf("OAuth2 request mapped to SAML connector") s.logger.Errorf("OAuth2 request mapped to SAML connector")
s.renderError(w, http.StatusBadRequest, "Invalid request") s.renderError(r, w, http.StatusBadRequest, "Invalid request")
return return
} }
identity, err = conn.HandlePOST(parseScopes(authReq.Scopes), r.PostFormValue("SAMLResponse"), authReq.ID) identity, err = conn.HandlePOST(parseScopes(authReq.Scopes), r.PostFormValue("SAMLResponse"), authReq.ID)
default: default:
s.renderError(w, http.StatusInternalServerError, "Requested resource does not exist.") s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return return
} }
if err != nil { if err != nil {
s.logger.Errorf("Failed to authenticate: %v", err) s.logger.Errorf("Failed to authenticate: %v", err)
s.renderError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to authenticate: %v", err)) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Failed to authenticate: %v", err))
return return
} }
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
if err != nil { if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err) s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(w, http.StatusInternalServerError, "Login error.") s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return return
} }
@ -511,12 +511,12 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
authReq, err := s.storage.GetAuthRequest(r.FormValue("req")) authReq, err := s.storage.GetAuthRequest(r.FormValue("req"))
if err != nil { if err != nil {
s.logger.Errorf("Failed to get auth request: %v", err) s.logger.Errorf("Failed to get auth request: %v", err)
s.renderError(w, http.StatusInternalServerError, "Database error.") s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return return
} }
if !authReq.LoggedIn { if !authReq.LoggedIn {
s.logger.Errorf("Auth request does not have an identity for approval") s.logger.Errorf("Auth request does not have an identity for approval")
s.renderError(w, http.StatusInternalServerError, "Login process not yet finalized.") s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.")
return return
} }
@ -529,15 +529,15 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
client, err := s.storage.GetClient(authReq.ClientID) client, err := s.storage.GetClient(authReq.ClientID)
if err != nil { if err != nil {
s.logger.Errorf("Failed to get client %q: %v", authReq.ClientID, err) s.logger.Errorf("Failed to get client %q: %v", authReq.ClientID, err)
s.renderError(w, http.StatusInternalServerError, "Failed to retrieve client.") s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve client.")
return return
} }
if err := s.templates.approval(w, authReq.ID, authReq.Claims.Username, client.Name, authReq.Scopes); err != nil { if err := s.templates.approval(r, w, authReq.ID, authReq.Claims.Username, client.Name, authReq.Scopes, r.URL.Path); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Errorf("Server template error: %v", err)
} }
case http.MethodPost: case http.MethodPost:
if r.FormValue("approval") != "approve" { if r.FormValue("approval") != "approve" {
s.renderError(w, http.StatusInternalServerError, "Approval rejected.") s.renderError(r, w, http.StatusInternalServerError, "Approval rejected.")
return return
} }
s.sendCodeResponse(w, r, authReq) s.sendCodeResponse(w, r, authReq)
@ -546,22 +546,22 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) { func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
if s.now().After(authReq.Expiry) { if s.now().After(authReq.Expiry) {
s.renderError(w, http.StatusBadRequest, "User session has expired.") s.renderError(r, w, http.StatusBadRequest, "User session has expired.")
return return
} }
if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil { if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("Failed to delete authorization request: %v", err) s.logger.Errorf("Failed to delete authorization request: %v", err)
s.renderError(w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
} else { } else {
s.renderError(w, http.StatusBadRequest, "User session error.") s.renderError(r, w, http.StatusBadRequest, "User session error.")
} }
return return
} }
u, err := url.Parse(authReq.RedirectURI) u, err := url.Parse(authReq.RedirectURI)
if err != nil { if err != nil {
s.renderError(w, http.StatusInternalServerError, "Invalid redirect URI.") s.renderError(r, w, http.StatusInternalServerError, "Invalid redirect URI.")
return return
} }
@ -598,14 +598,14 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
} }
if err := s.storage.CreateAuthCode(code); err != nil { if err := s.storage.CreateAuthCode(code); err != nil {
s.logger.Errorf("Failed to create auth code: %v", err) s.logger.Errorf("Failed to create auth code: %v", err)
s.renderError(w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return return
} }
// Implicit and hybrid flows that try to use the OOB redirect URI are // Implicit and hybrid flows that try to use the OOB redirect URI are
// rejected earlier. If we got here we're using the code flow. // rejected earlier. If we got here we're using the code flow.
if authReq.RedirectURI == redirectURIOOB { if authReq.RedirectURI == redirectURIOOB {
if err := s.templates.oob(w, code.ID); err != nil { if err := s.templates.oob(r, w, code.ID, r.URL.Path); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Errorf("Server template error: %v", err)
} }
return return
@ -1119,8 +1119,8 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, r
w.Write(data) w.Write(data)
} }
func (s *Server) renderError(w http.ResponseWriter, status int, description string) { func (s *Server) renderError(r *http.Request, w http.ResponseWriter, status int, description string) {
if err := s.templates.err(w, status, description); err != nil { if err := s.templates.err(r, w, status, description); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Errorf("Server template error: %v", err)
} }
} }