fix: unsupported request parameter error

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh 2021-01-22 02:06:18 +04:00
parent 369e16e97e
commit 30a5dade0f
2 changed files with 36 additions and 6 deletions

View file

@ -95,6 +95,7 @@ const (
errUnauthorizedClient = "unauthorized_client" errUnauthorizedClient = "unauthorized_client"
errAccessDenied = "access_denied" errAccessDenied = "access_denied"
errUnsupportedResponseType = "unsupported_response_type" errUnsupportedResponseType = "unsupported_response_type"
errRequestNotSupported = "request_not_supported"
errInvalidScope = "invalid_scope" errInvalidScope = "invalid_scope"
errServerError = "server_error" errServerError = "server_error"
errTemporarilyUnavailable = "temporarily_unavailable" errTemporarilyUnavailable = "temporarily_unavailable"
@ -453,6 +454,12 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)} return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
} }
// dex doesn't support request parameter and must return request_not_supported error
// https://openid.net/specs/openid-connect-core-1_0.html#6.1
if q.Get("request") != "" {
return nil, newErr(errRequestNotSupported, "Server does not support request parameter.")
}
if codeChallengeMethod != CodeChallengeMethodS256 && codeChallengeMethod != CodeChallengeMethodPlain { if codeChallengeMethod != CodeChallengeMethodS256 && codeChallengeMethod != CodeChallengeMethodPlain {
description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod) description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod)
return nil, newErr(errInvalidRequest, description) return nil, newErr(errInvalidRequest, description)

View file

@ -25,6 +25,7 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2" "golang.org/x/oauth2"
jose "gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
@ -223,6 +224,9 @@ type test struct {
// extra parameters to pass when retrieving id token // extra parameters to pass when retrieving id token
retrieveTokenOptions []oauth2.AuthCodeOption retrieveTokenOptions []oauth2.AuthCodeOption
// define an error response, when the test expects an error on the auth endpoint
authError *OAuth2ErrorResponse
// define an error response, when the test expects an error on the token endpoint // define an error response, when the test expects an error on the token endpoint
tokenError ErrorResponse tokenError ErrorResponse
} }
@ -607,6 +611,19 @@ func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time)
StatusCode: http.StatusBadRequest, StatusCode: http.StatusBadRequest,
}, },
}, },
{
name: "Request parameter in authorization query",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("request", "anything"),
},
authError: &OAuth2ErrorResponse{
Error: errRequestNotSupported,
ErrorDescription: "Server does not support request parameter.",
},
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
return nil
},
},
}, },
} }
} }
@ -665,7 +682,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
state = "a_state" state = "a_state"
) )
defer func() { defer func() {
if !gotCode { if !gotCode && tc.authError == nil {
t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump) t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump)
} }
}() }()
@ -684,14 +701,20 @@ func TestOAuth2CodeFlow(t *testing.T) {
// Did dex return an error? // Did dex return an error?
if errType := q.Get("error"); errType != "" { if errType := q.Get("error"); errType != "" {
if desc := q.Get("error_description"); desc != "" { description := q.Get("error_description")
t.Errorf("got error from server %s: %s", errType, desc)
if tc.authError == nil {
if description != "" {
t.Errorf("got error from server %s: %s", errType, description)
} else { } else {
t.Errorf("got error from server %s", errType) t.Errorf("got error from server %s", errType)
} }
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
require.Equal(t, *tc.authError, OAuth2ErrorResponse{Error: errType, ErrorDescription: description})
return
}
// Grab code, exchange for token. // Grab code, exchange for token.
if code := q.Get("code"); code != "" { if code := q.Get("code"); code != "" {