forked from mystiq/dex
Merge pull request #2010 from flant/switch-device-token-endpoint-to-token
fix: use /token endpoint to get tokens with device flow
This commit is contained in:
commit
94a2b3ed87
4 changed files with 243 additions and 196 deletions
|
@ -151,7 +151,9 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Request) {
|
||||
s.logger.Warn(`The deprecated "/device/token" endpoint was called. It will be removed, use "/token" instead.`)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
|
@ -162,71 +164,75 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
deviceCode := r.Form.Get("device_code")
|
||||
if deviceCode == "" {
|
||||
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.PostFormValue("grant_type")
|
||||
if grantType != grantTypeDeviceCode {
|
||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
now := s.now()
|
||||
|
||||
// Grab the device token, check validity
|
||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||
if err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
|
||||
return
|
||||
} else if now.After(deviceToken.Expiry) {
|
||||
s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Rate Limiting check
|
||||
slowDown := false
|
||||
pollInterval := deviceToken.PollIntervalSeconds
|
||||
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
||||
if now.Before(minRequestTime) {
|
||||
slowDown = true
|
||||
// Continually increase the poll interval until the user waits the proper time
|
||||
pollInterval += 5
|
||||
} else {
|
||||
pollInterval = 5
|
||||
}
|
||||
|
||||
switch deviceToken.Status {
|
||||
case deviceTokenPending:
|
||||
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
old.PollIntervalSeconds = pollInterval
|
||||
old.LastRequestTime = now
|
||||
return old, nil
|
||||
}
|
||||
// Update device token last request time in storage
|
||||
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
||||
s.logger.Errorf("failed to update device token: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
if slowDown {
|
||||
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
||||
} else {
|
||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||
}
|
||||
case deviceTokenComplete:
|
||||
w.Write([]byte(deviceToken.Token))
|
||||
}
|
||||
s.handleDeviceToken(w, r)
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||
deviceCode := r.Form.Get("device_code")
|
||||
if deviceCode == "" {
|
||||
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
now := s.now()
|
||||
|
||||
// Grab the device token, check validity
|
||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||
if err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
|
||||
return
|
||||
} else if now.After(deviceToken.Expiry) {
|
||||
s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Rate Limiting check
|
||||
slowDown := false
|
||||
pollInterval := deviceToken.PollIntervalSeconds
|
||||
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
||||
if now.Before(minRequestTime) {
|
||||
slowDown = true
|
||||
// Continually increase the poll interval until the user waits the proper time
|
||||
pollInterval += 5
|
||||
} else {
|
||||
pollInterval = 5
|
||||
}
|
||||
|
||||
switch deviceToken.Status {
|
||||
case deviceTokenPending:
|
||||
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
old.PollIntervalSeconds = pollInterval
|
||||
old.LastRequestTime = now
|
||||
return old, nil
|
||||
}
|
||||
// Update device token last request time in storage
|
||||
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
||||
s.logger.Errorf("failed to update device token: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
if slowDown {
|
||||
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
||||
} else {
|
||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||
}
|
||||
case deviceTokenComplete:
|
||||
w.Write([]byte(deviceToken.Token))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
|
|
|
@ -651,7 +651,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
|
|||
http.Redirect(w, r, u.String(), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client)) {
|
||||
clientID, clientSecret, ok := r.BasicAuth()
|
||||
if ok {
|
||||
var err error
|
||||
|
@ -688,14 +688,33 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
handler(w, r, client)
|
||||
}
|
||||
|
||||
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if r.Method != http.MethodPost {
|
||||
s.tokenErrHelper(w, errInvalidRequest, "method not allowed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not parse request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.PostFormValue("grant_type")
|
||||
switch grantType {
|
||||
case grantTypeDeviceCode:
|
||||
s.handleDeviceToken(w, r)
|
||||
case grantTypeAuthorizationCode:
|
||||
s.handleAuthCode(w, r, client)
|
||||
s.withClientFromStorage(w, r, s.handleAuthCode)
|
||||
case grantTypeRefreshToken:
|
||||
s.handleRefreshToken(w, r, client)
|
||||
s.withClientFromStorage(w, r, s.handleRefreshToken)
|
||||
case grantTypePassword:
|
||||
s.handlePasswordGrant(w, r, client)
|
||||
s.withClientFromStorage(w, r, s.handlePasswordGrant)
|
||||
default:
|
||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||
}
|
||||
|
|
|
@ -343,7 +343,8 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||
handleFunc("/device", s.handleDeviceExchange)
|
||||
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
||||
handleFunc("/device/code", s.handleDeviceCode)
|
||||
handleFunc("/device/token", s.handleDeviceToken)
|
||||
// TODO(nabokihms): "/device/token" endpoint is deprecated, consider using /token endpoint instead
|
||||
handleFunc("/device/token", s.handleDeviceTokenDeprecated)
|
||||
handleFunc(deviceCallbackURI, s.handleDeviceCallback)
|
||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip the X-Remote-* headers to prevent security issues on
|
||||
|
|
|
@ -1507,143 +1507,164 @@ func TestOAuth2DeviceFlow(t *testing.T) {
|
|||
var conn *mock.Callback
|
||||
idTokensValidFor := time.Second * 30
|
||||
|
||||
for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests {
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
tests := makeOAuth2Tests(clientID, clientSecret, now)
|
||||
testCases := []struct {
|
||||
name string
|
||||
tokenEndpoint string
|
||||
oauth2Tests oauth2Tests
|
||||
}{
|
||||
{
|
||||
name: "Actual token endpoint for devices",
|
||||
tokenEndpoint: "/token",
|
||||
oauth2Tests: tests,
|
||||
},
|
||||
// TODO(nabokihms): delete temporary tests after removing the deprecated token endpoint support
|
||||
{
|
||||
name: "Deprecated token endpoint for devices",
|
||||
tokenEndpoint: "/device/token",
|
||||
oauth2Tests: tests,
|
||||
},
|
||||
}
|
||||
|
||||
// Setup a dex server.
|
||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||
c.Issuer += "/non-root-path"
|
||||
c.Now = now
|
||||
c.IDTokensValidFor = idTokensValidFor
|
||||
for _, testCase := range testCases {
|
||||
for _, tc := range testCase.oauth2Tests.tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Setup a dex server.
|
||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||
c.Issuer += "/non-root-path"
|
||||
c.Now = now
|
||||
c.IDTokensValidFor = idTokensValidFor
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
mockConn := s.connectors["mock"]
|
||||
conn = mockConn.Connector.(*mock.Callback)
|
||||
|
||||
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get provider: %v", err)
|
||||
}
|
||||
|
||||
// Add the Clients to the test server
|
||||
client := storage.Client{
|
||||
ID: clientID,
|
||||
RedirectURIs: []string{deviceCallbackURI},
|
||||
Public: true,
|
||||
}
|
||||
if err := s.storage.CreateClient(client); err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// Grab the issuer that we'll reuse for the different endpoints to hit
|
||||
issuer, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
t.Errorf("Could not parse issuer URL %v", err)
|
||||
}
|
||||
|
||||
// Send a new Device Request
|
||||
codeURL, _ := url.Parse(issuer.String())
|
||||
codeURL.Path = path.Join(codeURL.Path, "device/code")
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("client_id", clientID)
|
||||
data.Add("scope", strings.Join(requestedScopes, " "))
|
||||
resp, err := http.PostForm(codeURL.String(), data)
|
||||
if err != nil {
|
||||
t.Errorf("Could not request device code: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read device code response %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||
}
|
||||
if resp.Header.Get("Cache-Control") != "no-store" {
|
||||
t.Errorf("Cache-Control header doesn't exist in Device Code Response")
|
||||
}
|
||||
|
||||
// Parse the code response
|
||||
var deviceCode deviceCodeResponse
|
||||
if err := json.Unmarshal(responseBody, &deviceCode); err != nil {
|
||||
t.Errorf("Unexpected Device Code Response Format %v", string(responseBody))
|
||||
}
|
||||
|
||||
// Mock the user hitting the verification URI and posting the form
|
||||
verifyURL, _ := url.Parse(issuer.String())
|
||||
verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code")
|
||||
urlData := url.Values{}
|
||||
urlData.Set("user_code", deviceCode.UserCode)
|
||||
resp, err = http.PostForm(verifyURL.String(), urlData)
|
||||
if err != nil {
|
||||
t.Errorf("Error Posting Form: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read verification response %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
// Hit the Token Endpoint, and try and get an access token
|
||||
tokenURL, _ := url.Parse(issuer.String())
|
||||
tokenURL.Path = path.Join(tokenURL.Path, testCase.tokenEndpoint)
|
||||
v := url.Values{}
|
||||
v.Add("grant_type", grantTypeDeviceCode)
|
||||
v.Add("device_code", deviceCode.DeviceCode)
|
||||
resp, err = http.PostForm(tokenURL.String(), v)
|
||||
if err != nil {
|
||||
t.Errorf("Could not request device token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read device token response %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
// Parse the response
|
||||
var tokenRes accessTokenResponse
|
||||
if err := json.Unmarshal(responseBody, &tokenRes); err != nil {
|
||||
t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody))
|
||||
}
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: tokenRes.AccessToken,
|
||||
TokenType: tokenRes.TokenType,
|
||||
RefreshToken: tokenRes.RefreshToken,
|
||||
}
|
||||
raw := make(map[string]interface{})
|
||||
json.Unmarshal(responseBody, &raw) // no error checks for optional fields
|
||||
token = token.WithExtra(raw)
|
||||
if secs := tokenRes.ExpiresIn; secs > 0 {
|
||||
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
|
||||
}
|
||||
|
||||
// Run token tests to validate info is correct
|
||||
// Create the OAuth2 config.
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: client.ID,
|
||||
ClientSecret: client.Secret,
|
||||
Endpoint: p.Endpoint(),
|
||||
Scopes: requestedScopes,
|
||||
RedirectURL: deviceCallbackURI,
|
||||
}
|
||||
if len(tc.scopes) != 0 {
|
||||
oauth2Config.Scopes = tc.scopes
|
||||
}
|
||||
err = tc.handleToken(ctx, p, oauth2Config, token, conn)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", tc.name, err)
|
||||
}
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
mockConn := s.connectors["mock"]
|
||||
conn = mockConn.Connector.(*mock.Callback)
|
||||
|
||||
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get provider: %v", err)
|
||||
}
|
||||
|
||||
// Add the Clients to the test server
|
||||
client := storage.Client{
|
||||
ID: clientID,
|
||||
RedirectURIs: []string{deviceCallbackURI},
|
||||
Public: true,
|
||||
}
|
||||
if err := s.storage.CreateClient(client); err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// Grab the issuer that we'll reuse for the different endpoints to hit
|
||||
issuer, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
t.Errorf("Could not parse issuer URL %v", err)
|
||||
}
|
||||
|
||||
// Send a new Device Request
|
||||
codeURL, _ := url.Parse(issuer.String())
|
||||
codeURL.Path = path.Join(codeURL.Path, "device/code")
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("client_id", clientID)
|
||||
data.Add("scope", strings.Join(requestedScopes, " "))
|
||||
resp, err := http.PostForm(codeURL.String(), data)
|
||||
if err != nil {
|
||||
t.Errorf("Could not request device code: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read device code response %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||
}
|
||||
if resp.Header.Get("Cache-Control") != "no-store" {
|
||||
t.Errorf("Cache-Control header doesn't exist in Device Code Response")
|
||||
}
|
||||
|
||||
// Parse the code response
|
||||
var deviceCode deviceCodeResponse
|
||||
if err := json.Unmarshal(responseBody, &deviceCode); err != nil {
|
||||
t.Errorf("Unexpected Device Code Response Format %v", string(responseBody))
|
||||
}
|
||||
|
||||
// Mock the user hitting the verification URI and posting the form
|
||||
verifyURL, _ := url.Parse(issuer.String())
|
||||
verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code")
|
||||
urlData := url.Values{}
|
||||
urlData.Set("user_code", deviceCode.UserCode)
|
||||
resp, err = http.PostForm(verifyURL.String(), urlData)
|
||||
if err != nil {
|
||||
t.Errorf("Error Posting Form: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read verification response %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
// Hit the Token Endpoint, and try and get an access token
|
||||
tokenURL, _ := url.Parse(issuer.String())
|
||||
tokenURL.Path = path.Join(tokenURL.Path, "/device/token")
|
||||
v := url.Values{}
|
||||
v.Add("grant_type", grantTypeDeviceCode)
|
||||
v.Add("device_code", deviceCode.DeviceCode)
|
||||
resp, err = http.PostForm(tokenURL.String(), v)
|
||||
if err != nil {
|
||||
t.Errorf("Could not request device token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read device token response %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
// Parse the response
|
||||
var tokenRes accessTokenResponse
|
||||
if err := json.Unmarshal(responseBody, &tokenRes); err != nil {
|
||||
t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody))
|
||||
}
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: tokenRes.AccessToken,
|
||||
TokenType: tokenRes.TokenType,
|
||||
RefreshToken: tokenRes.RefreshToken,
|
||||
}
|
||||
raw := make(map[string]interface{})
|
||||
json.Unmarshal(responseBody, &raw) // no error checks for optional fields
|
||||
token = token.WithExtra(raw)
|
||||
if secs := tokenRes.ExpiresIn; secs > 0 {
|
||||
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
|
||||
}
|
||||
|
||||
// Run token tests to validate info is correct
|
||||
// Create the OAuth2 config.
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: client.ID,
|
||||
ClientSecret: client.Secret,
|
||||
Endpoint: p.Endpoint(),
|
||||
Scopes: requestedScopes,
|
||||
RedirectURL: deviceCallbackURI,
|
||||
}
|
||||
if len(tc.scopes) != 0 {
|
||||
oauth2Config.Scopes = tc.scopes
|
||||
}
|
||||
err = tc.handleToken(ctx, p, oauth2Config, token, conn)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", tc.name, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue