diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index a73dafe8..eef50866 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -151,7 +151,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { +func (s *Server) handleDeviceTokenGrant(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodPost: @@ -162,71 +162,75 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { return } - deviceCode := r.Form.Get("device_code") - if deviceCode == "" { - s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest) - return - } - grantType := r.PostFormValue("grant_type") if grantType != grantTypeDeviceCode { s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) return } - now := s.now() - - // Grab the device token, check validity - deviceToken, err := s.storage.GetDeviceToken(deviceCode) - if err != nil { - if err != storage.ErrNotFound { - s.logger.Errorf("failed to get device code: %v", err) - } - s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest) - return - } else if now.After(deviceToken.Expiry) { - s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest) - return - } - - // Rate Limiting check - slowDown := false - pollInterval := deviceToken.PollIntervalSeconds - minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval)) - if now.Before(minRequestTime) { - slowDown = true - // Continually increase the poll interval until the user waits the proper time - pollInterval += 5 - } else { - pollInterval = 5 - } - - switch deviceToken.Status { - case deviceTokenPending: - updater := func(old storage.DeviceToken) (storage.DeviceToken, error) { - old.PollIntervalSeconds = pollInterval - old.LastRequestTime = now - return old, nil - } - // Update device token last request time in storage - if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil { - s.logger.Errorf("failed to update device token: %v", err) - s.renderError(r, w, http.StatusInternalServerError, "") - return - } - if slowDown { - s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest) - } else { - s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) - } - case deviceTokenComplete: - w.Write([]byte(deviceToken.Token)) - } + s.handleDeviceToken(w, r) default: s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") } } +func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { + deviceCode := r.Form.Get("device_code") + if deviceCode == "" { + s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest) + return + } + + now := s.now() + + // Grab the device token, check validity + deviceToken, err := s.storage.GetDeviceToken(deviceCode) + if err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get device code: %v", err) + } + s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest) + return + } else if now.After(deviceToken.Expiry) { + s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest) + return + } + + // Rate Limiting check + slowDown := false + pollInterval := deviceToken.PollIntervalSeconds + minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval)) + if now.Before(minRequestTime) { + slowDown = true + // Continually increase the poll interval until the user waits the proper time + pollInterval += 5 + } else { + pollInterval = 5 + } + + switch deviceToken.Status { + case deviceTokenPending: + updater := func(old storage.DeviceToken) (storage.DeviceToken, error) { + old.PollIntervalSeconds = pollInterval + old.LastRequestTime = now + return old, nil + } + // Update device token last request time in storage + if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil { + s.logger.Errorf("failed to update device token: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "") + return + } + if slowDown { + s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest) + } else { + s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) + } + case deviceTokenComplete: + w.Write([]byte(deviceToken.Token)) + } +} + func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: diff --git a/server/handlers.go b/server/handlers.go index eb65f490..3bf43b16 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -652,7 +652,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe http.Redirect(w, r, u.String(), http.StatusSeeOther) } -func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { +func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client)) { clientID, clientSecret, ok := r.BasicAuth() if ok { var err error @@ -689,14 +689,33 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { return } + handler(w, r, client) +} + +func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method != http.MethodPost { + s.tokenErrHelper(w, errInvalidRequest, "method not allowed", http.StatusBadRequest) + return + } + + err := r.ParseForm() + if err != nil { + s.logger.Errorf("Could not parse request body: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + return + } + grantType := r.PostFormValue("grant_type") switch grantType { + case grantTypeDeviceCode: + s.handleDeviceToken(w, r) case grantTypeAuthorizationCode: - s.handleAuthCode(w, r, client) + s.withClientFromStorage(w, r, s.handleAuthCode) case grantTypeRefreshToken: - s.handleRefreshToken(w, r, client) + s.withClientFromStorage(w, r, s.handleRefreshToken) case grantTypePassword: - s.handlePasswordGrant(w, r, client) + s.withClientFromStorage(w, r, s.handlePasswordGrant) default: s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) } diff --git a/server/server.go b/server/server.go index a79b7cfd..1294ff76 100644 --- a/server/server.go +++ b/server/server.go @@ -320,7 +320,8 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc("/device", s.handleDeviceExchange) handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/code", s.handleDeviceCode) - handleFunc("/device/token", s.handleDeviceToken) + // TODO(nabokihms): deprecate and remove this endpoint, use /token instead + handleFunc("/device/token", s.handleDeviceTokenGrant) handleFunc(deviceCallbackURI, s.handleDeviceCallback) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { // Strip the X-Remote-* headers to prevent security issues on diff --git a/server/server_test.go b/server/server_test.go index 87ca6c17..34aa9eec 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1583,7 +1583,7 @@ func TestOAuth2DeviceFlow(t *testing.T) { // Hit the Token Endpoint, and try and get an access token tokenURL, _ := url.Parse(issuer.String()) - tokenURL.Path = path.Join(tokenURL.Path, "/device/token") + tokenURL.Path = path.Join(tokenURL.Path, "/token") v := url.Values{} v.Add("grant_type", grantTypeDeviceCode) v.Add("device_code", deviceCode.DeviceCode)