From 89295a5b4ad96aa15d978926e99a9c3183329eac Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Fri, 15 Jan 2021 01:15:56 +0400 Subject: [PATCH] More refresh token handler refactoring, more tests Signed-off-by: m.nabokikh --- server/refreshhandlers.go | 115 +++++++++++++----------- server/refreshhandlers_test.go | 159 +++++++++++++++++++-------------- 2 files changed, 156 insertions(+), 118 deletions(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 311eb30a..588b91f1 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/dexidp/dex/connector" "github.com/dexidp/dex/server/internal" @@ -27,6 +28,12 @@ type refreshError struct { desc string } +var internalErr = &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + +func newBadRequestError(desc string) *refreshError { + return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} +} + func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) { s.tokenErrHelper(w, err.msg, err.desc, err.code) } @@ -34,7 +41,7 @@ func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) { code := r.PostFormValue("refresh_token") if code == "" { - return nil, &refreshError{msg: errInvalidRequest, desc: "No refresh token in request.", code: http.StatusBadRequest} + return nil, newBadRequestError("No refresh token is found in request.") } token := new(internal.RefreshToken) @@ -52,26 +59,22 @@ func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.Refr } // getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info -func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (storage.RefreshToken, *refreshError) { - refresh, err := s.storage.GetRefresh(token.RefreshId) - rerr := refreshError{ - msg: errInvalidRequest, - desc: "Refresh token is invalid or has already been claimed by another client.", - code: http.StatusBadRequest, - } +func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) { + invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.") + refresh, err := s.storage.GetRefresh(token.RefreshId) if err != nil { s.logger.Errorf("failed to get refresh token: %v", err) if err != storage.ErrNotFound { - return storage.RefreshToken{}, &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + return nil, internalErr } - return storage.RefreshToken{}, &rerr + return nil, invalidErr } if refresh.ClientID != clientID { s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID) - return storage.RefreshToken{}, &rerr + return nil, invalidErr } if refresh.Token != token.Token { @@ -82,22 +85,22 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref fallthrough case refresh.ObsoleteToken == "": s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) - return storage.RefreshToken{}, &rerr + return nil, invalidErr } } - rerr.desc = "Refresh token expired." + expiredErr := newBadRequestError("Refresh token expired.") if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { s.logger.Errorf("refresh token with id %s expired", refresh.ID) - return storage.RefreshToken{}, &rerr + return nil, expiredErr } if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) - return storage.RefreshToken{}, &rerr + return nil, expiredErr } - return refresh, nil + return &refresh, nil } func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) { @@ -126,7 +129,7 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken if len(unauthorizedScopes) > 0 { desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) - return nil, &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} + return nil, newBadRequestError(desc) } return requestedScopes, nil @@ -134,15 +137,15 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { var connectorData []byte - rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) switch { case err != nil: if err != storage.ErrNotFound { s.logger.Errorf("failed to get offline session: %v", err) - // TODO: previously there was a naked return without writing anything in response, need to figure it out - return connector.Identity{}, &rerr + // TODO: previously there was a naked return without writing anything in response + // Need to ensure that everything works as expected. + return connector.Identity{}, internalErr } case len(refresh.ConnectorData) > 0: // Use the old connector data if it exists, should be deleted once used @@ -154,7 +157,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre conn, err := s.getConnector(refresh.ConnectorID) if err != nil { s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) - return connector.Identity{}, &rerr + return connector.Identity{}, internalErr } ident := connector.Identity{ @@ -182,7 +185,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) if err != nil { s.logger.Errorf("failed to refresh identity: %v", err) - return connector.Identity{}, &rerr + return connector.Identity{}, internalErr } ident = newIdent } @@ -190,6 +193,28 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre return ident, nil } +// updateOfflineSession updates offline session in the storage +func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError { + offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if old.Refresh[refresh.ClientID].ID != refresh.ID { + return old, errors.New("refresh token invalid") + } + old.Refresh[refresh.ClientID].LastUsed = lastUsed + old.ConnectorData = ident.ConnectorData + return old, nil + } + + // Update LastUsed time stamp in refresh token reference object + // in offline session for the user. + err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) + if err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + return internalErr + } + + return nil +} + // updateRefreshToken updates refresh token and offline session in the storage func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) { newToken := token @@ -201,10 +226,16 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora } lastUsed := s.now() + + rerr := s.updateOfflineSession(refresh, ident, lastUsed) + if rerr != nil { + return nil, rerr + } + refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { if s.refreshTokenPolicy.RotationEnabled() { - if old.Token != refresh.Token { - if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token { + if old.Token != token.Token { + if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token { newToken.Token = old.Token return old, nil } @@ -230,36 +261,18 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora return old, nil } - offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { - if old.Refresh[refresh.ClientID].ID != refresh.ID { - return old, errors.New("refresh token invalid") - } - old.Refresh[refresh.ClientID].LastUsed = lastUsed - old.ConnectorData = ident.ConnectorData - return old, nil - } - - rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} - - // Update LastUsed time stamp in refresh token reference object - // in offline session for the user. - err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) - if err != nil { - s.logger.Errorf("failed to update offline session: %v", err) - return newToken, &rerr - } - // Update refresh token in the storage. - err = s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) + err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) if err != nil { s.logger.Errorf("failed to update refresh token: %v", err) - return newToken, &rerr + return nil, internalErr } return newToken, nil } // handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6 +// this method is the entrypoint for refresh tokens handling func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { token, rerr := s.extractRefreshTokenFromRequest(r) if rerr != nil { @@ -273,13 +286,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - scopes, rerr := s.getRefreshScopes(r, &refresh) + scopes, rerr := s.getRefreshScopes(r, refresh) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return } - ident, rerr := s.refreshWithConnector(r.Context(), token, &refresh, scopes) + ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return @@ -297,18 +310,18 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + s.refreshTokenErrHelper(w, internalErr) return } idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + s.refreshTokenErrHelper(w, internalErr) return } - newToken, rerr := s.updateRefreshToken(token, &refresh, ident) + newToken, rerr := s.updateRefreshToken(token, refresh, ident) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return @@ -317,7 +330,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie rawNewToken, err := internal.Marshal(newToken) if err != nil { s.logger.Errorf("failed to marshal refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + s.refreshTokenErrHelper(w, internalErr) return } diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go index 40e81435..c64c50b3 100644 --- a/server/refreshhandlers_test.go +++ b/server/refreshhandlers_test.go @@ -3,6 +3,7 @@ package server import ( "bytes" "context" + "encoding/json" "net/http" "net/http/httptest" "net/url" @@ -16,6 +17,67 @@ import ( "github.com/dexidp/dex/storage" ) +func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) { + c := storage.Client{ + ID: "test", + Secret: "barfoo", + RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, + Name: "dex client", + LogoURL: "https://goo.gl/JIyzIC", + } + + err := s.CreateClient(c) + require.NoError(t, err) + + c1 := storage.Connector{ + ID: "test", + Type: "mockCallback", + Name: "mockCallback", + Config: nil, + } + + err = s.CreateConnector(c1) + require.NoError(t, err) + + refresh := storage.RefreshToken{ + ID: "test", + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "test", + ConnectorID: "test", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + ConnectorData: []byte(`{"some":"data"}`), + } + + if useObsolete { + refresh.Token = "testtest" + refresh.ObsoleteToken = "bar" + } + + err = s.CreateRefresh(refresh) + require.NoError(t, err) + + offlineSessions := storage.OfflineSessions{ + UserID: "1", + ConnID: "test", + Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}}, + ConnectorData: nil, + } + + err = s.CreateOfflineSessions(offlineSessions) + require.NoError(t, err) +} + func TestRefreshTokenExpirationScenarios(t *testing.T) { t0 := time.Now() tests := []struct { @@ -56,15 +118,6 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { }, error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, }, - { - name: "Obsolete tokens are not allowed", - useObsolete: true, - policy: &RefreshTokenPolicy{ - rotateRefreshTokens: true, - now: func() time.Time { return t0.Add(time.Second * 25) }, - }, - error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`, - }, { name: "Obsolete tokens are allowed", useObsolete: true, @@ -75,6 +128,15 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { }, error: ``, }, + { + name: "Obsolete tokens are not allowed", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`, + }, { name: "Obsolete tokens are allowed but token is expired globally", useObsolete: true, @@ -100,64 +162,7 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { }) defer httpServer.Close() - c := storage.Client{ - ID: "test", - Secret: "barfoo", - RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, - Name: "dex client", - LogoURL: "https://goo.gl/JIyzIC", - } - - err := s.storage.CreateClient(c) - require.NoError(t, err) - - c1 := storage.Connector{ - ID: "test", - Type: "mockCallback", - Name: "mockCallback", - Config: nil, - } - - err = s.storage.CreateConnector(c1) - require.NoError(t, err) - - refresh := storage.RefreshToken{ - ID: "test", - Token: "bar", - ObsoleteToken: "", - Nonce: "foo", - ClientID: "test", - ConnectorID: "test", - Scopes: []string{"openid", "email", "profile"}, - CreatedAt: time.Now().UTC().Round(time.Millisecond), - LastUsed: time.Now().UTC().Round(time.Millisecond), - Claims: storage.Claims{ - UserID: "1", - Username: "jane", - Email: "jane.doe@example.com", - EmailVerified: true, - Groups: []string{"a", "b"}, - }, - ConnectorData: []byte(`{"some":"data"}`), - } - - if tc.useObsolete { - refresh.Token = "testtest" - refresh.ObsoleteToken = "bar" - } - - err = s.storage.CreateRefresh(refresh) - require.NoError(t, err) - - offlineSessions := storage.OfflineSessions{ - UserID: "1", - ConnID: "test", - Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}}, - ConnectorData: nil, - } - - err = s.storage.CreateOfflineSessions(offlineSessions) - require.NoError(t, err) + mockRefreshTokenTestStorage(t, s.storage, tc.useObsolete) u, err := url.Parse(s.issuerURL.String()) require.NoError(t, err) @@ -181,6 +186,26 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { require.Equal(t, 200, rr.Code) } else { require.Equal(t, rr.Body.String(), tc.error) + return + } + + // Check that we received expected refresh token + var ref struct { + Token string `json:"refresh_token"` + } + err = json.Unmarshal(rr.Body.Bytes(), &ref) + require.NoError(t, err) + + if tc.policy.rotateRefreshTokens == false { + require.Equal(t, tokenData, ref.Token) + } else { + require.NotEqual(t, tokenData, ref.Token) + } + + if tc.useObsolete { + updatedTokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "testtest"}) + require.NoError(t, err) + require.Equal(t, updatedTokenData, ref.Token) } }) }