From 3bd0e91a6879e9f00f0b02e237f2f9e5812df53a Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Thu, 25 Feb 2021 11:53:25 +0400 Subject: [PATCH] Make /device/token deprecation warning more concise Signed-off-by: m.nabokikh --- server/deviceflowhandlers.go | 5 +- server/server.go | 2 +- server/server_test.go | 293 +++++++++++++++++++---------------- 3 files changed, 160 insertions(+), 140 deletions(-) diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index 039472b8..5ec7eb8e 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -151,9 +151,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) handleDeviceTokenGrant(w http.ResponseWriter, r *http.Request) { - s.logger.Warn(`Request to the deprecated "/device/token" endpoint was received.`) - s.logger.Warn(`The "/device/token" endpoint will be removed in a future release.`) +func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Request) { + s.logger.Warn(`The deprecated "/device/token" endpoint was called. It will be removed, use "/token" instead.`) w.Header().Set("Content-Type", "application/json") switch r.Method { diff --git a/server/server.go b/server/server.go index 63305403..986d546b 100644 --- a/server/server.go +++ b/server/server.go @@ -321,7 +321,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/code", s.handleDeviceCode) // TODO(nabokihms): "/device/token" endpoint is deprecated, consider using /token endpoint instead - handleFunc("/device/token", s.handleDeviceTokenGrant) + handleFunc("/device/token", s.handleDeviceTokenDeprecated) handleFunc(deviceCallbackURI, s.handleDeviceCallback) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { // Strip the X-Remote-* headers to prevent security issues on diff --git a/server/server_test.go b/server/server_test.go index 34aa9eec..936d89f9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1497,143 +1497,164 @@ func TestOAuth2DeviceFlow(t *testing.T) { var conn *mock.Callback idTokensValidFor := time.Second * 30 - for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests { - func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + tests := makeOAuth2Tests(clientID, clientSecret, now) + testCases := []struct { + name string + tokenEndpoint string + oauth2Tests oauth2Tests + }{ + { + name: "Actual token endpoint for devices", + tokenEndpoint: "/token", + oauth2Tests: tests, + }, + // TODO(nabokihms): delete temporary tests after removing the deprecated token endpoint support + { + name: "Deprecated token endpoint for devices", + tokenEndpoint: "/device/token", + oauth2Tests: tests, + }, + } - // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { - c.Issuer += "/non-root-path" - c.Now = now - c.IDTokensValidFor = idTokensValidFor + for _, testCase := range testCases { + for _, tc := range testCase.oauth2Tests.tests { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer += "/non-root-path" + c.Now = now + c.IDTokensValidFor = idTokensValidFor + }) + defer httpServer.Close() + + mockConn := s.connectors["mock"] + conn = mockConn.Connector.(*mock.Callback) + + p, err := oidc.NewProvider(ctx, httpServer.URL) + if err != nil { + t.Fatalf("failed to get provider: %v", err) + } + + // Add the Clients to the test server + client := storage.Client{ + ID: clientID, + RedirectURIs: []string{deviceCallbackURI}, + Public: true, + } + if err := s.storage.CreateClient(client); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + // Grab the issuer that we'll reuse for the different endpoints to hit + issuer, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + + // Send a new Device Request + codeURL, _ := url.Parse(issuer.String()) + codeURL.Path = path.Join(codeURL.Path, "device/code") + + data := url.Values{} + data.Set("client_id", clientID) + data.Add("scope", strings.Join(requestedScopes, " ")) + resp, err := http.PostForm(codeURL.String(), data) + if err != nil { + t.Errorf("Could not request device code: %v", err) + } + defer resp.Body.Close() + responseBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read device code response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + if resp.Header.Get("Cache-Control") != "no-store" { + t.Errorf("Cache-Control header doesn't exist in Device Code Response") + } + + // Parse the code response + var deviceCode deviceCodeResponse + if err := json.Unmarshal(responseBody, &deviceCode); err != nil { + t.Errorf("Unexpected Device Code Response Format %v", string(responseBody)) + } + + // Mock the user hitting the verification URI and posting the form + verifyURL, _ := url.Parse(issuer.String()) + verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code") + urlData := url.Values{} + urlData.Set("user_code", deviceCode.UserCode) + resp, err = http.PostForm(verifyURL.String(), urlData) + if err != nil { + t.Errorf("Error Posting Form: %v", err) + } + defer resp.Body.Close() + responseBody, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read verification response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + // Hit the Token Endpoint, and try and get an access token + tokenURL, _ := url.Parse(issuer.String()) + tokenURL.Path = path.Join(tokenURL.Path, testCase.tokenEndpoint) + v := url.Values{} + v.Add("grant_type", grantTypeDeviceCode) + v.Add("device_code", deviceCode.DeviceCode) + resp, err = http.PostForm(tokenURL.String(), v) + if err != nil { + t.Errorf("Could not request device token: %v", err) + } + defer resp.Body.Close() + responseBody, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read device token response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + // Parse the response + var tokenRes accessTokenResponse + if err := json.Unmarshal(responseBody, &tokenRes); err != nil { + t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody)) + } + + token := &oauth2.Token{ + AccessToken: tokenRes.AccessToken, + TokenType: tokenRes.TokenType, + RefreshToken: tokenRes.RefreshToken, + } + raw := make(map[string]interface{}) + json.Unmarshal(responseBody, &raw) // no error checks for optional fields + token = token.WithExtra(raw) + if secs := tokenRes.ExpiresIn; secs > 0 { + token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) + } + + // Run token tests to validate info is correct + // Create the OAuth2 config. + oauth2Config := &oauth2.Config{ + ClientID: client.ID, + ClientSecret: client.Secret, + Endpoint: p.Endpoint(), + Scopes: requestedScopes, + RedirectURL: deviceCallbackURI, + } + if len(tc.scopes) != 0 { + oauth2Config.Scopes = tc.scopes + } + err = tc.handleToken(ctx, p, oauth2Config, token, conn) + if err != nil { + t.Errorf("%s: %v", tc.name, err) + } }) - defer httpServer.Close() - - mockConn := s.connectors["mock"] - conn = mockConn.Connector.(*mock.Callback) - - p, err := oidc.NewProvider(ctx, httpServer.URL) - if err != nil { - t.Fatalf("failed to get provider: %v", err) - } - - // Add the Clients to the test server - client := storage.Client{ - ID: clientID, - RedirectURIs: []string{deviceCallbackURI}, - Public: true, - } - if err := s.storage.CreateClient(client); err != nil { - t.Fatalf("failed to create client: %v", err) - } - - // Grab the issuer that we'll reuse for the different endpoints to hit - issuer, err := url.Parse(s.issuerURL.String()) - if err != nil { - t.Errorf("Could not parse issuer URL %v", err) - } - - // Send a new Device Request - codeURL, _ := url.Parse(issuer.String()) - codeURL.Path = path.Join(codeURL.Path, "device/code") - - data := url.Values{} - data.Set("client_id", clientID) - data.Add("scope", strings.Join(requestedScopes, " ")) - resp, err := http.PostForm(codeURL.String(), data) - if err != nil { - t.Errorf("Could not request device code: %v", err) - } - defer resp.Body.Close() - responseBody, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("Could read device code response %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) - } - if resp.Header.Get("Cache-Control") != "no-store" { - t.Errorf("Cache-Control header doesn't exist in Device Code Response") - } - - // Parse the code response - var deviceCode deviceCodeResponse - if err := json.Unmarshal(responseBody, &deviceCode); err != nil { - t.Errorf("Unexpected Device Code Response Format %v", string(responseBody)) - } - - // Mock the user hitting the verification URI and posting the form - verifyURL, _ := url.Parse(issuer.String()) - verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code") - urlData := url.Values{} - urlData.Set("user_code", deviceCode.UserCode) - resp, err = http.PostForm(verifyURL.String(), urlData) - if err != nil { - t.Errorf("Error Posting Form: %v", err) - } - defer resp.Body.Close() - responseBody, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("Could read verification response %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) - } - - // Hit the Token Endpoint, and try and get an access token - tokenURL, _ := url.Parse(issuer.String()) - tokenURL.Path = path.Join(tokenURL.Path, "/token") - v := url.Values{} - v.Add("grant_type", grantTypeDeviceCode) - v.Add("device_code", deviceCode.DeviceCode) - resp, err = http.PostForm(tokenURL.String(), v) - if err != nil { - t.Errorf("Could not request device token: %v", err) - } - defer resp.Body.Close() - responseBody, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("Could read device token response %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) - } - - // Parse the response - var tokenRes accessTokenResponse - if err := json.Unmarshal(responseBody, &tokenRes); err != nil { - t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody)) - } - - token := &oauth2.Token{ - AccessToken: tokenRes.AccessToken, - TokenType: tokenRes.TokenType, - RefreshToken: tokenRes.RefreshToken, - } - raw := make(map[string]interface{}) - json.Unmarshal(responseBody, &raw) // no error checks for optional fields - token = token.WithExtra(raw) - if secs := tokenRes.ExpiresIn; secs > 0 { - token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) - } - - // Run token tests to validate info is correct - // Create the OAuth2 config. - oauth2Config := &oauth2.Config{ - ClientID: client.ID, - ClientSecret: client.Secret, - Endpoint: p.Endpoint(), - Scopes: requestedScopes, - RedirectURL: deviceCallbackURI, - } - if len(tc.scopes) != 0 { - oauth2Config.Scopes = tc.scopes - } - err = tc.handleToken(ctx, p, oauth2Config, token, conn) - if err != nil { - t.Errorf("%s: %v", tc.name, err) - } - }() + } } }