From d658c24e8f2dbbf0bda333c88970d66242c13ade Mon Sep 17 00:00:00 2001 From: Rui Yang Date: Fri, 25 Sep 2020 11:59:42 -0400 Subject: [PATCH] add dex config flag for enabling client secret encryption * if enabled, it will make sure client secret is bcrypted correctly * if not, it falls back to old behaviour that allowing empty client secret and comparing plain text, though now it will do ConstantTimeCompare to avoid a timing attack. So in either way it should provide more secure of client secret verification. Co-authored-by: Alex Surraci Signed-off-by: Rui Yang --- server/handlers.go | 30 ++++---- server/server.go | 25 +++++++ server/server_test.go | 160 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 200 insertions(+), 15 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index db835997..494af232 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + "crypto/subtle" "encoding/base64" "encoding/json" "errors" @@ -681,22 +682,21 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { return } - if client.Secret != clientSecret { - if clientSecret == "" { - s.logger.Infof("missing client_secret on token request for client: %s", client.ID) - } else { - s.logger.Infof("invalid client_secret on token request for client: %s", client.ID) + if s.hashClientSecret { + if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil { + s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) + return + } + } else { + if subtle.ConstantTimeCompare([]byte(client.Secret), []byte(clientSecret)) != 1 { + if clientSecret == "" { + s.logger.Infof("missing client_secret on token request for client: %s", client.ID) + } else { + s.logger.Infof("invalid client_secret on token request for client: %s", client.ID) + } + s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) + return } - } - - if err := checkCost([]byte(client.Secret)); err != nil { - s.logger.Errorf("failed to check cost of client secret: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil { - s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) - return } grantType := r.PostFormValue("grant_type") diff --git a/server/server.go b/server/server.go index a79b7cfd..5909d5c3 100644 --- a/server/server.go +++ b/server/server.go @@ -77,6 +77,9 @@ type Config struct { // If enabled, the connectors selection page will always be shown even if there's only one AlwaysShowLoginScreen bool + // If enabled, the client secret is expected to be encrypted + HashClientSecret bool + RotateKeysAfter time.Duration // Defaults to 6 hours. IDTokensValidFor time.Duration // Defaults to 24 hours AuthRequestsValidFor time.Duration // Defaults to 24 hours @@ -151,6 +154,9 @@ type Server struct { // If enabled, show the connector selection screen even if there's only one alwaysShowLogin bool + // If enabled, the client secret is expected to be encrypted + hashClientSecret bool + // Used for password grant passwordConnector string @@ -189,6 +195,24 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) if c.Storage == nil { return nil, errors.New("server: storage cannot be nil") } + + if c.HashClientSecret { + clients, err := c.Storage.ListClients() + if err != nil { + return nil, fmt.Errorf("server: failed to list clients") + } + + for _, client := range clients { + if client.Secret == "" { + return nil, fmt.Errorf("server: client secret can't be empty") + } + + if err = checkCost([]byte(client.Secret)); err != nil { + return nil, fmt.Errorf("server: failed to check cost of client secret: %v", err) + } + } + } + if len(c.SupportedResponseTypes) == 0 { c.SupportedResponseTypes = []string{responseTypeCode} } @@ -232,6 +256,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute), skipApproval: c.SkipApprovalScreen, alwaysShowLogin: c.AlwaysShowLoginScreen, + hashClientSecret: c.HashClientSecret, now: now, templates: tmpls, passwordConnector: c.PasswordConnector, diff --git a/server/server_test.go b/server/server_test.go index 87ca6c17..cbb298e5 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1637,3 +1637,163 @@ func TestOAuth2DeviceFlow(t *testing.T) { }() } } + +func TestClientSecretEncryption(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.HashClientSecret = true + }) + defer httpServer.Close() + + clientID := "testclient" + clientSecret := "testclientsecret" + hash, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost) + if err != nil { + t.Fatalf("failed to bcrypt: %s", err) + } + + // Query server's provider metadata. + p, err := oidc.NewProvider(ctx, httpServer.URL) + if err != nil { + t.Fatalf("failed to get provider: %v", err) + } + + var ( + // If the OAuth2 client didn't get a response, we need + // to print the requests the user saw. + gotCode bool + reqDump, respDump []byte // Auth step, not token. + state = "a_state" + ) + defer func() { + if !gotCode { + t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump) + } + }() + + // Setup OAuth2 client. + var oauth2Config *oauth2.Config + + requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"} + + // Create the OAuth2 config. + oauth2Config = &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: p.Endpoint(), + Scopes: requestedScopes, + } + + oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/callback" { + // User is visiting app first time. Redirect to dex. + http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther) + return + } + + // User is at '/callback' so they were just redirected _from_ dex. + q := r.URL.Query() + + // Grab code, exchange for token. + if code := q.Get("code"); code != "" { + gotCode = true + token, err := oauth2Config.Exchange(ctx, code) + if err != nil { + t.Errorf("failed to exchange code for token: %v", err) + return + } + + oidcConfig := &oidc.Config{SkipClientIDCheck: true} + + idToken, ok := token.Extra("id_token").(string) + if !ok { + t.Errorf("no id token found") + return + } + if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil { + t.Errorf("failed to verify id token: %v", err) + return + } + } + + w.WriteHeader(http.StatusOK) + })) + + oauth2Config.RedirectURL = oauth2Client.URL + "/callback" + + defer oauth2Client.Close() + + // Regester the client above with dex. + client := storage.Client{ + ID: clientID, + Secret: string(hash), + RedirectURIs: []string{oauth2Client.URL + "/callback"}, + } + if err := s.storage.CreateClient(client); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + // Login! + // + // 1. First request to client, redirects to dex. + // 2. Dex "logs in" the user, redirects to client with "code". + // 3. Client exchanges "code" for "token" (id_token, refresh_token, etc.). + // 4. Test is run with OAuth2 token response. + // + resp, err := http.Get(oauth2Client.URL + "/login") + if err != nil { + t.Fatalf("get failed: %v", err) + } + defer resp.Body.Close() + + if reqDump, err = httputil.DumpRequest(resp.Request, false); err != nil { + t.Fatal(err) + } + if respDump, err = httputil.DumpResponse(resp, true); err != nil { + t.Fatal(err) + } +} + +func TestClientSecretEncryptionCost(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientID := "testclient" + clientSecret := "testclientsecret" + hash, err := bcrypt.GenerateFromPassword([]byte(clientSecret), 5) + if err != nil { + t.Fatalf("failed to bcrypt: %s", err) + } + + // Register the client above with dex. + client := storage.Client{ + ID: clientID, + Secret: string(hash), + } + + config := Config{ + Storage: memory.New(logger), + Web: WebConfig{ + Dir: "../web", + }, + Logger: logger, + PrometheusRegistry: prometheus.NewRegistry(), + HashClientSecret: true, + } + + err = config.Storage.CreateClient(client) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + _, err = newServer(ctx, config, staticRotationStrategy(testKey)) + if err == nil { + t.Error("constructing server should have failed") + } + + if !strings.Contains(err.Error(), "failed to check cost") { + t.Error("should have failed with cost error") + } +}