fix: add an extra endpoint to avoid refresh generating AuthRequests.

By adding an extra endpoint and a redirect, we can avoid a situation
where it's trivially easy to generate a large number of AuthRequests
by hitting F5/refresh in the browser.

Signed-off-by: Alastair Houghton <alastair@alastairs-place.net>
This commit is contained in:
Alastair Houghton 2021-05-21 11:03:22 +01:00
parent 030a6459d6
commit cd0c24ec4d
3 changed files with 67 additions and 15 deletions

View file

@ -128,7 +128,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
s.logger.Errorf("Failed to parse arguments: %v", err)
s.renderError(r, w, http.StatusBadRequest, "Bad query/form arguments")
s.renderError(r, w, http.StatusBadRequest, err.Error())
return
}
@ -141,6 +141,9 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
return
}
// We don't need connector_id any more
r.Form.Del("connector_id")
// Construct a URL with all of the arguments in its query
connURL := url.URL{
RawQuery: r.Form.Encode(),
@ -160,11 +163,8 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
}
if len(connectors) == 1 && !s.alwaysShowLogin {
for _, c := range connectors {
connURL.Path = s.absPath("/auth", c.ID)
http.Redirect(w, r, connURL.String(), http.StatusFound)
return
}
connURL.Path = s.absPath("/auth", connectors[0].ID)
http.Redirect(w, r, connURL.String(), http.StatusFound)
}
connectorInfos := make([]connectorInfo, len(connectors))
@ -258,9 +258,15 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
}
http.Redirect(w, r, callbackURL, http.StatusFound)
case connector.PasswordConnector:
if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(conn), false, backLink); err != nil {
s.logger.Errorf("Server template error: %v", err)
loginURL := url.URL{
Path: s.absPath("/auth", connID, "login"),
}
q := loginURL.Query()
q.Set("state", authReq.ID)
q.Set("back", backLink)
loginURL.RawQuery = q.Encode()
http.Redirect(w, r, loginURL.String(), http.StatusFound)
case connector.SAMLConnector:
action, value, err := conn.POSTData(scopes, authReq.ID)
if err != nil {
@ -289,29 +295,75 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
case http.MethodPost:
passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
if !ok {
default:
s.renderError(r, w, http.StatusBadRequest, "Unsupported request method.")
}
}
func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
authID := r.URL.Query().Get("state")
if authID == "" {
s.renderError(r, w, http.StatusBadRequest, "User session error.")
return
}
backLink := r.URL.Query().Get("back")
authReq, err := s.storage.GetAuthRequest(authID)
if err != nil {
if err == storage.ErrNotFound {
s.logger.Errorf("Invalid 'state' parameter provided: %v", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
return
}
s.logger.Errorf("Failed to get auth request: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return
}
if connID := mux.Vars(r)["connector"]; connID != "" && connID != authReq.ConnectorID {
s.logger.Errorf("Connector mismatch: authentication started with id %q, but password login for id %q was triggered", authReq.ConnectorID, connID)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return
}
conn, err := s.getConnector(authReq.ConnectorID)
if err != nil {
s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return
}
pwConn, ok := conn.Connector.(connector.PasswordConnector)
if !ok {
s.logger.Errorf("Expected password connector in handlePasswordLogin(), but got %v", pwConn)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return
}
switch r.Method {
case http.MethodGet:
if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil {
s.logger.Errorf("Server template error: %v", err)
}
case http.MethodPost:
username := r.FormValue("login")
password := r.FormValue("password")
scopes := parseScopes(authReq.Scopes)
identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password)
identity, ok, err := pwConn.Login(r.Context(), scopes, username, password)
if err != nil {
s.logger.Errorf("Failed to login user: %v", err)
s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err))
return
}
if !ok {
if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(passwordConnector), true, backLink); err != nil {
if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(pwConn), true, backLink); err != nil {
s.logger.Errorf("Server template error: %v", err)
}
return
}
redirectURL, err := s.finalizeLogin(identity, *authReq, conn.Connector)
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")

View file

@ -7,7 +7,6 @@ import (
"errors"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

View file

@ -341,6 +341,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/auth/{connector}/login", s.handlePasswordLogin)
handleFunc("/device", s.handleDeviceExchange)
handleFunc("/device/auth/verify_code", s.verifyUserCode)
handleFunc("/device/code", s.handleDeviceCode)