diff --git a/server/handlers.go b/server/handlers.go index bd134813..342849ee 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -23,6 +25,11 @@ import ( "github.com/dexidp/dex/storage" ) +const ( + CodeChallengeMethodPlain = "plain" + CodeChallengeMethodS256 = "S256" +) + // newHealthChecker returns the healthz handler. The handler runs until the // provided context is canceled. func (s *Server) newHealthChecker(ctx context.Context) http.Handler { @@ -148,34 +155,36 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { } type discovery struct { - Issuer string `json:"issuer"` - Auth string `json:"authorization_endpoint"` - Token string `json:"token_endpoint"` - Keys string `json:"jwks_uri"` - UserInfo string `json:"userinfo_endpoint"` - DeviceEndpoint string `json:"device_authorization_endpoint"` - GrantTypes []string `json:"grant_types_supported"` - ResponseTypes []string `json:"response_types_supported"` - Subjects []string `json:"subject_types_supported"` - IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` - Scopes []string `json:"scopes_supported"` - AuthMethods []string `json:"token_endpoint_auth_methods_supported"` - Claims []string `json:"claims_supported"` + Issuer string `json:"issuer"` + Auth string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` + Keys string `json:"jwks_uri"` + UserInfo string `json:"userinfo_endpoint"` + DeviceEndpoint string `json:"device_authorization_endpoint"` + GrantTypes []string `json:"grant_types_supported"` + ResponseTypes []string `json:"response_types_supported"` + Subjects []string `json:"subject_types_supported"` + IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` + CodeChallengeAlgs []string `json:"code_challenge_methods_supported"` + Scopes []string `json:"scopes_supported"` + AuthMethods []string `json:"token_endpoint_auth_methods_supported"` + Claims []string `json:"claims_supported"` } func (s *Server) discoveryHandler() (http.HandlerFunc, error) { d := discovery{ - Issuer: s.issuerURL.String(), - Auth: s.absURL("/auth"), - Token: s.absURL("/token"), - Keys: s.absURL("/keys"), - UserInfo: s.absURL("/userinfo"), - DeviceEndpoint: s.absURL("/device/code"), - Subjects: []string{"public"}, - GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode}, - IDTokenAlgs: []string{string(jose.RS256)}, - Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, - AuthMethods: []string{"client_secret_basic"}, + Issuer: s.issuerURL.String(), + Auth: s.absURL("/auth"), + Token: s.absURL("/token"), + Keys: s.absURL("/keys"), + UserInfo: s.absURL("/userinfo"), + DeviceEndpoint: s.absURL("/device/code"), + Subjects: []string{"public"}, + GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode}, + IDTokenAlgs: []string{string(jose.RS256)}, + CodeChallengeAlgs: []string{CodeChallengeMethodS256, CodeChallengeMethodPlain}, + Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, + AuthMethods: []string{"client_secret_basic"}, Claims: []string{ "aud", "email", "email_verified", "exp", "iat", "iss", "locale", "name", "sub", @@ -643,6 +652,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe Expiry: s.now().Add(time.Minute * 30), RedirectURI: authReq.RedirectURI, ConnectorData: authReq.ConnectorData, + PKCE: authReq.PKCE, } if err := s.storage.CreateAuthCode(code); err != nil { s.logger.Errorf("Failed to create auth code: %v", err) @@ -756,6 +766,11 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { return } if client.Secret != clientSecret { + if clientSecret == "" { + s.logger.Infof("missing client_secret on token request for client: %s", client.ID) + } else { + s.logger.Infof("invalid client_secret on token request for client: %s", client.ID) + } s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) return } @@ -773,6 +788,18 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { } } +func (s *Server) calculateCodeChallenge(codeVerifier, codeChallengeMethod string) (string, error) { + switch codeChallengeMethod { + case CodeChallengeMethodPlain: + return codeVerifier, nil + case CodeChallengeMethodS256: + shaSum := sha256.Sum256([]byte(codeVerifier)) + return base64.RawURLEncoding.EncodeToString(shaSum[:]), nil + default: + return "", fmt.Errorf("unknown challenge method (%v)", codeChallengeMethod) + } +} + // handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3 func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) { code := r.PostFormValue("code") @@ -789,6 +816,31 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s return } + // RFC 7636 (PKCE) + codeChallengeFromStorage := authCode.PKCE.CodeChallenge + providedCodeVerifier := r.PostFormValue("code_verifier") + + if providedCodeVerifier != "" && codeChallengeFromStorage != "" { + calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, authCode.PKCE.CodeChallengeMethod) + if err != nil { + s.logger.Error(err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + if codeChallengeFromStorage != calculatedCodeChallenge { + s.tokenErrHelper(w, errInvalidGrant, "Invalid code_verifier.", http.StatusBadRequest) + return + } + } else if providedCodeVerifier != "" { + // Received no code_challenge on /auth, but a code_verifier on /token + s.tokenErrHelper(w, errInvalidRequest, "No PKCE flow started. Cannot check code_verifier.", http.StatusBadRequest) + return + } else if codeChallengeFromStorage != "" { + // Received PKCE request on /auth, but no code_verifier on /token + s.tokenErrHelper(w, errInvalidGrant, "Expecting parameter code_verifier in PKCE flow.", http.StatusBadRequest) + return + } + if authCode.RedirectURI != redirectURI { s.tokenErrHelper(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest) return diff --git a/server/oauth2.go b/server/oauth2.go index 2596fd4e..79e54c32 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -413,6 +413,13 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques scopes := strings.Fields(q.Get("scope")) responseTypes := strings.Fields(q.Get("response_type")) + codeChallenge := q.Get("code_challenge") + codeChallengeMethod := q.Get("code_challenge_method") + + if codeChallengeMethod == "" { + codeChallengeMethod = CodeChallengeMethodPlain + } + client, err := s.storage.GetClient(clientID) if err != nil { if err == storage.ErrNotFound { @@ -446,6 +453,11 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)} } + if codeChallengeMethod != CodeChallengeMethodS256 && codeChallengeMethod != CodeChallengeMethodPlain { + description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod) + return nil, newErr(errInvalidRequest, description) + } + var ( unrecognized []string invalidScopes []string @@ -541,6 +553,10 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques RedirectURI: redirectURI, ResponseTypes: responseTypes, ConnectorID: connectorID, + PKCE: storage.PKCE{ + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + }, }, nil } diff --git a/server/oauth2_test.go b/server/oauth2_test.go index ad122055..8db9ea59 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -197,6 +197,78 @@ func TestParseAuthorizationRequest(t *testing.T) { }, wantErr: true, }, + { + name: "PKCE code_challenge_method plain", + clients: []storage.Client{ + { + ID: "bar", + RedirectURIs: []string{"https://example.com/bar"}, + }, + }, + supportedResponseTypes: []string{"code"}, + queryParams: map[string]string{ + "client_id": "bar", + "redirect_uri": "https://example.com/bar", + "response_type": "code", + "code_challenge": "123", + "code_challenge_method": "plain", + "scope": "openid email profile", + }, + }, + { + name: "PKCE code_challenge_method default plain", + clients: []storage.Client{ + { + ID: "bar", + RedirectURIs: []string{"https://example.com/bar"}, + }, + }, + supportedResponseTypes: []string{"code"}, + queryParams: map[string]string{ + "client_id": "bar", + "redirect_uri": "https://example.com/bar", + "response_type": "code", + "code_challenge": "123", + "scope": "openid email profile", + }, + }, + { + name: "PKCE code_challenge_method S256", + clients: []storage.Client{ + { + ID: "bar", + RedirectURIs: []string{"https://example.com/bar"}, + }, + }, + supportedResponseTypes: []string{"code"}, + queryParams: map[string]string{ + "client_id": "bar", + "redirect_uri": "https://example.com/bar", + "response_type": "code", + "code_challenge": "123", + "code_challenge_method": "S256", + "scope": "openid email profile", + }, + }, + { + name: "PKCE invalid code_challenge_method", + clients: []storage.Client{ + { + ID: "bar", + RedirectURIs: []string{"https://example.com/bar"}, + }, + }, + supportedResponseTypes: []string{"code"}, + queryParams: map[string]string{ + "client_id": "bar", + "redirect_uri": "https://example.com/bar", + "response_type": "code", + "code_challenge": "123", + "code_challenge_method": "invalid_method", + "scope": "openid email profile", + }, + wantErr: true, + }, } for _, tc := range tests { diff --git a/server/server_test.go b/server/server_test.go index c36f2e85..a909d98b 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -216,6 +216,29 @@ type test struct { scopes []string // handleToken provides the OAuth2 token response for the integration test. handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token, *mock.Callback) error + + // extra parameters to pass when requesting auth_code + authCodeOptions []oauth2.AuthCodeOption + + // extra parameters to pass when retrieving id token + retrieveTokenOptions []oauth2.AuthCodeOption + + // define an error response, when the test expects an error on the token endpoint + tokenError ErrorResponse +} + +// Defines an expected error by HTTP Status Code and +// the OAuth2 error int the response json +type ErrorResponse struct { + Error string + StatusCode int +} + +// https://tools.ietf.org/html/rfc6749#section-5.2 +type OAuth2ErrorResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` } func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) oauth2Tests { @@ -229,6 +252,17 @@ func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) oidcConfig := &oidc.Config{SkipClientIDCheck: true} + basicIDTokenVerify := func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + idToken, ok := token.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id token found") + } + if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + return nil + } + return oauth2Tests{ clientID: clientID, tests: []test{ @@ -469,6 +503,110 @@ func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) return nil }, }, + { + // This test ensures that PKCE work in "plain" mode (no code_challenge_method specified) + name: "PKCE with plain", + authCodeOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", "challenge123"), + }, + retrieveTokenOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_verifier", "challenge123"), + }, + handleToken: basicIDTokenVerify, + }, + { + // This test ensures that PKCE works in "S256" mode + name: "PKCE with S256", + authCodeOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + }, + retrieveTokenOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_verifier", "challenge123"), + }, + handleToken: basicIDTokenVerify, + }, + { + // This test ensures that PKCE does fail with wrong code_verifier in "plain" mode + name: "PKCE with plain and wrong code_verifier", + authCodeOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", "challenge123"), + }, + retrieveTokenOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_verifier", "challenge124"), + }, + handleToken: basicIDTokenVerify, + tokenError: ErrorResponse{ + Error: errInvalidGrant, + StatusCode: http.StatusBadRequest, + }, + }, + { + // This test ensures that PKCE fail with wrong code_verifier in "S256" mode + name: "PKCE with S256 and wrong code_verifier", + authCodeOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + }, + retrieveTokenOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_verifier", "challenge124"), + }, + handleToken: basicIDTokenVerify, + tokenError: ErrorResponse{ + Error: errInvalidGrant, + StatusCode: http.StatusBadRequest, + }, + }, + { + // Ensure that, when PKCE flow started on /auth + // we stay in PKCE flow on /token + name: "PKCE flow expected on /token", + authCodeOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + }, + retrieveTokenOptions: []oauth2.AuthCodeOption{ + // No PKCE call on /token + }, + handleToken: basicIDTokenVerify, + tokenError: ErrorResponse{ + Error: errInvalidGrant, + StatusCode: http.StatusBadRequest, + }, + }, + { + // Ensure that when no PKCE flow was started on /auth + // we cannot switch to PKCE on /token + name: "No PKCE flow started on /auth", + authCodeOptions: []oauth2.AuthCodeOption{ + // No PKCE call on /auth + }, + retrieveTokenOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_verifier", "challenge123"), + }, + handleToken: basicIDTokenVerify, + tokenError: ErrorResponse{ + Error: errInvalidRequest, + StatusCode: http.StatusBadRequest, + }, + }, + { + // Make sure that, when we start with "S256" on /auth, we cannot downgrade to "plain" on /token + name: "PKCE with S256 and try to downgrade to plain", + authCodeOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + }, + retrieveTokenOptions: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_verifier", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"), + oauth2.SetAuthURLParam("code_challenge_method", "plain"), + }, + handleToken: basicIDTokenVerify, + tokenError: ErrorResponse{ + Error: errInvalidGrant, + StatusCode: http.StatusBadRequest, + }, + }, }, } } @@ -537,7 +675,7 @@ func TestOAuth2CodeFlow(t *testing.T) { oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/callback" { // User is visiting app first time. Redirect to dex. - http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther) + http.Redirect(w, r, oauth2Config.AuthCodeURL(state, tc.authCodeOptions...), http.StatusSeeOther) return } @@ -558,7 +696,11 @@ func TestOAuth2CodeFlow(t *testing.T) { // Grab code, exchange for token. if code := q.Get("code"); code != "" { gotCode = true - token, err := oauth2Config.Exchange(ctx, code) + token, err := oauth2Config.Exchange(ctx, code, tc.retrieveTokenOptions...) + if tc.tokenError.StatusCode != 0 { + checkErrorResponse(err, t, tc) + return + } if err != nil { t.Errorf("failed to exchange code for token: %v", err) return @@ -1170,6 +1312,30 @@ func TestKeyCacher(t *testing.T) { } } +func checkErrorResponse(err error, t *testing.T, tc test) { + if err == nil { + t.Errorf("%s: DANGEROUS! got a token when we should not get one!", tc.name) + return + } + if rErr, ok := err.(*oauth2.RetrieveError); ok { + if rErr.Response.StatusCode != tc.tokenError.StatusCode { + t.Errorf("%s: got wrong StatusCode from server %d. expected %d", + tc.name, rErr.Response.StatusCode, tc.tokenError.StatusCode) + } + details := new(OAuth2ErrorResponse) + if err := json.Unmarshal(rErr.Body, details); err != nil { + t.Errorf("%s: could not parse return json: %s", tc.name, err) + return + } + if tc.tokenError.Error != "" && details.Error != tc.tokenError.Error { + t.Errorf("%s: got wrong Error in response: %s (%s). expected %s", + tc.name, details.Error, details.ErrorDescription, tc.tokenError.Error) + } + } else { + t.Errorf("%s: unexpected error type: %s. expected *oauth2.RetrieveError", tc.name, reflect.TypeOf(err)) + } +} + type oauth2Client struct { config *oauth2.Config token *oauth2.Token diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index f3f208e1..dd2083ae 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -81,6 +81,11 @@ func mustBeErrAlreadyExists(t *testing.T, kind string, err error) { } func testAuthRequestCRUD(t *testing.T, s storage.Storage) { + codeChallenge := storage.PKCE{ + CodeChallenge: "code_challenge_test", + CodeChallengeMethod: "plain", + } + a1 := storage.AuthRequest{ ID: storage.NewID(), ClientID: "client1", @@ -101,6 +106,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { EmailVerified: true, Groups: []string{"a", "b"}, }, + PKCE: codeChallenge, } identity := storage.Claims{Email: "foobar"} @@ -155,6 +161,10 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { t.Fatalf("update failed, wanted identity=%#v got %#v", identity, got.Claims) } + if !reflect.DeepEqual(got.PKCE, codeChallenge) { + t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE) + } + if err := s.DeleteAuthRequest(a1.ID); err != nil { t.Fatalf("failed to delete auth request: %v", err) } diff --git a/storage/etcd/types.go b/storage/etcd/types.go index def95b55..22e083af 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -21,19 +21,24 @@ type AuthCode struct { Claims Claims `json:"claims,omitempty"` Expiry time.Time `json:"expiry"` + + CodeChallenge string `json:"code_challenge,omitempty"` + CodeChallengeMethod string `json:"code_challenge_method,omitempty"` } func fromStorageAuthCode(a storage.AuthCode) AuthCode { return AuthCode{ - ID: a.ID, - ClientID: a.ClientID, - RedirectURI: a.RedirectURI, - ConnectorID: a.ConnectorID, - ConnectorData: a.ConnectorData, - Nonce: a.Nonce, - Scopes: a.Scopes, - Claims: fromStorageClaims(a.Claims), - Expiry: a.Expiry, + ID: a.ID, + ClientID: a.ClientID, + RedirectURI: a.RedirectURI, + ConnectorID: a.ConnectorID, + ConnectorData: a.ConnectorData, + Nonce: a.Nonce, + Scopes: a.Scopes, + Claims: fromStorageClaims(a.Claims), + Expiry: a.Expiry, + CodeChallenge: a.PKCE.CodeChallenge, + CodeChallengeMethod: a.PKCE.CodeChallengeMethod, } } @@ -58,6 +63,9 @@ type AuthRequest struct { ConnectorID string `json:"connector_id"` ConnectorData []byte `json:"connector_data"` + + CodeChallenge string `json:"code_challenge,omitempty"` + CodeChallengeMethod string `json:"code_challenge_method,omitempty"` } func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { @@ -75,6 +83,8 @@ func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { Claims: fromStorageClaims(a.Claims), ConnectorID: a.ConnectorID, ConnectorData: a.ConnectorData, + CodeChallenge: a.PKCE.CodeChallenge, + CodeChallengeMethod: a.PKCE.CodeChallengeMethod, } } @@ -93,6 +103,10 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { ConnectorData: a.ConnectorData, Expiry: a.Expiry, Claims: toStorageClaims(a.Claims), + PKCE: storage.PKCE{ + CodeChallenge: a.CodeChallenge, + CodeChallengeMethod: a.CodeChallengeMethod, + }, } } diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index c3eb4172..41b14f37 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -299,6 +299,9 @@ type AuthRequest struct { ConnectorData []byte `json:"connectorData,omitempty"` Expiry time.Time `json:"expiry"` + + CodeChallenge string `json:"code_challenge,omitempty"` + CodeChallengeMethod string `json:"code_challenge_method,omitempty"` } // AuthRequestList is a list of AuthRequests. @@ -323,6 +326,10 @@ func toStorageAuthRequest(req AuthRequest) storage.AuthRequest { ConnectorData: req.ConnectorData, Expiry: req.Expiry, Claims: toStorageClaims(req.Claims), + PKCE: storage.PKCE{ + CodeChallenge: req.CodeChallenge, + CodeChallengeMethod: req.CodeChallengeMethod, + }, } return a } @@ -349,6 +356,8 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { ConnectorData: a.ConnectorData, Expiry: a.Expiry, Claims: fromStorageClaims(a.Claims), + CodeChallenge: a.PKCE.CodeChallenge, + CodeChallengeMethod: a.PKCE.CodeChallengeMethod, } return req } @@ -422,6 +431,9 @@ type AuthCode struct { ConnectorData []byte `json:"connectorData,omitempty"` Expiry time.Time `json:"expiry"` + + CodeChallenge string `json:"code_challenge,omitempty"` + CodeChallengeMethod string `json:"code_challenge_method,omitempty"` } // AuthCodeList is a list of AuthCodes. @@ -441,14 +453,16 @@ func (cli *client) fromStorageAuthCode(a storage.AuthCode) AuthCode { Name: a.ID, Namespace: cli.namespace, }, - ClientID: a.ClientID, - RedirectURI: a.RedirectURI, - ConnectorID: a.ConnectorID, - ConnectorData: a.ConnectorData, - Nonce: a.Nonce, - Scopes: a.Scopes, - Claims: fromStorageClaims(a.Claims), - Expiry: a.Expiry, + ClientID: a.ClientID, + RedirectURI: a.RedirectURI, + ConnectorID: a.ConnectorID, + ConnectorData: a.ConnectorData, + Nonce: a.Nonce, + Scopes: a.Scopes, + Claims: fromStorageClaims(a.Claims), + Expiry: a.Expiry, + CodeChallenge: a.PKCE.CodeChallenge, + CodeChallengeMethod: a.PKCE.CodeChallengeMethod, } } @@ -463,6 +477,10 @@ func toStorageAuthCode(a AuthCode) storage.AuthCode { Scopes: a.Scopes, Claims: toStorageClaims(a.Claims), Expiry: a.Expiry, + PKCE: storage.PKCE{ + CodeChallenge: a.CodeChallenge, + CodeChallengeMethod: a.CodeChallengeMethod, + }, } } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 325756f6..dedfd2a8 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -130,10 +130,11 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - expiry + expiry, + code_challenge, code_challenge_method ) values ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20 ); `, a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, @@ -142,6 +143,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry, + a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -172,8 +174,9 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) claims_email = $12, claims_email_verified = $13, claims_groups = $14, connector_id = $15, connector_data = $16, - expiry = $17 - where id = $18; + expiry = $17, + code_challenge = $18, code_challenge_method = $19 + where id = $20; `, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ForceApprovalPrompt, a.LoggedIn, @@ -181,7 +184,9 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, - a.Expiry, r.ID, + a.Expiry, + a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, + r.ID, ) if err != nil { return fmt.Errorf("update auth request: %v", err) @@ -201,7 +206,8 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { force_approval_prompt, logged_in, claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, - connector_id, connector_data, expiry + connector_id, connector_data, expiry, + code_challenge, code_challenge_method from auth_request where id = $1; `, id).Scan( &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, @@ -210,6 +216,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry, + &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, ) if err != nil { if err == sql.ErrNoRows { @@ -227,13 +234,15 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - expiry + expiry, + code_challenge, code_challenge_method ) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16); `, a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID, a.Claims.Username, a.Claims.PreferredUsername, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry, + a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, ) if err != nil { @@ -252,12 +261,14 @@ func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - expiry + expiry, + code_challenge, code_challenge_method from auth_code where id = $1; `, id).Scan( &a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID, &a.Claims.Username, &a.Claims.PreferredUsername, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry, + &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, ) if err != nil { if err == sql.ErrNoRows { @@ -317,7 +328,7 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok claims_email_verified = $8, claims_groups = $9, connector_id = $10, - connector_data = $11, + connector_data = $11, token = $12, created_at = $13, last_used = $14 diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 73934b1b..8201d443 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -250,4 +250,19 @@ var migrations = []migration{ );`, }, }, + { + stmts: []string{` + alter table auth_request + add column code_challenge text not null default '';`, + ` + alter table auth_request + add column code_challenge_method text not null default '';`, + ` + alter table auth_code + add column code_challenge text not null default '';`, + ` + alter table auth_code + add column code_challenge_method text not null default '';`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index a8eeea40..06f718e1 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -169,6 +169,12 @@ type Claims struct { Groups []string } +// Data needed for PKCE (RFC 7636) +type PKCE struct { + CodeChallenge string + CodeChallengeMethod string +} + // AuthRequest represents a OAuth2 client authorization request. It holds the state // of a single auth flow up to the point that the user authorizes the client. type AuthRequest struct { @@ -206,6 +212,9 @@ type AuthRequest struct { // Set when the user authenticates. ConnectorID string ConnectorData []byte + + // PKCE CodeChallenge and CodeChallengeMethod + PKCE PKCE } // AuthCode represents a code which can be exchanged for an OAuth2 token response. @@ -241,6 +250,9 @@ type AuthCode struct { Claims Claims Expiry time.Time + + // PKCE CodeChallenge and CodeChallengeMethod + PKCE PKCE } // RefreshToken is an OAuth2 refresh token which allows a client to request new