diff --git a/server/handlers.go b/server/handlers.go index d25cac52..2a4f8c71 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -158,7 +158,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { return } } - s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound) + s.renderError(r, w, http.StatusBadRequest, "Connector ID does not match a valid Connector") return } @@ -187,21 +187,16 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { authReq, err := s.parseAuthorizationRequest(r) if err != nil { s.logger.Errorf("Failed to parse authorization request: %v", err) - status := http.StatusInternalServerError - // If this is an authErr, let's let it handle the error, or update the HTTP - // status code - if err, ok := err.(*authErr); ok { - if handler, ok := err.Handle(); ok { - // client_id and redirect_uri checked out and we can redirect back to - // the client with the error. - handler.ServeHTTP(w, r) - return - } - status = err.Status() + switch authErr := err.(type) { + case *redirectedAuthErr: + authErr.Handler().ServeHTTP(w, r) + case *displayedAuthErr: + s.renderError(r, w, authErr.Status, err.Error()) + default: + panic("unsupported error type") } - s.renderError(r, w, status, err.Error()) return } @@ -770,7 +765,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { case grantTypePassword: s.withClientFromStorage(w, r, s.handlePasswordGrant) default: - s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) + s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest) } } diff --git a/server/oauth2.go b/server/oauth2.go index 00beb6ff..23f06b82 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -29,32 +29,35 @@ import ( // TODO(ericchiang): clean this file up and figure out more idiomatic error handling. -// authErr is an error response to an authorization request. // See: https://tools.ietf.org/html/rfc6749#section-4.1.2.1 -type authErr struct { + +// displayedAuthErr is an error that should be displayed to the user as a web page +type displayedAuthErr struct { + Status int + Description string +} + +func (err *displayedAuthErr) Error() string { + return err.Description +} + +func newDisplayedErr(status int, format string, a ...interface{}) *displayedAuthErr { + return &displayedAuthErr{status, fmt.Sprintf(format, a...)} +} + +// redirectedAuthErr is an error that should be reported back to the client by 302 redirect +type redirectedAuthErr struct { State string RedirectURI string Type string Description string } -func (err *authErr) Status() int { - if err.State == errServerError { - return http.StatusInternalServerError - } - return http.StatusBadRequest -} - -func (err *authErr) Error() string { +func (err *redirectedAuthErr) Error() string { return err.Description } -func (err *authErr) Handle() (http.Handler, bool) { - // Didn't get a valid redirect URI. - if err.RedirectURI == "" { - return nil, false - } - +func (err *redirectedAuthErr) Handler() http.Handler { hf := func(w http.ResponseWriter, r *http.Request) { v := url.Values{} v.Add("state", err.State) @@ -70,7 +73,7 @@ func (err *authErr) Handle() (http.Handler, bool) { } http.Redirect(w, r, redirectURI, http.StatusSeeOther) } - return http.HandlerFunc(hf), true + return http.HandlerFunc(hf) } func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) error { @@ -102,7 +105,6 @@ const ( errUnsupportedGrantType = "unsupported_grant_type" errInvalidGrant = "invalid_grant" errInvalidClient = "invalid_client" - errInvalidConnectorID = "invalid_connector_id" ) const ( @@ -408,12 +410,12 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str // parse the initial request from the OAuth2 client. func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, error) { if err := r.ParseForm(); err != nil { - return nil, &authErr{"", "", errInvalidRequest, "Failed to parse request body."} + return nil, newDisplayedErr(http.StatusBadRequest, "Failed to parse request.") } q := r.Form redirectURI, err := url.QueryUnescape(q.Get("redirect_uri")) if err != nil { - return nil, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."} + return nil, newDisplayedErr(http.StatusBadRequest, "No redirect_uri provided.") } clientID := q.Get("client_id") @@ -434,45 +436,44 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques client, err := s.storage.GetClient(clientID) if err != nil { if err == storage.ErrNotFound { - description := fmt.Sprintf("Invalid client_id (%q).", clientID) - return nil, &authErr{"", "", errUnauthorizedClient, description} + return nil, newDisplayedErr(http.StatusNotFound, "Invalid client_id (%q).", clientID) } s.logger.Errorf("Failed to get client: %v", err) - return nil, &authErr{"", "", errServerError, ""} - } - - if connectorID != "" { - connectors, err := s.storage.ListConnectors() - if err != nil { - return nil, &authErr{"", "", errServerError, "Unable to retrieve connectors"} - } - if !validateConnectorID(connectors, connectorID) { - return nil, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"} - } + return nil, newDisplayedErr(http.StatusInternalServerError, "Database error.") } if !validateRedirectURI(client, redirectURI) { - description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) - return nil, &authErr{"", "", errInvalidRequest, description} + return nil, newDisplayedErr(http.StatusBadRequest, "Unregistered redirect_uri (%q).", redirectURI) } if redirectURI == deviceCallbackURI && client.Public { redirectURI = s.issuerURL.Path + deviceCallbackURI } // From here on out, we want to redirect back to the client with an error. - newErr := func(typ, format string, a ...interface{}) *authErr { - return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)} + newRedirectedErr := func(typ, format string, a ...interface{}) *redirectedAuthErr { + return &redirectedAuthErr{state, redirectURI, typ, fmt.Sprintf(format, a...)} + } + + if connectorID != "" { + connectors, err := s.storage.ListConnectors() + if err != nil { + s.logger.Errorf("Failed to list connectors: %v", err) + return nil, newRedirectedErr(errServerError, "Unable to retrieve connectors") + } + if !validateConnectorID(connectors, connectorID) { + return nil, newRedirectedErr(errInvalidRequest, "Invalid ConnectorID") + } } // dex doesn't support request parameter and must return request_not_supported error // https://openid.net/specs/openid-connect-core-1_0.html#6.1 if q.Get("request") != "" { - return nil, newErr(errRequestNotSupported, "Server does not support request parameter.") + return nil, newRedirectedErr(errRequestNotSupported, "Server does not support request parameter.") } if codeChallengeMethod != codeChallengeMethodS256 && codeChallengeMethod != codeChallengeMethodPlain { description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod) - return nil, newErr(errInvalidRequest, description) + return nil, newRedirectedErr(errInvalidRequest, description) } var ( @@ -494,7 +495,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques isTrusted, err := s.validateCrossClientTrust(clientID, peerID) if err != nil { - return nil, newErr(errServerError, "Internal server error.") + return nil, newRedirectedErr(errServerError, "Internal server error.") } if !isTrusted { invalidScopes = append(invalidScopes, scope) @@ -502,13 +503,13 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques } } if !hasOpenIDScope { - return nil, newErr(errInvalidScope, `Missing required scope(s) ["openid"].`) + return nil, newRedirectedErr(errInvalidScope, `Missing required scope(s) ["openid"].`) } if len(unrecognized) > 0 { - return nil, newErr(errInvalidScope, "Unrecognized scope(s) %q", unrecognized) + return nil, newRedirectedErr(errInvalidScope, "Unrecognized scope(s) %q", unrecognized) } if len(invalidScopes) > 0 { - return nil, newErr(errInvalidScope, "Client can't request scope(s) %q", invalidScopes) + return nil, newRedirectedErr(errInvalidScope, "Client can't request scope(s) %q", invalidScopes) } var rt struct { @@ -526,23 +527,23 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques case responseTypeToken: rt.token = true default: - return nil, newErr(errInvalidRequest, "Invalid response type %q", responseType) + return nil, newRedirectedErr(errInvalidRequest, "Invalid response type %q", responseType) } if !s.supportedResponseTypes[responseType] { - return nil, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType) + return nil, newRedirectedErr(errUnsupportedResponseType, "Unsupported response type %q", responseType) } } if len(responseTypes) == 0 { - return nil, newErr(errInvalidRequest, "No response_type provided") + return nil, newRedirectedErr(errInvalidRequest, "No response_type provided") } if rt.token && !rt.code && !rt.idToken { // "token" can't be provided by its own. // // https://openid.net/specs/openid-connect-core-1_0.html#Authentication - return nil, newErr(errInvalidRequest, "Response type 'token' must be provided with type 'id_token' and/or 'code'") + return nil, newRedirectedErr(errInvalidRequest, "Response type 'token' must be provided with type 'id_token' and/or 'code'") } if !rt.code { // Either "id_token token" or "id_token" has been provided which implies the @@ -550,13 +551,13 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques // // https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest if nonce == "" { - return nil, newErr(errInvalidRequest, "Response type 'token' requires a 'nonce' value.") + return nil, newRedirectedErr(errInvalidRequest, "Response type 'token' requires a 'nonce' value.") } } if rt.token { if redirectURI == redirectURIOOB { err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB) - return nil, newErr(errInvalidRequest, err) + return nil, newRedirectedErr(errInvalidRequest, err) } } diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 518e22ee..710382aa 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -10,7 +10,6 @@ import ( "strings" "testing" - "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" "github.com/dexidp/dex/storage" @@ -27,8 +26,7 @@ func TestParseAuthorizationRequest(t *testing.T) { queryParams map[string]string - wantErr bool - exactError *authErr + expectedError error }{ { name: "normal request", @@ -78,7 +76,7 @@ func TestParseAuthorizationRequest(t *testing.T) { "response_type": "code", "scope": "openid email profile", }, - wantErr: true, + expectedError: &displayedAuthErr{Status: http.StatusNotFound}, }, { name: "invalid redirect uri", @@ -95,7 +93,7 @@ func TestParseAuthorizationRequest(t *testing.T) { "response_type": "code", "scope": "openid email profile", }, - wantErr: true, + expectedError: &displayedAuthErr{Status: http.StatusBadRequest}, }, { name: "implicit flow", @@ -128,7 +126,7 @@ func TestParseAuthorizationRequest(t *testing.T) { "response_type": "code id_token", "scope": "openid email profile", }, - wantErr: true, + expectedError: &redirectedAuthErr{Type: errUnsupportedResponseType}, }, { name: "only token response type", @@ -145,7 +143,7 @@ func TestParseAuthorizationRequest(t *testing.T) { "response_type": "token", "scope": "openid email profile", }, - wantErr: true, + expectedError: &redirectedAuthErr{Type: errInvalidRequest}, }, { name: "choose connector_id", @@ -197,7 +195,7 @@ func TestParseAuthorizationRequest(t *testing.T) { "response_type": "code id_token", "scope": "openid email profile", }, - wantErr: true, + expectedError: &redirectedAuthErr{Type: errInvalidRequest}, }, { name: "PKCE code_challenge_method plain", @@ -269,7 +267,7 @@ func TestParseAuthorizationRequest(t *testing.T) { "code_challenge_method": "invalid_method", "scope": "openid email profile", }, - wantErr: true, + expectedError: &redirectedAuthErr{Type: errInvalidRequest}, }, { name: "No response type", @@ -287,12 +285,7 @@ func TestParseAuthorizationRequest(t *testing.T) { "code_challenge_method": "plain", "scope": "openid email profile", }, - wantErr: true, - exactError: &authErr{ - RedirectURI: "https://example.com/bar", - Type: "invalid_request", - Description: "No response_type provided", - }, + expectedError: &redirectedAuthErr{Type: errInvalidRequest}, }, } @@ -321,13 +314,34 @@ func TestParseAuthorizationRequest(t *testing.T) { } _, err := server.parseAuthorizationRequest(req) - if tc.wantErr { - require.Error(t, err) - if tc.exactError != nil { - require.Equal(t, tc.exactError, err) + if tc.expectedError == nil { + if err != nil { + t.Errorf("%s: expected no error", tc.name) } } else { - require.NoError(t, err) + switch expectedErr := tc.expectedError.(type) { + case *redirectedAuthErr: + e, ok := err.(*redirectedAuthErr) + if !ok { + t.Fatalf("%s: expected redirectedAuthErr error", tc.name) + } + if e.Type != expectedErr.Type { + t.Errorf("%s: expected error type %v, got %v", tc.name, expectedErr.Type, e.Type) + } + if e.RedirectURI != tc.queryParams["redirect_uri"] { + t.Errorf("%s: expected error to be returned in redirect to %v", tc.name, tc.queryParams["redirect_uri"]) + } + case *displayedAuthErr: + e, ok := err.(*displayedAuthErr) + if !ok { + t.Fatalf("%s: expected displayedAuthErr error", tc.name) + } + if e.Status != expectedErr.Status { + t.Errorf("%s: expected http status %v, got %v", tc.name, expectedErr.Status, e.Status) + } + default: + t.Fatalf("%s: unsupported error type", tc.name) + } } }() } diff --git a/server/server_test.go b/server/server_test.go index 24f76199..682d16a7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -532,6 +532,17 @@ func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) return nil }, }, + { + name: "unsupported grant type", + retrieveTokenOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("grant_type", "unsupported"), + }, + handleToken: basicIDTokenVerify, + tokenError: ErrorResponse{ + Error: errUnsupportedGrantType, + StatusCode: http.StatusBadRequest, + }, + }, { // This test ensures that PKCE work in "plain" mode (no code_challenge_method specified) name: "PKCE with plain", @@ -678,7 +689,7 @@ func TestOAuth2CodeFlow(t *testing.T) { tests := makeOAuth2Tests(clientID, clientSecret, now) for _, tc := range tests.tests { - func() { + t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -812,7 +823,7 @@ func TestOAuth2CodeFlow(t *testing.T) { if respDump, err = httputil.DumpResponse(resp, true); err != nil { t.Fatal(err) } - }() + }) } }