fix: add an extra endpoint to avoid refresh generating AuthRequests.

By adding an extra endpoint and a redirect, we can avoid a situation
where it's trivially easy to generate a large number of AuthRequests
by hitting F5/refresh in the browser.

Signed-off-by: Alastair Houghton <alastair@alastairs-place.net>
This commit is contained in:
Alastair Houghton 2021-05-21 11:03:22 +01:00
parent 030a6459d6
commit cd0c24ec4d
3 changed files with 67 additions and 15 deletions

View file

@ -128,7 +128,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.logger.Errorf("Failed to parse arguments: %v", err) s.logger.Errorf("Failed to parse arguments: %v", err)
s.renderError(r, w, http.StatusBadRequest, "Bad query/form arguments") s.renderError(r, w, http.StatusBadRequest, err.Error())
return return
} }
@ -141,6 +141,9 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
return return
} }
// We don't need connector_id any more
r.Form.Del("connector_id")
// Construct a URL with all of the arguments in its query // Construct a URL with all of the arguments in its query
connURL := url.URL{ connURL := url.URL{
RawQuery: r.Form.Encode(), RawQuery: r.Form.Encode(),
@ -160,11 +163,8 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
} }
if len(connectors) == 1 && !s.alwaysShowLogin { if len(connectors) == 1 && !s.alwaysShowLogin {
for _, c := range connectors { connURL.Path = s.absPath("/auth", connectors[0].ID)
connURL.Path = s.absPath("/auth", c.ID) http.Redirect(w, r, connURL.String(), http.StatusFound)
http.Redirect(w, r, connURL.String(), http.StatusFound)
return
}
} }
connectorInfos := make([]connectorInfo, len(connectors)) connectorInfos := make([]connectorInfo, len(connectors))
@ -258,9 +258,15 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
} }
http.Redirect(w, r, callbackURL, http.StatusFound) http.Redirect(w, r, callbackURL, http.StatusFound)
case connector.PasswordConnector: case connector.PasswordConnector:
if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(conn), false, backLink); err != nil { loginURL := url.URL{
s.logger.Errorf("Server template error: %v", err) Path: s.absPath("/auth", connID, "login"),
} }
q := loginURL.Query()
q.Set("state", authReq.ID)
q.Set("back", backLink)
loginURL.RawQuery = q.Encode()
http.Redirect(w, r, loginURL.String(), http.StatusFound)
case connector.SAMLConnector: case connector.SAMLConnector:
action, value, err := conn.POSTData(scopes, authReq.ID) action, value, err := conn.POSTData(scopes, authReq.ID)
if err != nil { if err != nil {
@ -289,29 +295,75 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
default: default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
} }
case http.MethodPost: default:
passwordConnector, ok := conn.Connector.(connector.PasswordConnector) s.renderError(r, w, http.StatusBadRequest, "Unsupported request method.")
if !ok { }
}
func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
authID := r.URL.Query().Get("state")
if authID == "" {
s.renderError(r, w, http.StatusBadRequest, "User session error.")
return
}
backLink := r.URL.Query().Get("back")
authReq, err := s.storage.GetAuthRequest(authID)
if err != nil {
if err == storage.ErrNotFound {
s.logger.Errorf("Invalid 'state' parameter provided: %v", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
return return
} }
s.logger.Errorf("Failed to get auth request: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return
}
if connID := mux.Vars(r)["connector"]; connID != "" && connID != authReq.ConnectorID {
s.logger.Errorf("Connector mismatch: authentication started with id %q, but password login for id %q was triggered", authReq.ConnectorID, connID)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return
}
conn, err := s.getConnector(authReq.ConnectorID)
if err != nil {
s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return
}
pwConn, ok := conn.Connector.(connector.PasswordConnector)
if !ok {
s.logger.Errorf("Expected password connector in handlePasswordLogin(), but got %v", pwConn)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return
}
switch r.Method {
case http.MethodGet:
if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil {
s.logger.Errorf("Server template error: %v", err)
}
case http.MethodPost:
username := r.FormValue("login") username := r.FormValue("login")
password := r.FormValue("password") password := r.FormValue("password")
scopes := parseScopes(authReq.Scopes)
identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password) identity, ok, err := pwConn.Login(r.Context(), scopes, username, password)
if err != nil { if err != nil {
s.logger.Errorf("Failed to login user: %v", err) s.logger.Errorf("Failed to login user: %v", err)
s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err)) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err))
return return
} }
if !ok { if !ok {
if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(passwordConnector), true, backLink); err != nil { if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(pwConn), true, backLink); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Errorf("Server template error: %v", err)
} }
return return
} }
redirectURL, err := s.finalizeLogin(identity, *authReq, conn.Connector) redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
if err != nil { if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err) s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.") s.renderError(r, w, http.StatusInternalServerError, "Login error.")

View file

@ -7,7 +7,6 @@ import (
"errors" "errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"testing" "testing"
"time" "time"

View file

@ -341,6 +341,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleWithCORS("/userinfo", s.handleUserInfo) handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization) handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/auth/{connector}/login", s.handlePasswordLogin)
handleFunc("/device", s.handleDeviceExchange) handleFunc("/device", s.handleDeviceExchange)
handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/auth/verify_code", s.verifyUserCode)
handleFunc("/device/code", s.handleDeviceCode) handleFunc("/device/code", s.handleDeviceCode)