From 1211a86d58a001d75c62e64ecd304f1ec47455e4 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Fri, 19 Feb 2021 19:41:19 +0400 Subject: [PATCH 1/3] fix: use /token endpoint to get tokens with device flow Signed-off-by: m.nabokikh --- server/deviceflowhandlers.go | 114 ++++++++++++++++++----------------- server/handlers.go | 27 +++++++-- server/server.go | 3 +- server/server_test.go | 2 +- 4 files changed, 85 insertions(+), 61 deletions(-) diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index a73dafe8..eef50866 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -151,7 +151,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { +func (s *Server) handleDeviceTokenGrant(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodPost: @@ -162,71 +162,75 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { return } - deviceCode := r.Form.Get("device_code") - if deviceCode == "" { - s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest) - return - } - grantType := r.PostFormValue("grant_type") if grantType != grantTypeDeviceCode { s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) return } - now := s.now() - - // Grab the device token, check validity - deviceToken, err := s.storage.GetDeviceToken(deviceCode) - if err != nil { - if err != storage.ErrNotFound { - s.logger.Errorf("failed to get device code: %v", err) - } - s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest) - return - } else if now.After(deviceToken.Expiry) { - s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest) - return - } - - // Rate Limiting check - slowDown := false - pollInterval := deviceToken.PollIntervalSeconds - minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval)) - if now.Before(minRequestTime) { - slowDown = true - // Continually increase the poll interval until the user waits the proper time - pollInterval += 5 - } else { - pollInterval = 5 - } - - switch deviceToken.Status { - case deviceTokenPending: - updater := func(old storage.DeviceToken) (storage.DeviceToken, error) { - old.PollIntervalSeconds = pollInterval - old.LastRequestTime = now - return old, nil - } - // Update device token last request time in storage - if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil { - s.logger.Errorf("failed to update device token: %v", err) - s.renderError(r, w, http.StatusInternalServerError, "") - return - } - if slowDown { - s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest) - } else { - s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) - } - case deviceTokenComplete: - w.Write([]byte(deviceToken.Token)) - } + s.handleDeviceToken(w, r) default: s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") } } +func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { + deviceCode := r.Form.Get("device_code") + if deviceCode == "" { + s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest) + return + } + + now := s.now() + + // Grab the device token, check validity + deviceToken, err := s.storage.GetDeviceToken(deviceCode) + if err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get device code: %v", err) + } + s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest) + return + } else if now.After(deviceToken.Expiry) { + s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest) + return + } + + // Rate Limiting check + slowDown := false + pollInterval := deviceToken.PollIntervalSeconds + minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval)) + if now.Before(minRequestTime) { + slowDown = true + // Continually increase the poll interval until the user waits the proper time + pollInterval += 5 + } else { + pollInterval = 5 + } + + switch deviceToken.Status { + case deviceTokenPending: + updater := func(old storage.DeviceToken) (storage.DeviceToken, error) { + old.PollIntervalSeconds = pollInterval + old.LastRequestTime = now + return old, nil + } + // Update device token last request time in storage + if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil { + s.logger.Errorf("failed to update device token: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "") + return + } + if slowDown { + s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest) + } else { + s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) + } + case deviceTokenComplete: + w.Write([]byte(deviceToken.Token)) + } +} + func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: diff --git a/server/handlers.go b/server/handlers.go index eb65f490..3bf43b16 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -652,7 +652,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe http.Redirect(w, r, u.String(), http.StatusSeeOther) } -func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { +func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client)) { clientID, clientSecret, ok := r.BasicAuth() if ok { var err error @@ -689,14 +689,33 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { return } + handler(w, r, client) +} + +func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method != http.MethodPost { + s.tokenErrHelper(w, errInvalidRequest, "method not allowed", http.StatusBadRequest) + return + } + + err := r.ParseForm() + if err != nil { + s.logger.Errorf("Could not parse request body: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + return + } + grantType := r.PostFormValue("grant_type") switch grantType { + case grantTypeDeviceCode: + s.handleDeviceToken(w, r) case grantTypeAuthorizationCode: - s.handleAuthCode(w, r, client) + s.withClientFromStorage(w, r, s.handleAuthCode) case grantTypeRefreshToken: - s.handleRefreshToken(w, r, client) + s.withClientFromStorage(w, r, s.handleRefreshToken) case grantTypePassword: - s.handlePasswordGrant(w, r, client) + s.withClientFromStorage(w, r, s.handlePasswordGrant) default: s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) } diff --git a/server/server.go b/server/server.go index a79b7cfd..1294ff76 100644 --- a/server/server.go +++ b/server/server.go @@ -320,7 +320,8 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc("/device", s.handleDeviceExchange) handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/code", s.handleDeviceCode) - handleFunc("/device/token", s.handleDeviceToken) + // TODO(nabokihms): deprecate and remove this endpoint, use /token instead + handleFunc("/device/token", s.handleDeviceTokenGrant) handleFunc(deviceCallbackURI, s.handleDeviceCallback) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { // Strip the X-Remote-* headers to prevent security issues on diff --git a/server/server_test.go b/server/server_test.go index 87ca6c17..34aa9eec 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1583,7 +1583,7 @@ func TestOAuth2DeviceFlow(t *testing.T) { // Hit the Token Endpoint, and try and get an access token tokenURL, _ := url.Parse(issuer.String()) - tokenURL.Path = path.Join(tokenURL.Path, "/device/token") + tokenURL.Path = path.Join(tokenURL.Path, "/token") v := url.Values{} v.Add("grant_type", grantTypeDeviceCode) v.Add("device_code", deviceCode.DeviceCode) From 9ed5cc00cfcc0bcb41f8f6831064ca464583831e Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Wed, 24 Feb 2021 17:14:28 +0400 Subject: [PATCH 2/3] Add deprecation warning for /device/token endpoint Signed-off-by: m.nabokikh --- server/deviceflowhandlers.go | 3 +++ server/server.go | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index eef50866..039472b8 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -152,6 +152,9 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleDeviceTokenGrant(w http.ResponseWriter, r *http.Request) { + s.logger.Warn(`Request to the deprecated "/device/token" endpoint was received.`) + s.logger.Warn(`The "/device/token" endpoint will be removed in a future release.`) + w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodPost: diff --git a/server/server.go b/server/server.go index 1294ff76..63305403 100644 --- a/server/server.go +++ b/server/server.go @@ -320,7 +320,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc("/device", s.handleDeviceExchange) handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/code", s.handleDeviceCode) - // TODO(nabokihms): deprecate and remove this endpoint, use /token instead + // TODO(nabokihms): "/device/token" endpoint is deprecated, consider using /token endpoint instead handleFunc("/device/token", s.handleDeviceTokenGrant) handleFunc(deviceCallbackURI, s.handleDeviceCallback) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { From 3bd0e91a6879e9f00f0b02e237f2f9e5812df53a Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Thu, 25 Feb 2021 11:53:25 +0400 Subject: [PATCH 3/3] Make /device/token deprecation warning more concise Signed-off-by: m.nabokikh --- server/deviceflowhandlers.go | 5 +- server/server.go | 2 +- server/server_test.go | 293 +++++++++++++++++++---------------- 3 files changed, 160 insertions(+), 140 deletions(-) diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index 039472b8..5ec7eb8e 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -151,9 +151,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) handleDeviceTokenGrant(w http.ResponseWriter, r *http.Request) { - s.logger.Warn(`Request to the deprecated "/device/token" endpoint was received.`) - s.logger.Warn(`The "/device/token" endpoint will be removed in a future release.`) +func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Request) { + s.logger.Warn(`The deprecated "/device/token" endpoint was called. It will be removed, use "/token" instead.`) w.Header().Set("Content-Type", "application/json") switch r.Method { diff --git a/server/server.go b/server/server.go index 63305403..986d546b 100644 --- a/server/server.go +++ b/server/server.go @@ -321,7 +321,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/code", s.handleDeviceCode) // TODO(nabokihms): "/device/token" endpoint is deprecated, consider using /token endpoint instead - handleFunc("/device/token", s.handleDeviceTokenGrant) + handleFunc("/device/token", s.handleDeviceTokenDeprecated) handleFunc(deviceCallbackURI, s.handleDeviceCallback) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { // Strip the X-Remote-* headers to prevent security issues on diff --git a/server/server_test.go b/server/server_test.go index 34aa9eec..936d89f9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1497,143 +1497,164 @@ func TestOAuth2DeviceFlow(t *testing.T) { var conn *mock.Callback idTokensValidFor := time.Second * 30 - for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests { - func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + tests := makeOAuth2Tests(clientID, clientSecret, now) + testCases := []struct { + name string + tokenEndpoint string + oauth2Tests oauth2Tests + }{ + { + name: "Actual token endpoint for devices", + tokenEndpoint: "/token", + oauth2Tests: tests, + }, + // TODO(nabokihms): delete temporary tests after removing the deprecated token endpoint support + { + name: "Deprecated token endpoint for devices", + tokenEndpoint: "/device/token", + oauth2Tests: tests, + }, + } - // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { - c.Issuer += "/non-root-path" - c.Now = now - c.IDTokensValidFor = idTokensValidFor + for _, testCase := range testCases { + for _, tc := range testCase.oauth2Tests.tests { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer += "/non-root-path" + c.Now = now + c.IDTokensValidFor = idTokensValidFor + }) + defer httpServer.Close() + + mockConn := s.connectors["mock"] + conn = mockConn.Connector.(*mock.Callback) + + p, err := oidc.NewProvider(ctx, httpServer.URL) + if err != nil { + t.Fatalf("failed to get provider: %v", err) + } + + // Add the Clients to the test server + client := storage.Client{ + ID: clientID, + RedirectURIs: []string{deviceCallbackURI}, + Public: true, + } + if err := s.storage.CreateClient(client); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + // Grab the issuer that we'll reuse for the different endpoints to hit + issuer, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + + // Send a new Device Request + codeURL, _ := url.Parse(issuer.String()) + codeURL.Path = path.Join(codeURL.Path, "device/code") + + data := url.Values{} + data.Set("client_id", clientID) + data.Add("scope", strings.Join(requestedScopes, " ")) + resp, err := http.PostForm(codeURL.String(), data) + if err != nil { + t.Errorf("Could not request device code: %v", err) + } + defer resp.Body.Close() + responseBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read device code response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + if resp.Header.Get("Cache-Control") != "no-store" { + t.Errorf("Cache-Control header doesn't exist in Device Code Response") + } + + // Parse the code response + var deviceCode deviceCodeResponse + if err := json.Unmarshal(responseBody, &deviceCode); err != nil { + t.Errorf("Unexpected Device Code Response Format %v", string(responseBody)) + } + + // Mock the user hitting the verification URI and posting the form + verifyURL, _ := url.Parse(issuer.String()) + verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code") + urlData := url.Values{} + urlData.Set("user_code", deviceCode.UserCode) + resp, err = http.PostForm(verifyURL.String(), urlData) + if err != nil { + t.Errorf("Error Posting Form: %v", err) + } + defer resp.Body.Close() + responseBody, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read verification response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + // Hit the Token Endpoint, and try and get an access token + tokenURL, _ := url.Parse(issuer.String()) + tokenURL.Path = path.Join(tokenURL.Path, testCase.tokenEndpoint) + v := url.Values{} + v.Add("grant_type", grantTypeDeviceCode) + v.Add("device_code", deviceCode.DeviceCode) + resp, err = http.PostForm(tokenURL.String(), v) + if err != nil { + t.Errorf("Could not request device token: %v", err) + } + defer resp.Body.Close() + responseBody, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read device token response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + // Parse the response + var tokenRes accessTokenResponse + if err := json.Unmarshal(responseBody, &tokenRes); err != nil { + t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody)) + } + + token := &oauth2.Token{ + AccessToken: tokenRes.AccessToken, + TokenType: tokenRes.TokenType, + RefreshToken: tokenRes.RefreshToken, + } + raw := make(map[string]interface{}) + json.Unmarshal(responseBody, &raw) // no error checks for optional fields + token = token.WithExtra(raw) + if secs := tokenRes.ExpiresIn; secs > 0 { + token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) + } + + // Run token tests to validate info is correct + // Create the OAuth2 config. + oauth2Config := &oauth2.Config{ + ClientID: client.ID, + ClientSecret: client.Secret, + Endpoint: p.Endpoint(), + Scopes: requestedScopes, + RedirectURL: deviceCallbackURI, + } + if len(tc.scopes) != 0 { + oauth2Config.Scopes = tc.scopes + } + err = tc.handleToken(ctx, p, oauth2Config, token, conn) + if err != nil { + t.Errorf("%s: %v", tc.name, err) + } }) - defer httpServer.Close() - - mockConn := s.connectors["mock"] - conn = mockConn.Connector.(*mock.Callback) - - p, err := oidc.NewProvider(ctx, httpServer.URL) - if err != nil { - t.Fatalf("failed to get provider: %v", err) - } - - // Add the Clients to the test server - client := storage.Client{ - ID: clientID, - RedirectURIs: []string{deviceCallbackURI}, - Public: true, - } - if err := s.storage.CreateClient(client); err != nil { - t.Fatalf("failed to create client: %v", err) - } - - // Grab the issuer that we'll reuse for the different endpoints to hit - issuer, err := url.Parse(s.issuerURL.String()) - if err != nil { - t.Errorf("Could not parse issuer URL %v", err) - } - - // Send a new Device Request - codeURL, _ := url.Parse(issuer.String()) - codeURL.Path = path.Join(codeURL.Path, "device/code") - - data := url.Values{} - data.Set("client_id", clientID) - data.Add("scope", strings.Join(requestedScopes, " ")) - resp, err := http.PostForm(codeURL.String(), data) - if err != nil { - t.Errorf("Could not request device code: %v", err) - } - defer resp.Body.Close() - responseBody, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("Could read device code response %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) - } - if resp.Header.Get("Cache-Control") != "no-store" { - t.Errorf("Cache-Control header doesn't exist in Device Code Response") - } - - // Parse the code response - var deviceCode deviceCodeResponse - if err := json.Unmarshal(responseBody, &deviceCode); err != nil { - t.Errorf("Unexpected Device Code Response Format %v", string(responseBody)) - } - - // Mock the user hitting the verification URI and posting the form - verifyURL, _ := url.Parse(issuer.String()) - verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code") - urlData := url.Values{} - urlData.Set("user_code", deviceCode.UserCode) - resp, err = http.PostForm(verifyURL.String(), urlData) - if err != nil { - t.Errorf("Error Posting Form: %v", err) - } - defer resp.Body.Close() - responseBody, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("Could read verification response %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) - } - - // Hit the Token Endpoint, and try and get an access token - tokenURL, _ := url.Parse(issuer.String()) - tokenURL.Path = path.Join(tokenURL.Path, "/token") - v := url.Values{} - v.Add("grant_type", grantTypeDeviceCode) - v.Add("device_code", deviceCode.DeviceCode) - resp, err = http.PostForm(tokenURL.String(), v) - if err != nil { - t.Errorf("Could not request device token: %v", err) - } - defer resp.Body.Close() - responseBody, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("Could read device token response %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) - } - - // Parse the response - var tokenRes accessTokenResponse - if err := json.Unmarshal(responseBody, &tokenRes); err != nil { - t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody)) - } - - token := &oauth2.Token{ - AccessToken: tokenRes.AccessToken, - TokenType: tokenRes.TokenType, - RefreshToken: tokenRes.RefreshToken, - } - raw := make(map[string]interface{}) - json.Unmarshal(responseBody, &raw) // no error checks for optional fields - token = token.WithExtra(raw) - if secs := tokenRes.ExpiresIn; secs > 0 { - token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) - } - - // Run token tests to validate info is correct - // Create the OAuth2 config. - oauth2Config := &oauth2.Config{ - ClientID: client.ID, - ClientSecret: client.Secret, - Endpoint: p.Endpoint(), - Scopes: requestedScopes, - RedirectURL: deviceCallbackURI, - } - if len(tc.scopes) != 0 { - oauth2Config.Scopes = tc.scopes - } - err = tc.handleToken(ctx, p, oauth2Config, token, conn) - if err != nil { - t.Errorf("%s: %v", tc.name, err) - } - }() + } } }