Merge pull request #2010 from flant/switch-device-token-endpoint-to-token

fix: use /token endpoint to get tokens with device flow
This commit is contained in:
Márk Sági-Kazár 2021-05-01 13:24:55 +02:00 committed by GitHub
commit 94a2b3ed87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 243 additions and 196 deletions

View file

@ -151,7 +151,9 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Request) {
s.logger.Warn(`The deprecated "/device/token" endpoint was called. It will be removed, use "/token" instead.`)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
switch r.Method { switch r.Method {
case http.MethodPost: case http.MethodPost:
@ -162,18 +164,25 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
return return
} }
deviceCode := r.Form.Get("device_code")
if deviceCode == "" {
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
return
}
grantType := r.PostFormValue("grant_type") grantType := r.PostFormValue("grant_type")
if grantType != grantTypeDeviceCode { if grantType != grantTypeDeviceCode {
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
return return
} }
s.handleDeviceToken(w, r)
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
deviceCode := r.Form.Get("device_code")
if deviceCode == "" {
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
return
}
now := s.now() now := s.now()
// Grab the device token, check validity // Grab the device token, check validity
@ -222,9 +231,6 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
case deviceTokenComplete: case deviceTokenComplete:
w.Write([]byte(deviceToken.Token)) w.Write([]byte(deviceToken.Token))
} }
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
} }
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {

View file

@ -651,7 +651,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
http.Redirect(w, r, u.String(), http.StatusSeeOther) http.Redirect(w, r, u.String(), http.StatusSeeOther)
} }
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client)) {
clientID, clientSecret, ok := r.BasicAuth() clientID, clientSecret, ok := r.BasicAuth()
if ok { if ok {
var err error var err error
@ -688,14 +688,33 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
return return
} }
handler(w, r, client)
}
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
s.tokenErrHelper(w, errInvalidRequest, "method not allowed", http.StatusBadRequest)
return
}
err := r.ParseForm()
if err != nil {
s.logger.Errorf("Could not parse request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}
grantType := r.PostFormValue("grant_type") grantType := r.PostFormValue("grant_type")
switch grantType { switch grantType {
case grantTypeDeviceCode:
s.handleDeviceToken(w, r)
case grantTypeAuthorizationCode: case grantTypeAuthorizationCode:
s.handleAuthCode(w, r, client) s.withClientFromStorage(w, r, s.handleAuthCode)
case grantTypeRefreshToken: case grantTypeRefreshToken:
s.handleRefreshToken(w, r, client) s.withClientFromStorage(w, r, s.handleRefreshToken)
case grantTypePassword: case grantTypePassword:
s.handlePasswordGrant(w, r, client) s.withClientFromStorage(w, r, s.handlePasswordGrant)
default: default:
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
} }

View file

@ -343,7 +343,8 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleFunc("/device", s.handleDeviceExchange) handleFunc("/device", s.handleDeviceExchange)
handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/auth/verify_code", s.verifyUserCode)
handleFunc("/device/code", s.handleDeviceCode) handleFunc("/device/code", s.handleDeviceCode)
handleFunc("/device/token", s.handleDeviceToken) // TODO(nabokihms): "/device/token" endpoint is deprecated, consider using /token endpoint instead
handleFunc("/device/token", s.handleDeviceTokenDeprecated)
handleFunc(deviceCallbackURI, s.handleDeviceCallback) handleFunc(deviceCallbackURI, s.handleDeviceCallback)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
// Strip the X-Remote-* headers to prevent security issues on // Strip the X-Remote-* headers to prevent security issues on

View file

@ -1507,8 +1507,28 @@ func TestOAuth2DeviceFlow(t *testing.T) {
var conn *mock.Callback var conn *mock.Callback
idTokensValidFor := time.Second * 30 idTokensValidFor := time.Second * 30
for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests { tests := makeOAuth2Tests(clientID, clientSecret, now)
func() { testCases := []struct {
name string
tokenEndpoint string
oauth2Tests oauth2Tests
}{
{
name: "Actual token endpoint for devices",
tokenEndpoint: "/token",
oauth2Tests: tests,
},
// TODO(nabokihms): delete temporary tests after removing the deprecated token endpoint support
{
name: "Deprecated token endpoint for devices",
tokenEndpoint: "/device/token",
oauth2Tests: tests,
},
}
for _, testCase := range testCases {
for _, tc := range testCase.oauth2Tests.tests {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -1593,7 +1613,7 @@ func TestOAuth2DeviceFlow(t *testing.T) {
// Hit the Token Endpoint, and try and get an access token // Hit the Token Endpoint, and try and get an access token
tokenURL, _ := url.Parse(issuer.String()) tokenURL, _ := url.Parse(issuer.String())
tokenURL.Path = path.Join(tokenURL.Path, "/device/token") tokenURL.Path = path.Join(tokenURL.Path, testCase.tokenEndpoint)
v := url.Values{} v := url.Values{}
v.Add("grant_type", grantTypeDeviceCode) v.Add("grant_type", grantTypeDeviceCode)
v.Add("device_code", deviceCode.DeviceCode) v.Add("device_code", deviceCode.DeviceCode)
@ -1644,6 +1664,7 @@ func TestOAuth2DeviceFlow(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("%s: %v", tc.name, err) t.Errorf("%s: %v", tc.name, err)
} }
}() })
}
} }
} }