diff --git a/server/http.go b/server/http.go index a9d2bc46..8beb1784 100644 --- a/server/http.go +++ b/server/http.go @@ -255,7 +255,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp v := r.URL.Query() v.Set("connector_id", idpc.ID()) - v.Set("response_type", "code") + v.Set("response_type", q.Get("response_type")) link.URL = httpPathAuth + "?" + v.Encode() td.Links = append(td.Links, link) } @@ -273,77 +273,92 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T } q := r.URL.Query() - register := q.Get("register") == "1" && registrationEnabled - e := q.Get("error") - if e != "" { - sessionKey := q.Get("state") - if err := srv.KillSession(sessionKey); err != nil { - log.Errorf("Failed killing sessionKey %q: %v", sessionKey, err) + + // Retrieve client id + clientid := q.Get("client_id") + + // Retrieve state + state := q.Get("state") + + // Retrieve response_type + responseType := q.Get("response_type") + + // Retrieve scopes + qscope := strings.Fields(q.Get("scope")) + + // Check client ID param + if clientid == "" { + log.Errorf("Invalid auth request: no client_id received") + writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) + return + } + + // Check redirect_uri param, but if it's empty we don't return any error here + qru := q.Get("redirect_uri") + var rURL *url.URL + if qru != "" { + ru, err := url.Parse(qru) + if err != nil { + log.Errorf("Invalid auth request: %v", err) + writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) + return } - renderLoginPage(w, r, srv, idpcs, register, tpl) - return + rURL = ru } - connectorID := q.Get("connector_id") - idpc, ok := idx[connectorID] - if !ok { - renderLoginPage(w, r, srv, idpcs, register, tpl) - return - } - - acr, err := oauth2.ParseAuthCodeRequest(q) + cm, err := srv.ClientMetadata(clientid) if err != nil { - log.Errorf("Invalid auth request: %v", err) - writeAuthError(w, err, acr.State) - return - } - - cm, err := srv.ClientMetadata(acr.ClientID) - if err != nil { - log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err) - writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State) + log.Errorf("Failed fetching client %q from repo: %v", clientid, err) + writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), state) return } if cm == nil { - log.Errorf("Client %q not found", acr.ClientID) - writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) + log.Errorf("Client %q not found", clientid) + writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return } if len(cm.RedirectURIs) == 0 { - log.Errorf("Client %q has no redirect URLs", acr.ClientID) - writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State) + log.Errorf("Client %q has no redirect URLs", clientid) + writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), state) return } - redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURIs) + redirectURL, err := client.ValidRedirectURL(rURL, cm.RedirectURIs) if err != nil { switch err { case (client.ErrorCantChooseRedirectURL): - log.Errorf("Request must provide redirect URL as client %q has registered many", acr.ClientID) - writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) + log.Errorf("Request must provide redirect URL as client %q has registered many", clientid) + writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return case (client.ErrorInvalidRedirectURL): - log.Errorf("Request provided unregistered redirect URL: %s", acr.RedirectURL) - writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) + log.Errorf("Request provided unregistered redirect URL: %s", rURL) + writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return case (client.ErrorNoValidRedirectURLs): - log.Errorf("There are no registered URLs for the requested client: %s", acr.RedirectURL) - writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) + log.Errorf("There are no registered URLs for the requested client: %s", rURL) + writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) + return + default: + log.Errorf("Unexpected error checking redirect URL for client %q: %v", clientid, err) + writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), state) return } } - if acr.ResponseType != oauth2.ResponseTypeCode { - log.Errorf("unexpected ResponseType: %v: ", acr.ResponseType) - redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), acr.State, redirectURL) + // Response type check + switch responseType { + case "code": // Add more cases as we support more response types + default: + log.Errorf("Invalid auth request: unsupported response_type") + redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), state, redirectURL) return } // Check scopes. var scopes []string foundOpenIDScope := false - for _, scope := range acr.Scope { + for _, scope := range qscope { switch scope { case "openid": foundOpenIDScope = true @@ -364,16 +379,33 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T if !foundOpenIDScope { log.Errorf("Invalid auth request: missing 'openid' in 'scope'") - writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) + writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) + return + } + + register := q.Get("register") == "1" && registrationEnabled + e := q.Get("error") + if e != "" { + if err := srv.KillSession(state); err != nil { + log.Errorf("Failed killing sessionKey %q: %v", state, err) + } + renderLoginPage(w, r, srv, idpcs, register, tpl) + return + } + + connectorID := q.Get("connector_id") + idpc, ok := idx[connectorID] + if !ok { + renderLoginPage(w, r, srv, idpcs, register, tpl) return } nonce := q.Get("nonce") - key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope) + key, err := srv.NewSession(connectorID, clientid, state, redirectURL, nonce, register, qscope) if err != nil { log.Errorf("Error creating new session: %v: ", err) - redirectAuthError(w, err, acr.State, redirectURL) + redirectAuthError(w, err, state, redirectURL) return } @@ -399,7 +431,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T lu, err := idpc.LoginURL(key, p) if err != nil { log.Errorf("Connector.LoginURL failed: %v", err) - redirectAuthError(w, err, acr.State, redirectURL) + redirectAuthError(w, err, state, redirectURL) return } diff --git a/server/http_test.go b/server/http_test.go index 94052b06..94912bc7 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -175,6 +175,29 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { }, wantCode: http.StatusBadRequest, }, + + // empty response_type + { + query: url.Values{ + "redirect_uri": []string{"http://client.example.com/callback"}, + "client_id": []string{"XXX"}, + "connector_id": []string{"fake"}, + "scope": []string{"openid"}, + }, + wantCode: http.StatusFound, + wantLocation: "http://client.example.com/callback?error=unsupported_response_type&state=", + }, + + // empty client_id + { + query: url.Values{ + "response_type": []string{"code"}, + "redirect_uri": []string{"http://unrecognized.example.com/callback"}, + "connector_id": []string{"fake"}, + "scope": []string{"openid"}, + }, + wantCode: http.StatusBadRequest, + }, } for i, tt := range tests {