diff --git a/server/deviceHandlers.go b/server/deviceflowhandlers.go similarity index 77% rename from server/deviceHandlers.go rename to server/deviceflowhandlers.go index 55255408..39ead503 100644 --- a/server/deviceHandlers.go +++ b/server/deviceflowhandlers.go @@ -29,7 +29,7 @@ type deviceCodeResponse struct { PollInterval int `json:"interval"` } -func (s *Server) getDeviceAuthURI() string { +func (s *Server) getDeviceVerificationURI() string { return path.Join(s.issuerURL.Path, "/device/auth/verify_code") } @@ -41,8 +41,9 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) { if err != nil { invalidAttempt = false } - if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, invalidAttempt); err != nil { + if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil { s.logger.Errorf("Server template error: %v", err) + s.renderError(r, w, http.StatusNotFound, "Page not found") } default: s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") @@ -63,7 +64,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { //Get the client id and scopes from the post clientID := r.Form.Get("client_id") - scopes := r.Form["scope"] + clientSecret := r.Form.Get("client_secret") + scopes := strings.Fields(r.Form.Get("scope")) s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes) @@ -82,11 +84,12 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { //Store the Device Request deviceReq := storage.DeviceRequest{ - UserCode: userCode, - DeviceCode: deviceCode, - ClientID: clientID, - Scopes: scopes, - Expiry: expireTime, + UserCode: userCode, + DeviceCode: deviceCode, + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: scopes, + Expiry: expireTime, } if err := s.storage.CreateDeviceRequest(deviceReq); err != nil { @@ -100,8 +103,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { DeviceCode: deviceCode, Status: deviceTokenPending, Expiry: expireTime, - LastRequestTime: time.Now(), - PollIntervalSeconds: 5, + LastRequestTime: s.now(), + PollIntervalSeconds: 0, } if err := s.storage.CreateDeviceToken(deviceToken); err != nil { @@ -113,7 +116,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { u, err := url.Parse(s.issuerURL.String()) if err != nil { s.logger.Errorf("Could not parse issuer URL %v", err) - s.renderError(r, w, http.StatusInternalServerError, "") + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) return } u.Path = path.Join(u.Path, "device") @@ -134,6 +137,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { } enc := json.NewEncoder(w) + enc.SetEscapeHTML(false) enc.SetIndent("", " ") enc.Encode(code) @@ -168,21 +172,25 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { now := s.now() - //Grab the device token + //Grab the device token, check validity deviceToken, err := s.storage.GetDeviceToken(deviceCode) - if err != nil || now.After(deviceToken.Expiry) { + if err != nil { if err != storage.ErrNotFound { s.logger.Errorf("failed to get device code: %v", err) } - s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest) + s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest) + return + } else if now.After(deviceToken.Expiry) { + s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest) return } //Rate Limiting check + slowDown := false pollInterval := deviceToken.PollIntervalSeconds minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval)) if now.Before(minRequestTime) { - s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest) + slowDown = true //Continually increase the poll interval until the user waits the proper time pollInterval += 5 } else { @@ -202,7 +210,11 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusInternalServerError, "") return } - s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) + if slowDown { + s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest) + } else { + s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) + } case deviceTokenComplete: w.Write([]byte(deviceToken.Token)) } @@ -230,44 +242,58 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { authCode, err := s.storage.GetAuthCode(code) if err != nil || s.now().After(authCode.Expiry) { - if err != storage.ErrNotFound { + errCode := http.StatusBadRequest + if err != nil && err != storage.ErrNotFound { s.logger.Errorf("failed to get auth code: %v", err) + errCode = http.StatusInternalServerError } - s.renderError(r, w, http.StatusBadRequest, "Invalid or expired auth code.") + s.renderError(r, w, errCode, "Invalid or expired auth code.") return } //Grab the device request from storage deviceReq, err := s.storage.GetDeviceRequest(userCode) if err != nil || s.now().After(deviceReq.Expiry) { - if err != storage.ErrNotFound { + errCode := http.StatusBadRequest + if err != nil && err != storage.ErrNotFound { s.logger.Errorf("failed to get device code: %v", err) + errCode = http.StatusInternalServerError } - s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.") + s.renderError(r, w, errCode, "Invalid or expired user code.") return } - reqClient, err := s.storage.GetClient(deviceReq.ClientID) + client, err := s.storage.GetClient(deviceReq.ClientID) if err != nil { - s.logger.Errorf("Failed to get reqClient %q: %v", deviceReq.ClientID, err) - s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve device client.") + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get client: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + } else { + s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) + } + return + } + if client.Secret != deviceReq.ClientSecret { + s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) return } - resp, err := s.exchangeAuthCode(w, authCode, reqClient) + resp, err := s.exchangeAuthCode(w, authCode, client) if err != nil { s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err) s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.") return } - //Grab the device request from storage + //Grab the device token from storage old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode) if err != nil || s.now().After(old.Expiry) { - if err != storage.ErrNotFound { + errCode := http.StatusBadRequest + if err != nil && err != storage.ErrNotFound { s.logger.Errorf("failed to get device token: %v", err) + errCode = http.StatusInternalServerError } - s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.") + s.renderError(r, w, errCode, "Invalid or expired device code.") return } @@ -290,12 +316,13 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { // Update refresh token in the storage, store the token and mark as complete if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil { s.logger.Errorf("failed to update device token: %v", err) - s.renderError(r, w, http.StatusInternalServerError, "") + s.renderError(r, w, http.StatusBadRequest, "") return } - if err := s.templates.deviceSuccess(r, w, reqClient.Name); err != nil { + if err := s.templates.deviceSuccess(r, w, client.Name); err != nil { s.logger.Errorf("Server template error: %v", err) + s.renderError(r, w, http.StatusNotFound, "Page not found") } default: @@ -309,9 +336,8 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { case http.MethodPost: err := r.ParseForm() if err != nil { - message := "Could not parse user code verification Request body" - s.logger.Warnf("%s : %v", message, err) - s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest) + s.logger.Warnf("Could not parse user code verification request body : %v", err) + s.renderError(r, w, http.StatusBadRequest, "") return } @@ -326,12 +352,12 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { //Find the user code in the available requests deviceRequest, err := s.storage.GetDeviceRequest(userCode) if err != nil || s.now().After(deviceRequest.Expiry) { - if err != storage.ErrNotFound { + if err != nil && err != storage.ErrNotFound { s.logger.Errorf("failed to get device request: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) } - if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, true); err != nil { + if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil { s.logger.Errorf("Server template error: %v", err) + s.renderError(r, w, http.StatusNotFound, "Page not found") } return } @@ -345,6 +371,7 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { } q := u.Query() q.Set("client_id", deviceRequest.ClientID) + q.Set("client_secret", deviceRequest.ClientSecret) q.Set("state", deviceRequest.UserCode) q.Set("response_type", "code") q.Set("redirect_uri", path.Join(s.issuerURL.Path, "/device/callback")) diff --git a/server/deviceflowhandlers_test.go b/server/deviceflowhandlers_test.go new file mode 100644 index 00000000..35306971 --- /dev/null +++ b/server/deviceflowhandlers_test.go @@ -0,0 +1,672 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" + "testing" + "time" + + "github.com/dexidp/dex/storage" +) + +func TestDeviceVerificationURI(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "/device/auth/verify_code") + + uri := s.getDeviceVerificationURI() + if uri != u.Path { + t.Errorf("Invalid verification URI. Expected %v got %v", u.Path, uri) + } +} + +func TestHandleDeviceCode(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + + tests := []struct { + testName string + clientID string + scopes []string + expectedResponseCode int + expectedServerResponse string + }{ + { + testName: "New Valid Code", + clientID: "test", + scopes: []string{"openid", "profile", "email"}, + expectedResponseCode: http.StatusOK, + }, + } + for _, tc := range tests { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "device/code") + + data := url.Values{} + data.Set("client_id", tc.clientID) + for _, scope := range tc.scopes { + data.Add("scope", scope) + } + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code) + } + + body, err := ioutil.ReadAll(rr.Body) + if err != nil { + t.Errorf("Could read token response %v", err) + } + if tc.expectedResponseCode == http.StatusOK { + var resp deviceCodeResponse + if err := json.Unmarshal(body, &resp); err != nil { + t.Errorf("Unexpected Device Code Response Format %v", string(body)) + } + } + if tc.expectedResponseCode == http.StatusBadRequest || tc.expectedResponseCode == http.StatusUnauthorized { + expectErrorResponse(tc.testName, body, tc.expectedServerResponse, t) + } + }() + } +} + +func TestDeviceCallback(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + + type formValues struct { + state string + code string + error string + } + + // Base "Control" test values + baseFormValues := formValues{ + state: "XXXX-XXXX", + code: "somecode", + } + baseAuthCode := storage.AuthCode{ + ID: "somecode", + ClientID: "testclient", + RedirectURI: "/device/callback", + Nonce: "", + Scopes: []string{"openid", "profile", "email"}, + ConnectorID: "mock", + ConnectorData: nil, + Claims: storage.Claims{}, + Expiry: now().Add(5 * time.Minute), + } + baseDeviceRequest := storage.DeviceRequest{ + UserCode: "XXXX-XXXX", + DeviceCode: "devicecode", + ClientID: "testclient", + ClientSecret: "", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(5 * time.Minute), + } + baseDeviceToken := storage.DeviceToken{ + DeviceCode: "devicecode", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + } + + tests := []struct { + testName string + expectedResponseCode int + values formValues + testAuthCode storage.AuthCode + testDeviceRequest storage.DeviceRequest + testDeviceToken storage.DeviceToken + }{ + { + testName: "Missing State", + values: formValues{ + state: "", + code: "somecode", + error: "", + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Missing Code", + values: formValues{ + state: "XXXX-XXXX", + code: "", + error: "", + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Error During Authorization", + values: formValues{ + state: "XXXX-XXXX", + code: "somecode", + error: "Error Condition", + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Expired Auth Code", + values: baseFormValues, + testAuthCode: storage.AuthCode{ + ID: "somecode", + ClientID: "testclient", + RedirectURI: "/device/callback", + Nonce: "", + Scopes: []string{"openid", "profile", "email"}, + ConnectorID: "pic", + ConnectorData: nil, + Claims: storage.Claims{}, + Expiry: now().Add(-5 * time.Minute), + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Invalid Auth Code", + values: baseFormValues, + testAuthCode: storage.AuthCode{ + ID: "somecode", + ClientID: "testclient", + RedirectURI: "/device/callback", + Nonce: "", + Scopes: []string{"openid", "profile", "email"}, + ConnectorID: "pic", + ConnectorData: nil, + Claims: storage.Claims{}, + Expiry: now().Add(5 * time.Minute), + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Expired Device Request", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: storage.DeviceRequest{ + UserCode: "XXXX-XXXX", + DeviceCode: "devicecode", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(-5 * time.Minute), + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Non-Existent User Code", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ZZZZ-ZZZZ", + DeviceCode: "devicecode", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(5 * time.Minute), + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Bad Device Request Client", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: storage.DeviceRequest{ + UserCode: "XXXX-XXXX", + DeviceCode: "devicecode", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(5 * time.Minute), + }, + expectedResponseCode: http.StatusUnauthorized, + }, + { + testName: "Bad Device Request Secret", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: storage.DeviceRequest{ + UserCode: "XXXX-XXXX", + DeviceCode: "devicecode", + ClientSecret: "foobar", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(5 * time.Minute), + }, + expectedResponseCode: http.StatusUnauthorized, + }, + { + testName: "Expired Device Token", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "devicecode", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(-5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Device Code Already Redeemed", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "devicecode", + Status: deviceTokenComplete, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Successful Exchange", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: baseDeviceRequest, + testDeviceToken: baseDeviceToken, + expectedResponseCode: http.StatusOK, + }, + } + for _, tc := range tests { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + //c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil { + t.Errorf("failed to create auth code: %v", err) + } + + if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { + t.Errorf("failed to create device request: %v", err) + } + + if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil { + t.Errorf("failed to create device token: %v", err) + } + + client := storage.Client{ + ID: "testclient", + Secret: "", + RedirectURIs: []string{"/device/callback"}, + } + if err := s.storage.CreateClient(client); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "device/callback") + q := u.Query() + q.Set("state", tc.values.state) + q.Set("code", tc.values.code) + q.Set("error", tc.values.error) + u.RawQuery = q.Encode() + req, _ := http.NewRequest("GET", u.String(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code) + } + }() + } +} + +func TestDeviceTokenResponse(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + + baseDeviceRequest := storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "foo", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(5 * time.Minute), + } + + tests := []struct { + testName string + testDeviceRequest storage.DeviceRequest + testDeviceToken storage.DeviceToken + testGrantType string + testDeviceCode string + expectedServerResponse string + expectedResponseCode int + }{ + { + testName: "Valid but pending token", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "f00bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "f00bar", + expectedServerResponse: deviceTokenPending, + expectedResponseCode: http.StatusUnauthorized, + }, + { + testName: "Invalid Grant Type", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "f00bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "f00bar", + testGrantType: grantTypeAuthorizationCode, + expectedServerResponse: errInvalidGrant, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Test Slow Down State", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "f00bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: now(), + PollIntervalSeconds: 10, + }, + testDeviceCode: "f00bar", + expectedServerResponse: deviceTokenSlowDown, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Test Expired Device Token", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "f00bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(-5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "f00bar", + expectedServerResponse: deviceTokenExpired, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Test Non-existent Device Code", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "foo", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(-5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "bar", + expectedServerResponse: errInvalidRequest, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Empty Device Code in Request", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(-5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "", + expectedServerResponse: errInvalidRequest, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Claim validated token from Device Code", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "foo", + Status: deviceTokenComplete, + Token: "{\"access_token\": \"foobar\"}", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "foo", + expectedServerResponse: "{\"access_token\": \"foobar\"}", + expectedResponseCode: http.StatusOK, + }, + } + for _, tc := range tests { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { + t.Errorf("Failed to store device token %v", err) + } + + if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil { + t.Errorf("Failed to store device token %v", err) + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "device/token") + + data := url.Values{} + grantType := grantTypeDeviceCode + if tc.testGrantType != "" { + grantType = tc.testGrantType + } + data.Set("grant_type", grantType) + data.Set("device_code", tc.testDeviceCode) + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code) + } + + body, err := ioutil.ReadAll(rr.Body) + if err != nil { + t.Errorf("Could read token response %v", err) + } + if tc.expectedResponseCode == http.StatusBadRequest || tc.expectedResponseCode == http.StatusUnauthorized { + expectErrorResponse(tc.testName, body, tc.expectedServerResponse, t) + } else if string(body) != tc.expectedServerResponse { + t.Errorf("Unexpected Server Response. Expected %v got %v", tc.expectedServerResponse, string(body)) + } + }() + } +} + +func expectErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) { + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(body, &jsonMap) + if err != nil { + t.Errorf("Unexpected error unmarshalling response: %v", err) + } + if jsonMap["error"] != expectedError { + t.Errorf("Test Case %s expected error %v, received %v", testCase, expectedError, jsonMap["error"]) + } +} + +func TestVerifyCodeResponse(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + + tests := []struct { + testName string + testDeviceRequest storage.DeviceRequest + userCode string + expectedResponseCode int + expectedRedirectPath string + }{ + { + testName: "Unknown user code", + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "f00bar", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(5 * time.Minute), + }, + userCode: "CODE-TEST", + expectedResponseCode: http.StatusBadRequest, + expectedRedirectPath: "", + }, + { + testName: "Expired user code", + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "f00bar", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(-5 * time.Minute), + }, + userCode: "ABCD-WXYZ", + expectedResponseCode: http.StatusBadRequest, + expectedRedirectPath: "", + }, + { + testName: "No user code", + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "f00bar", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(-5 * time.Minute), + }, + userCode: "", + expectedResponseCode: http.StatusBadRequest, + expectedRedirectPath: "", + }, + { + testName: "Valid user code, expect redirect to auth endpoint", + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "f00bar", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(5 * time.Minute), + }, + userCode: "ABCD-WXYZ", + expectedResponseCode: http.StatusFound, + expectedRedirectPath: "/auth", + }, + } + for _, tc := range tests { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { + t.Errorf("Failed to store device token %v", err) + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + + u.Path = path.Join(u.Path, "device/auth/verify_code") + data := url.Values{} + data.Set("user_code", tc.userCode) + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code) + } + + u, err = url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, tc.expectedRedirectPath) + + location := rr.Header().Get("Location") + if rr.Code == http.StatusFound && !strings.HasPrefix(location, u.Path) { + t.Errorf("Invalid Redirect. Expected %v got %v", u.Path, location) + } + }() + } +} diff --git a/server/handlers.go b/server/handlers.go index 32a81b98..babd5417 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -15,11 +15,12 @@ import ( "time" oidc "github.com/coreos/go-oidc" + "github.com/gorilla/mux" + jose "gopkg.in/square/go-jose.v2" + "github.com/dexidp/dex/connector" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" - "github.com/gorilla/mux" - jose "gopkg.in/square/go-jose.v2" ) // newHealthChecker returns the healthz handler. The handler runs until the @@ -153,7 +154,7 @@ type discovery struct { Keys string `json:"jwks_uri"` UserInfo string `json:"userinfo_endpoint"` DeviceEndpoint string `json:"device_authorization_endpoint"` - GrantTypes []string `json:"grant_types_supported"'` + GrantTypes []string `json:"grant_types_supported"` ResponseTypes []string `json:"response_types_supported"` Subjects []string `json:"subject_types_supported"` IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` @@ -1381,18 +1382,10 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli } } - s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) + resp := s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry) + s.writeAccessToken(w, resp) } -func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) { - resp := struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token,omitempty"` - IDToken string `json:"id_token"` - }{ - type accessTokenReponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` diff --git a/server/server.go b/server/server.go index 90e96327..661fc835 100644 --- a/server/server.go +++ b/server/server.go @@ -81,7 +81,7 @@ type Config struct { DeviceRequestsValidFor time.Duration // Defaults to 5 minutes // If set, the server will use this connector to handle password grants PasswordConnector string - + GCFrequency time.Duration // Defaults to 5 minutes // If specified, the server will use this function for determining time. diff --git a/server/server_test.go b/server/server_test.go index 8fe84c9a..edc00832 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -8,11 +8,13 @@ import ( "encoding/pem" "errors" "fmt" + "io/ioutil" "net/http" "net/http/httptest" "net/http/httputil" "net/url" "os" + "path" "reflect" "sort" "strings" @@ -203,6 +205,274 @@ func TestDiscovery(t *testing.T) { } } +type oauth2Tests struct { + clientID string + tests []test +} + +type test struct { + name string + // If specified these set of scopes will be used during the test case. + scopes []string + // handleToken provides the OAuth2 token response for the integration test. + handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token, *mock.Callback) error +} + +func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) oauth2Tests { + requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"} + + // Used later when configuring test servers to set how long id_tokens will be valid for. + // + // The actual value of 30s is completely arbitrary. We just need to set a value + // so tests can compute the expected "expires_in" field. + idTokensValidFor := time.Second * 30 + + oidcConfig := &oidc.Config{SkipClientIDCheck: true} + + return oauth2Tests{ + clientID: clientID, + tests: []test{ + { + name: "verify ID Token", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + idToken, ok := token.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id token found") + } + if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + return nil + }, + }, + { + name: "fetch userinfo", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + ui, err := p.UserInfo(ctx, config.TokenSource(ctx, token)) + if err != nil { + return fmt.Errorf("failed to fetch userinfo: %v", err) + } + if conn.Identity.Email != ui.Email { + return fmt.Errorf("expected email to be %v, got %v", conn.Identity.Email, ui.Email) + } + return nil + }, + }, + { + name: "verify id token and oauth2 token expiry", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + expectedExpiry := now().Add(idTokensValidFor) + + timeEq := func(t1, t2 time.Time, within time.Duration) bool { + return t1.Sub(t2) < within + } + + if !timeEq(token.Expiry, expectedExpiry, time.Second) { + return fmt.Errorf("expected expired_in to be %s, got %s", expectedExpiry, token.Expiry) + } + + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id token found") + } + idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) + if err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + if !timeEq(idToken.Expiry, expectedExpiry, time.Second) { + return fmt.Errorf("expected id token expiry to be %s, got %s", expectedExpiry, token.Expiry) + } + return nil + }, + }, + { + name: "verify at_hash", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id token found") + } + idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) + if err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + + var claims struct { + AtHash string `json:"at_hash"` + } + if err := idToken.Claims(&claims); err != nil { + return fmt.Errorf("failed to decode raw claims: %v", err) + } + if claims.AtHash == "" { + return errors.New("no at_hash value in id_token") + } + wantAtHash, err := accessTokenHash(jose.RS256, token.AccessToken) + if err != nil { + return fmt.Errorf("computed expected at hash: %v", err) + } + if wantAtHash != claims.AtHash { + return fmt.Errorf("expected at_hash=%q got=%q", wantAtHash, claims.AtHash) + } + + return nil + }, + }, + { + name: "refresh token", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + // have to use time.Now because the OAuth2 package uses it. + token.Expiry = time.Now().Add(time.Second * -10) + if token.Valid() { + return errors.New("token shouldn't be valid") + } + + newToken, err := config.TokenSource(ctx, token).Token() + if err != nil { + return fmt.Errorf("failed to refresh token: %v", err) + } + if token.RefreshToken == newToken.RefreshToken { + return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken) + } + + if _, err := config.TokenSource(ctx, token).Token(); err == nil { + return errors.New("was able to redeem the same refresh token twice") + } + return nil + }, + }, + { + name: "refresh with explicit scopes", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + v := url.Values{} + v.Add("client_id", clientID) + v.Add("client_secret", clientSecret) + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", token.RefreshToken) + v.Add("scope", strings.Join(requestedScopes, " ")) + resp, err := http.PostForm(p.Endpoint().TokenURL, v) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + panic(err) + } + return fmt.Errorf("unexpected response: %s", dump) + } + return nil + }, + }, + { + name: "refresh with extra spaces", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + v := url.Values{} + v.Add("client_id", clientID) + v.Add("client_secret", clientSecret) + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", token.RefreshToken) + + // go-oidc adds an additional space before scopes when refreshing. + // Since we support that client we choose to be more relaxed about + // scope parsing, disregarding extra whitespace. + v.Add("scope", " "+strings.Join(requestedScopes, " ")) + resp, err := http.PostForm(p.Endpoint().TokenURL, v) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + panic(err) + } + return fmt.Errorf("unexpected response: %s", dump) + } + return nil + }, + }, + { + name: "refresh with unauthorized scopes", + scopes: []string{"openid", "email"}, + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + v := url.Values{} + v.Add("client_id", clientID) + v.Add("client_secret", clientSecret) + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", token.RefreshToken) + // Request a scope that wasn't requestd initially. + v.Add("scope", "oidc email profile") + resp, err := http.PostForm(p.Endpoint().TokenURL, v) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + panic(err) + } + return fmt.Errorf("unexpected response: %s", dump) + } + return nil + }, + }, + { + // This test ensures that the connector.RefreshConnector interface is being + // used when clients request a refresh token. + name: "refresh with identity changes", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + // have to use time.Now because the OAuth2 package uses it. + token.Expiry = time.Now().Add(time.Second * -10) + if token.Valid() { + return errors.New("token shouldn't be valid") + } + + ident := connector.Identity{ + UserID: "fooid", + Username: "foo", + Email: "foo@bar.com", + EmailVerified: true, + Groups: []string{"foo", "bar"}, + } + conn.Identity = ident + + type claims struct { + Username string `json:"name"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Groups []string `json:"groups"` + } + want := claims{ident.Username, ident.Email, ident.EmailVerified, ident.Groups} + + newToken, err := config.TokenSource(ctx, token).Token() + if err != nil { + return fmt.Errorf("failed to refresh token: %v", err) + } + rawIDToken, ok := newToken.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id_token in refreshed token") + } + idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) + if err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + var got claims + if err := idToken.Claims(&got); err != nil { + return fmt.Errorf("failed to unmarshal claims: %v", err) + } + + if diff := pretty.Compare(want, got); diff != "" { + return fmt.Errorf("got identity != want identity: %s", diff) + } + return nil + }, + }, + }, + } +} + // TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server // which requires no interaction to login, logs in through a test client, then passes the client // and returned token to the test. @@ -226,255 +496,8 @@ func TestOAuth2CodeFlow(t *testing.T) { // Connector used by the tests. var conn *mock.Callback - oidcConfig := &oidc.Config{SkipClientIDCheck: true} - - tests := []struct { - name string - // If specified these set of scopes will be used during the test case. - scopes []string - // handleToken provides the OAuth2 token response for the integration test. - handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error - }{ - { - name: "verify ID Token", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - idToken, ok := token.Extra("id_token").(string) - if !ok { - return fmt.Errorf("no id token found") - } - if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil { - return fmt.Errorf("failed to verify id token: %v", err) - } - return nil - }, - }, - { - name: "fetch userinfo", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - ui, err := p.UserInfo(ctx, config.TokenSource(ctx, token)) - if err != nil { - return fmt.Errorf("failed to fetch userinfo: %v", err) - } - if conn.Identity.Email != ui.Email { - return fmt.Errorf("expected email to be %v, got %v", conn.Identity.Email, ui.Email) - } - return nil - }, - }, - { - name: "verify id token and oauth2 token expiry", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - expectedExpiry := now().Add(idTokensValidFor) - - timeEq := func(t1, t2 time.Time, within time.Duration) bool { - return t1.Sub(t2) < within - } - - if !timeEq(token.Expiry, expectedExpiry, time.Second) { - return fmt.Errorf("expected expired_in to be %s, got %s", expectedExpiry, token.Expiry) - } - - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - return fmt.Errorf("no id token found") - } - idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) - if err != nil { - return fmt.Errorf("failed to verify id token: %v", err) - } - if !timeEq(idToken.Expiry, expectedExpiry, time.Second) { - return fmt.Errorf("expected id token expiry to be %s, got %s", expectedExpiry, token.Expiry) - } - return nil - }, - }, - { - name: "verify at_hash", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - return fmt.Errorf("no id token found") - } - idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) - if err != nil { - return fmt.Errorf("failed to verify id token: %v", err) - } - - var claims struct { - AtHash string `json:"at_hash"` - } - if err := idToken.Claims(&claims); err != nil { - return fmt.Errorf("failed to decode raw claims: %v", err) - } - if claims.AtHash == "" { - return errors.New("no at_hash value in id_token") - } - wantAtHash, err := accessTokenHash(jose.RS256, token.AccessToken) - if err != nil { - return fmt.Errorf("computed expected at hash: %v", err) - } - if wantAtHash != claims.AtHash { - return fmt.Errorf("expected at_hash=%q got=%q", wantAtHash, claims.AtHash) - } - - return nil - }, - }, - { - name: "refresh token", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - // have to use time.Now because the OAuth2 package uses it. - token.Expiry = time.Now().Add(time.Second * -10) - if token.Valid() { - return errors.New("token shouldn't be valid") - } - - newToken, err := config.TokenSource(ctx, token).Token() - if err != nil { - return fmt.Errorf("failed to refresh token: %v", err) - } - if token.RefreshToken == newToken.RefreshToken { - return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken) - } - - if _, err := config.TokenSource(ctx, token).Token(); err == nil { - return errors.New("was able to redeem the same refresh token twice") - } - return nil - }, - }, - { - name: "refresh with explicit scopes", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - v := url.Values{} - v.Add("client_id", clientID) - v.Add("client_secret", clientSecret) - v.Add("grant_type", "refresh_token") - v.Add("refresh_token", token.RefreshToken) - v.Add("scope", strings.Join(requestedScopes, " ")) - resp, err := http.PostForm(p.Endpoint().TokenURL, v) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - dump, err := httputil.DumpResponse(resp, true) - if err != nil { - panic(err) - } - return fmt.Errorf("unexpected response: %s", dump) - } - return nil - }, - }, - { - name: "refresh with extra spaces", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - v := url.Values{} - v.Add("client_id", clientID) - v.Add("client_secret", clientSecret) - v.Add("grant_type", "refresh_token") - v.Add("refresh_token", token.RefreshToken) - - // go-oidc adds an additional space before scopes when refreshing. - // Since we support that client we choose to be more relaxed about - // scope parsing, disregarding extra whitespace. - v.Add("scope", " "+strings.Join(requestedScopes, " ")) - resp, err := http.PostForm(p.Endpoint().TokenURL, v) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - dump, err := httputil.DumpResponse(resp, true) - if err != nil { - panic(err) - } - return fmt.Errorf("unexpected response: %s", dump) - } - return nil - }, - }, - { - name: "refresh with unauthorized scopes", - scopes: []string{"openid", "email"}, - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - v := url.Values{} - v.Add("client_id", clientID) - v.Add("client_secret", clientSecret) - v.Add("grant_type", "refresh_token") - v.Add("refresh_token", token.RefreshToken) - // Request a scope that wasn't requestd initially. - v.Add("scope", "oidc email profile") - resp, err := http.PostForm(p.Endpoint().TokenURL, v) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode == http.StatusOK { - dump, err := httputil.DumpResponse(resp, true) - if err != nil { - panic(err) - } - return fmt.Errorf("unexpected response: %s", dump) - } - return nil - }, - }, - { - // This test ensures that the connector.RefreshConnector interface is being - // used when clients request a refresh token. - name: "refresh with identity changes", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - // have to use time.Now because the OAuth2 package uses it. - token.Expiry = time.Now().Add(time.Second * -10) - if token.Valid() { - return errors.New("token shouldn't be valid") - } - - ident := connector.Identity{ - UserID: "fooid", - Username: "foo", - Email: "foo@bar.com", - EmailVerified: true, - Groups: []string{"foo", "bar"}, - } - conn.Identity = ident - - type claims struct { - Username string `json:"name"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - Groups []string `json:"groups"` - } - want := claims{ident.Username, ident.Email, ident.EmailVerified, ident.Groups} - - newToken, err := config.TokenSource(ctx, token).Token() - if err != nil { - return fmt.Errorf("failed to refresh token: %v", err) - } - rawIDToken, ok := newToken.Extra("id_token").(string) - if !ok { - return fmt.Errorf("no id_token in refreshed token") - } - idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) - if err != nil { - return fmt.Errorf("failed to verify id token: %v", err) - } - var got claims - if err := idToken.Claims(&got); err != nil { - return fmt.Errorf("failed to unmarshal claims: %v", err) - } - - if diff := pretty.Compare(want, got); diff != "" { - return fmt.Errorf("got identity != want identity: %s", diff) - } - return nil - }, - }, - } - - for _, tc := range tests { + tests := makeOAuth2Tests(clientID, clientSecret, now) + for _, tc := range tests.tests { func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -540,7 +563,7 @@ func TestOAuth2CodeFlow(t *testing.T) { t.Errorf("failed to exchange code for token: %v", err) return } - err = tc.handleToken(ctx, p, oauth2Config, token) + err = tc.handleToken(ctx, p, oauth2Config, token, conn) if err != nil { t.Errorf("%s: %v", tc.name, err) } @@ -1253,3 +1276,160 @@ func TestRefreshTokenFlow(t *testing.T) { t.Errorf("Token refreshed with invalid refresh token, error expected.") } } + +// TestOAuth2DeviceFlow runs device flow integration tests against a test server +func TestOAuth2DeviceFlow(t *testing.T) { + clientID := "testclient" + clientSecret := "" + requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"} + + t0 := time.Now() + + // Always have the time function used by the server return the same time so + // we can predict expected values of "expires_in" fields exactly. + now := func() time.Time { return t0 } + + // Connector used by the tests. + var conn *mock.Callback + idTokensValidFor := time.Second * 30 + + for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + c.IDTokensValidFor = idTokensValidFor + }) + defer httpServer.Close() + + mockConn := s.connectors["mock"] + conn = mockConn.Connector.(*mock.Callback) + + p, err := oidc.NewProvider(ctx, httpServer.URL) + if err != nil { + t.Fatalf("failed to get provider: %v", err) + } + + //Add the Clients to the test server + client := storage.Client{ + ID: clientID, + //Secret: "testclientsecret", + RedirectURIs: []string{"/non-root-path/device/callback"}, + } + if err := s.storage.CreateClient(client); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + //Grab the issuer that we'll reuse for the different endpoints to hit + issuer, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + + //Send a new Device Request + codeURL, _ := url.Parse(issuer.String()) + codeURL.Path = path.Join(codeURL.Path, "device/code") + + data := url.Values{} + data.Set("client_id", clientID) + data.Add("scope", strings.Join(requestedScopes, " ")) + //for _, scope := range requestedScopes { + // data.Add("scope", scope) + //} + resp, err := http.PostForm(codeURL.String(), data) + if err != nil { + t.Errorf("Could not request device code: %v", err) + } + defer resp.Body.Close() + responseBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read device code response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + //Parse the code response + var deviceCode deviceCodeResponse + if err := json.Unmarshal(responseBody, &deviceCode); err != nil { + t.Errorf("Unexpected Device Code Response Format %v", string(responseBody)) + } + + //Mock the user hitting the verification URI and posting the form + verifyURL, _ := url.Parse(issuer.String()) + verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code") + urlData := url.Values{} + urlData.Set("user_code", deviceCode.UserCode) + resp, err = http.PostForm(verifyURL.String(), urlData) + if err != nil { + t.Errorf("Error Posting Form: %v", err) + } + defer resp.Body.Close() + responseBody, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read verification response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + //Hit the Token Endpoint, and try and get an access token + tokenURL, _ := url.Parse(issuer.String()) + tokenURL.Path = path.Join(tokenURL.Path, "/device/token") + v := url.Values{} + v.Add("grant_type", grantTypeDeviceCode) + v.Add("device_code", deviceCode.DeviceCode) + resp, err = http.PostForm(tokenURL.String(), v) + if err != nil { + t.Errorf("Could not request device token: %v", err) + } + defer resp.Body.Close() + responseBody, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read device token response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + //Parse the response + var tokenRes accessTokenReponse + if err := json.Unmarshal(responseBody, &tokenRes); err != nil { + t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody)) + } + + token := &oauth2.Token{ + AccessToken: tokenRes.AccessToken, + TokenType: tokenRes.TokenType, + RefreshToken: tokenRes.RefreshToken, + } + raw := make(map[string]interface{}) + json.Unmarshal(responseBody, &raw) // no error checks for optional fields + token = token.WithExtra(raw) + if secs := tokenRes.ExpiresIn; secs > 0 { + token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) + } + + //Run token tests to validate info is correct + // Create the OAuth2 config. + oauth2Config := &oauth2.Config{ + ClientID: client.ID, + ClientSecret: client.Secret, + Endpoint: p.Endpoint(), + Scopes: requestedScopes, + RedirectURL: "/non-root-path/device/callback", + } + if len(tc.scopes) != 0 { + oauth2Config.Scopes = tc.scopes + } + err = tc.handleToken(ctx, p, oauth2Config, token, conn) + if err != nil { + t.Errorf("%s: %v", tc.name, err) + } + }() + } +} diff --git a/server/templates.go b/server/templates.go index 934362b7..dd2678ea 100644 --- a/server/templates.go +++ b/server/templates.go @@ -250,6 +250,9 @@ func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error { + if lastWasInvalid { + w.WriteHeader(http.StatusBadRequest) + } data := struct { PostURL string UserCode string diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 6edc8350..224e1bfa 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -843,11 +843,12 @@ func testGC(t *testing.T, s storage.Storage) { } d := storage.DeviceRequest{ - UserCode: userCode, - DeviceCode: storage.NewID(), - ClientID: "client1", - Scopes: []string{"openid", "email"}, - Expiry: expiry, + UserCode: userCode, + DeviceCode: storage.NewID(), + ClientID: "client1", + ClientSecret: "secret1", + Scopes: []string{"openid", "email"}, + Expiry: expiry, } if err := s.CreateDeviceRequest(d); err != nil { @@ -863,9 +864,9 @@ func testGC(t *testing.T, s storage.Storage) { t.Errorf("expected no device garbage collection results, got %#v", result) } } - //if _, err := s.GetDeviceRequest(d.UserCode); err != nil { - // t.Errorf("expected to be able to get auth request after GC: %v", err) - //} + if _, err := s.GetDeviceRequest(d.UserCode); err != nil { + t.Errorf("expected to be able to get auth request after GC: %v", err) + } } if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) @@ -873,18 +874,19 @@ func testGC(t *testing.T, s storage.Storage) { t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests) } - //TODO add this code back once Getters are written for device requests - //if _, err := s.GetDeviceRequest(d.UserCode); err == nil { - // t.Errorf("expected device request to be GC'd") - //} else if err != storage.ErrNotFound { - // t.Errorf("expected storage.ErrNotFound, got %v", err) - //} + if _, err := s.GetDeviceRequest(d.UserCode); err == nil { + t.Errorf("expected device request to be GC'd") + } else if err != storage.ErrNotFound { + t.Errorf("expected storage.ErrNotFound, got %v", err) + } dt := storage.DeviceToken{ - DeviceCode: storage.NewID(), - Status: "pending", - Token: "foo", - Expiry: expiry, + DeviceCode: storage.NewID(), + Status: "pending", + Token: "foo", + Expiry: expiry, + LastRequestTime: time.Now(), + PollIntervalSeconds: 0, } if err := s.CreateDeviceToken(dt); err != nil { @@ -969,11 +971,12 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { panic(err) } d1 := storage.DeviceRequest{ - UserCode: userCode, - DeviceCode: storage.NewID(), - ClientID: "client1", - Scopes: []string{"openid", "email"}, - Expiry: neverExpire, + UserCode: userCode, + DeviceCode: storage.NewID(), + ClientID: "client1", + ClientSecret: "secret1", + Scopes: []string{"openid", "email"}, + Expiry: neverExpire, } if err := s.CreateDeviceRequest(d1); err != nil { diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index f41831cd..e8abe3d0 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -595,7 +595,7 @@ func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) defer cancel() - return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t)) + return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t)) } func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) { diff --git a/storage/etcd/etcd_test.go b/storage/etcd/etcd_test.go index 4c17fdf1..122d7dae 100644 --- a/storage/etcd/etcd_test.go +++ b/storage/etcd/etcd_test.go @@ -44,6 +44,8 @@ func cleanDB(c *conn) error { passwordPrefix, offlineSessionPrefix, connectorPrefix, + deviceRequestPrefix, + deviceTokenPrefix, } { _, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix()) if err != nil { diff --git a/storage/etcd/types.go b/storage/etcd/types.go index cc8045c9..def95b55 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -219,20 +219,22 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { // DeviceRequest is a mirrored struct from storage with JSON struct tags type DeviceRequest struct { - UserCode string `json:"user_code"` - DeviceCode string `json:"device_code"` - ClientID string `json:"client_id"` - Scopes []string `json:"scopes"` - Expiry time.Time `json:"expiry"` + UserCode string `json:"user_code"` + DeviceCode string `json:"device_code"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Scopes []string `json:"scopes"` + Expiry time.Time `json:"expiry"` } func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest { return DeviceRequest{ - UserCode: d.UserCode, - DeviceCode: d.DeviceCode, - ClientID: d.ClientID, - Scopes: d.Scopes, - Expiry: d.Expiry, + UserCode: d.UserCode, + DeviceCode: d.DeviceCode, + ClientID: d.ClientID, + ClientSecret: d.ClientSecret, + Scopes: d.Scopes, + Expiry: d.Expiry, } } diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 61794ccf..f856a731 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -672,10 +672,11 @@ type DeviceRequest struct { k8sapi.TypeMeta `json:",inline"` k8sapi.ObjectMeta `json:"metadata,omitempty"` - DeviceCode string `json:"device_code,omitempty"` - CLientID string `json:"client_id,omitempty"` - Scopes []string `json:"scopes,omitempty"` - Expiry time.Time `json:"expiry"` + DeviceCode string `json:"device_code,omitempty"` + ClientID string `json:"client_id,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + Scopes []string `json:"scopes,omitempty"` + Expiry time.Time `json:"expiry"` } // AuthRequestList is a list of AuthRequests. @@ -695,21 +696,23 @@ func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceReque Name: strings.ToLower(a.UserCode), Namespace: cli.namespace, }, - DeviceCode: a.DeviceCode, - CLientID: a.ClientID, - Scopes: a.Scopes, - Expiry: a.Expiry, + DeviceCode: a.DeviceCode, + ClientID: a.ClientID, + ClientSecret: a.ClientSecret, + Scopes: a.Scopes, + Expiry: a.Expiry, } return req } func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest { return storage.DeviceRequest{ - UserCode: strings.ToUpper(req.ObjectMeta.Name), - DeviceCode: req.DeviceCode, - ClientID: req.CLientID, - Scopes: req.Scopes, - Expiry: req.Expiry, + UserCode: strings.ToUpper(req.ObjectMeta.Name), + DeviceCode: req.DeviceCode, + ClientID: req.ClientID, + ClientSecret: req.ClientSecret, + Scopes: req.Scopes, + Expiry: req.Expiry, } } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index a85c972b..b74b76e1 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -888,12 +888,12 @@ func (c *conn) delete(table, field, id string) error { func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { _, err := c.Exec(` insert into device_request ( - user_code, device_code, client_id, scopes, expiry + user_code, device_code, client_id, client_secret, scopes, expiry ) values ( - $1, $2, $3, $4, $5 + $1, $2, $3, $4, $5, $6 );`, - d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.Expiry, + d.UserCode, d.DeviceCode, d.ClientID, d.ClientSecret, encoder(d.Scopes), d.Expiry, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -930,10 +930,10 @@ func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) { err = q.QueryRow(` select - device_code, client_id, scopes, expiry + device_code, client_id, client_secret, scopes, expiry from device_request where user_code = $1; `, userCode).Scan( - &d.DeviceCode, &d.ClientID, decoder(&d.Scopes), &d.Expiry, + &d.DeviceCode, &d.ClientID, &d.ClientSecret, decoder(&d.Scopes), &d.Expiry, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index e399d2b8..43230581 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -235,6 +235,7 @@ var migrations = []migration{ user_code text not null primary key, device_code text not null, client_id text not null, + client_secret text , scopes bytea not null, -- JSON array of strings expiry timestamptz not null );`, diff --git a/storage/storage.go b/storage/storage.go index 005e9190..2134678e 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -392,6 +392,8 @@ type DeviceRequest struct { DeviceCode string //The client ID the code is for ClientID string + //The Client Secret + ClientSecret string //The scopes the device requests Scopes []string //The expire time