forked from mystiq/dex
fix: use /token endpoint to get tokens with device flow
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
parent
3c5a631ce3
commit
1211a86d58
4 changed files with 85 additions and 61 deletions
|
@ -151,7 +151,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleDeviceTokenGrant(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodPost:
|
case http.MethodPost:
|
||||||
|
@ -162,71 +162,75 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceCode := r.Form.Get("device_code")
|
|
||||||
if deviceCode == "" {
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
grantType := r.PostFormValue("grant_type")
|
grantType := r.PostFormValue("grant_type")
|
||||||
if grantType != grantTypeDeviceCode {
|
if grantType != grantTypeDeviceCode {
|
||||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := s.now()
|
s.handleDeviceToken(w, r)
|
||||||
|
|
||||||
// Grab the device token, check validity
|
|
||||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
|
||||||
if err != nil {
|
|
||||||
if err != storage.ErrNotFound {
|
|
||||||
s.logger.Errorf("failed to get device code: %v", err)
|
|
||||||
}
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
} else if now.After(deviceToken.Expiry) {
|
|
||||||
s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rate Limiting check
|
|
||||||
slowDown := false
|
|
||||||
pollInterval := deviceToken.PollIntervalSeconds
|
|
||||||
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
|
||||||
if now.Before(minRequestTime) {
|
|
||||||
slowDown = true
|
|
||||||
// Continually increase the poll interval until the user waits the proper time
|
|
||||||
pollInterval += 5
|
|
||||||
} else {
|
|
||||||
pollInterval = 5
|
|
||||||
}
|
|
||||||
|
|
||||||
switch deviceToken.Status {
|
|
||||||
case deviceTokenPending:
|
|
||||||
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
|
||||||
old.PollIntervalSeconds = pollInterval
|
|
||||||
old.LastRequestTime = now
|
|
||||||
return old, nil
|
|
||||||
}
|
|
||||||
// Update device token last request time in storage
|
|
||||||
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
|
||||||
s.logger.Errorf("failed to update device token: %v", err)
|
|
||||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if slowDown {
|
|
||||||
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
|
||||||
} else {
|
|
||||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
|
||||||
}
|
|
||||||
case deviceTokenComplete:
|
|
||||||
w.Write([]byte(deviceToken.Token))
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
deviceCode := r.Form.Get("device_code")
|
||||||
|
if deviceCode == "" {
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := s.now()
|
||||||
|
|
||||||
|
// Grab the device token, check validity
|
||||||
|
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device code: %v", err)
|
||||||
|
}
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
} else if now.After(deviceToken.Expiry) {
|
||||||
|
s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rate Limiting check
|
||||||
|
slowDown := false
|
||||||
|
pollInterval := deviceToken.PollIntervalSeconds
|
||||||
|
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
||||||
|
if now.Before(minRequestTime) {
|
||||||
|
slowDown = true
|
||||||
|
// Continually increase the poll interval until the user waits the proper time
|
||||||
|
pollInterval += 5
|
||||||
|
} else {
|
||||||
|
pollInterval = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
switch deviceToken.Status {
|
||||||
|
case deviceTokenPending:
|
||||||
|
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||||
|
old.PollIntervalSeconds = pollInterval
|
||||||
|
old.LastRequestTime = now
|
||||||
|
return old, nil
|
||||||
|
}
|
||||||
|
// Update device token last request time in storage
|
||||||
|
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
||||||
|
s.logger.Errorf("failed to update device token: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if slowDown {
|
||||||
|
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
||||||
|
} else {
|
||||||
|
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
case deviceTokenComplete:
|
||||||
|
w.Write([]byte(deviceToken.Token))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
|
|
|
@ -652,7 +652,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
|
||||||
http.Redirect(w, r, u.String(), http.StatusSeeOther)
|
http.Redirect(w, r, u.String(), http.StatusSeeOther)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client)) {
|
||||||
clientID, clientSecret, ok := r.BasicAuth()
|
clientID, clientSecret, ok := r.BasicAuth()
|
||||||
if ok {
|
if ok {
|
||||||
var err error
|
var err error
|
||||||
|
@ -689,14 +689,33 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handler(w, r, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "method not allowed", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Could not parse request body: %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
grantType := r.PostFormValue("grant_type")
|
grantType := r.PostFormValue("grant_type")
|
||||||
switch grantType {
|
switch grantType {
|
||||||
|
case grantTypeDeviceCode:
|
||||||
|
s.handleDeviceToken(w, r)
|
||||||
case grantTypeAuthorizationCode:
|
case grantTypeAuthorizationCode:
|
||||||
s.handleAuthCode(w, r, client)
|
s.withClientFromStorage(w, r, s.handleAuthCode)
|
||||||
case grantTypeRefreshToken:
|
case grantTypeRefreshToken:
|
||||||
s.handleRefreshToken(w, r, client)
|
s.withClientFromStorage(w, r, s.handleRefreshToken)
|
||||||
case grantTypePassword:
|
case grantTypePassword:
|
||||||
s.handlePasswordGrant(w, r, client)
|
s.withClientFromStorage(w, r, s.handlePasswordGrant)
|
||||||
default:
|
default:
|
||||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
|
@ -320,7 +320,8 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||||
handleFunc("/device", s.handleDeviceExchange)
|
handleFunc("/device", s.handleDeviceExchange)
|
||||||
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
||||||
handleFunc("/device/code", s.handleDeviceCode)
|
handleFunc("/device/code", s.handleDeviceCode)
|
||||||
handleFunc("/device/token", s.handleDeviceToken)
|
// TODO(nabokihms): deprecate and remove this endpoint, use /token instead
|
||||||
|
handleFunc("/device/token", s.handleDeviceTokenGrant)
|
||||||
handleFunc(deviceCallbackURI, s.handleDeviceCallback)
|
handleFunc(deviceCallbackURI, s.handleDeviceCallback)
|
||||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Strip the X-Remote-* headers to prevent security issues on
|
// Strip the X-Remote-* headers to prevent security issues on
|
||||||
|
|
|
@ -1583,7 +1583,7 @@ func TestOAuth2DeviceFlow(t *testing.T) {
|
||||||
|
|
||||||
// Hit the Token Endpoint, and try and get an access token
|
// Hit the Token Endpoint, and try and get an access token
|
||||||
tokenURL, _ := url.Parse(issuer.String())
|
tokenURL, _ := url.Parse(issuer.String())
|
||||||
tokenURL.Path = path.Join(tokenURL.Path, "/device/token")
|
tokenURL.Path = path.Join(tokenURL.Path, "/token")
|
||||||
v := url.Values{}
|
v := url.Values{}
|
||||||
v.Add("grant_type", grantTypeDeviceCode)
|
v.Add("grant_type", grantTypeDeviceCode)
|
||||||
v.Add("device_code", deviceCode.DeviceCode)
|
v.Add("device_code", deviceCode.DeviceCode)
|
||||||
|
|
Loading…
Reference in a new issue