From 0c75ed12e2feac0ecb7f95afe821420b0e598572 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Wed, 6 Jan 2021 04:22:38 +0400 Subject: [PATCH] Add refresh token expiration tests and some refactoring Signed-off-by: m.nabokikh --- server/handlers.go | 232 --------------------- server/refreshhandlers.go | 319 +++++++++++++++++++++++++++++ server/refreshhandlers_test.go | 187 +++++++++++++++++ server/server_test.go | 10 +- storage/conformance/conformance.go | 2 +- 5 files changed, 513 insertions(+), 237 deletions(-) create mode 100644 server/refreshhandlers.go create mode 100644 server/refreshhandlers_test.go diff --git a/server/handlers.go b/server/handlers.go index aa08b7a5..1db1f68f 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -1005,237 +1004,6 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil } -// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6 -func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { - code := r.PostFormValue("refresh_token") - scope := r.PostFormValue("scope") - if code == "" { - s.tokenErrHelper(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest) - return - } - - token := new(internal.RefreshToken) - if err := internal.Unmarshal(code, token); err != nil { - // For backward compatibility, assume the refresh_token is a raw refresh token ID - // if it fails to decode. - // - // Because refresh_token values that aren't unmarshable were generated by servers - // that don't have a Token value, we'll still reject any attempts to claim a - // refresh_token twice. - token = &internal.RefreshToken{RefreshId: code, Token: ""} - } - - refresh, err := s.storage.GetRefresh(token.RefreshId) - if err != nil { - s.logger.Errorf("failed to get refresh token: %v", err) - if err == storage.ErrNotFound { - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) - } else { - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - } - return - } - - if refresh.ClientID != client.ID { - s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) - return - } - if refresh.Token != token.Token { - switch { - case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed): - fallthrough - case refresh.ObsoleteToken != token.Token: - fallthrough - case refresh.ObsoleteToken == "": - s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) - return - } - } - if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { - s.logger.Errorf("refresh token with id %s expired", refresh.ID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest) - return - } - if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { - s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest) - return - } - - // Per the OAuth2 spec, if the client has omitted the scopes, default to the original - // authorized scopes. - // - // https://tools.ietf.org/html/rfc6749#section-6 - scopes := refresh.Scopes - if scope != "" { - requestedScopes := strings.Fields(scope) - var unauthorizedScopes []string - - for _, s := range requestedScopes { - contains := func() bool { - for _, scope := range refresh.Scopes { - if s == scope { - return true - } - } - return false - }() - if !contains { - unauthorizedScopes = append(unauthorizedScopes, s) - } - } - - if len(unauthorizedScopes) > 0 { - msg := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) - s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest) - return - } - scopes = requestedScopes - } - - var connectorData []byte - - session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) - switch { - case err != nil: - if err != storage.ErrNotFound { - s.logger.Errorf("failed to get offline session: %v", err) - return - } - case len(refresh.ConnectorData) > 0: - // Use the old connector data if it exists, should be deleted once used - connectorData = refresh.ConnectorData - default: - connectorData = session.ConnectorData - } - - conn, err := s.getConnector(refresh.ConnectorID) - if err != nil { - s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - ident := connector.Identity{ - UserID: refresh.Claims.UserID, - Username: refresh.Claims.Username, - PreferredUsername: refresh.Claims.PreferredUsername, - Email: refresh.Claims.Email, - EmailVerified: refresh.Claims.EmailVerified, - Groups: refresh.Claims.Groups, - ConnectorData: connectorData, - } - - // Can the connector refresh the identity? If so, attempt to refresh the data - // in the connector. - // - // TODO(ericchiang): We may want a strict mode where connectors that don't implement - // this interface can't perform refreshing. - if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { - newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident) - if err != nil { - s.logger.Errorf("failed to refresh identity: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - ident = newIdent - } - - claims := storage.Claims{ - UserID: ident.UserID, - Username: ident.Username, - PreferredUsername: ident.PreferredUsername, - Email: ident.Email, - EmailVerified: ident.EmailVerified, - Groups: ident.Groups, - } - - accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) - if err != nil { - s.logger.Errorf("failed to create new access token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID) - if err != nil { - s.logger.Errorf("failed to create ID token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - newToken := token - if s.refreshTokenPolicy.RotationEnabled() { - newToken = &internal.RefreshToken{ - RefreshId: refresh.ID, - Token: storage.NewID(), - } - } - - lastUsed := s.now() - updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { - if s.refreshTokenPolicy.RotationEnabled() { - if old.Token != refresh.Token { - if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token { - newToken.Token = old.Token - return old, nil - } - return old, errors.New("refresh token claimed twice") - } - - old.ObsoleteToken = old.Token - } - - old.Token = newToken.Token - // Update the claims of the refresh token. - // - // UserID intentionally ignored for now. - old.Claims.Username = ident.Username - old.Claims.PreferredUsername = ident.PreferredUsername - old.Claims.Email = ident.Email - old.Claims.EmailVerified = ident.EmailVerified - old.Claims.Groups = ident.Groups - old.LastUsed = lastUsed - - // ConnectorData has been moved to OfflineSession - old.ConnectorData = []byte{} - return old, nil - } - - // Update LastUsed time stamp in refresh token reference object - // in offline session for the user. - if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { - if old.Refresh[refresh.ClientID].ID != refresh.ID { - return old, errors.New("refresh token invalid") - } - old.Refresh[refresh.ClientID].LastUsed = lastUsed - old.ConnectorData = ident.ConnectorData - return old, nil - }); err != nil { - s.logger.Errorf("failed to update offline session: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - // Update refresh token in the storage. - if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil { - s.logger.Errorf("failed to update refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - rawNewToken, err := internal.Marshal(newToken) - if err != nil { - s.logger.Errorf("failed to marshal refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) - s.writeAccessToken(w, resp) -} - func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { const prefix = "Bearer " diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go new file mode 100644 index 00000000..31709ad5 --- /dev/null +++ b/server/refreshhandlers.go @@ -0,0 +1,319 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/server/internal" + "github.com/dexidp/dex/storage" +) + +func contains(arr []string, item string) bool { + for _, itemFromArray := range arr { + if itemFromArray == item { + return true + } + } + return false +} + +type refreshError struct { + msg string + code int + desc string +} + +func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) { + s.tokenErrHelper(w, err.msg, err.desc, err.code) +} + +func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) { + code := r.PostFormValue("refresh_token") + if code == "" { + return nil, &refreshError{msg: errInvalidRequest, desc: "No refresh token in request.", code: http.StatusBadRequest} + } + + token := new(internal.RefreshToken) + if err := internal.Unmarshal(code, token); err != nil { + // For backward compatibility, assume the refresh_token is a raw refresh token ID + // if it fails to decode. + // + // Because refresh_token values that aren't unmarshable were generated by servers + // that don't have a Token value, we'll still reject any attempts to claim a + // refresh_token twice. + token = &internal.RefreshToken{RefreshId: code, Token: ""} + } + + return token, nil +} + +// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info +func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (storage.RefreshToken, *refreshError) { + refresh, err := s.storage.GetRefresh(token.RefreshId) + rerr := refreshError{ + msg: errInvalidRequest, + desc: "Refresh token is invalid or has already been claimed by another client.", + code: http.StatusBadRequest, + } + + if err != nil { + s.logger.Errorf("failed to get refresh token: %v", err) + if err != storage.ErrNotFound { + return storage.RefreshToken{}, &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + } + + return storage.RefreshToken{}, &rerr + } + + if refresh.ClientID != clientID { + s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID) + return storage.RefreshToken{}, &rerr + } + + if refresh.Token != token.Token { + switch { + case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed): + fallthrough + case refresh.ObsoleteToken != token.Token: + fallthrough + case refresh.ObsoleteToken == "": + s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) + return storage.RefreshToken{}, &rerr + } + } + + rerr.desc = "Refresh token expired." + if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { + s.logger.Errorf("refresh token with id %s expired", refresh.ID) + return storage.RefreshToken{}, &rerr + } + if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { + s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) + return storage.RefreshToken{}, &rerr + } + + return refresh, nil +} + +func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) { + // Per the OAuth2 spec, if the client has omitted the scopes, default to the original + // authorized scopes. + // + // https://tools.ietf.org/html/rfc6749#section-6 + scope := r.PostFormValue("scope") + + if scope == "" { + return refresh.Scopes, nil + } + + requestedScopes := strings.Fields(scope) + var unauthorizedScopes []string + + // Per the OAuth2 spec, if the client has omitted the scopes, default to the original + // authorized scopes. + // + // https://tools.ietf.org/html/rfc6749#section-6 + for _, requestScope := range requestedScopes { + if !contains(refresh.Scopes, requestScope) { + unauthorizedScopes = append(unauthorizedScopes, requestScope) + } + } + + if len(unauthorizedScopes) > 0 { + desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) + return nil, &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} + } + + return requestedScopes, nil +} + +func (s *Server) refreshWithConnector(ctx context.Context, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { + var connectorData []byte + rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + + session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) + switch { + case err != nil: + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + // TODO: previously there was a naked return without writing anything in response, need to figure it out + return connector.Identity{}, &rerr + } + case len(refresh.ConnectorData) > 0: + // Use the old connector data if it exists, should be deleted once used + connectorData = refresh.ConnectorData + default: + connectorData = session.ConnectorData + } + + conn, err := s.getConnector(refresh.ConnectorID) + if err != nil { + s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) + return connector.Identity{}, &rerr + } + + ident := connector.Identity{ + UserID: refresh.Claims.UserID, + Username: refresh.Claims.Username, + PreferredUsername: refresh.Claims.PreferredUsername, + Email: refresh.Claims.Email, + EmailVerified: refresh.Claims.EmailVerified, + Groups: refresh.Claims.Groups, + ConnectorData: connectorData, + } + + // Can the connector refresh the identity? If so, attempt to refresh the data + // in the connector. + // + // TODO(ericchiang): We may want a strict mode where connectors that don't implement + // this interface can't perform refreshing. + if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { + newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) + if err != nil { + s.logger.Errorf("failed to refresh identity: %v", err) + return connector.Identity{}, &rerr + } + ident = newIdent + } + + return ident, nil +} + +// updateRefreshToken updates refresh token and offline session in the storage +func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) { + newToken := token + if s.refreshTokenPolicy.RotationEnabled() { + newToken = &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: storage.NewID(), + } + } + + lastUsed := s.now() + refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { + if s.refreshTokenPolicy.RotationEnabled() { + if old.Token != refresh.Token { + if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token { + newToken.Token = old.Token + return old, nil + } + return old, errors.New("refresh token claimed twice") + } + + old.ObsoleteToken = old.Token + } + + old.Token = newToken.Token + // Update the claims of the refresh token. + // + // UserID intentionally ignored for now. + old.Claims.Username = ident.Username + old.Claims.PreferredUsername = ident.PreferredUsername + old.Claims.Email = ident.Email + old.Claims.EmailVerified = ident.EmailVerified + old.Claims.Groups = ident.Groups + old.LastUsed = lastUsed + + // ConnectorData has been moved to OfflineSession + old.ConnectorData = []byte{} + return old, nil + } + + offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if old.Refresh[refresh.ClientID].ID != refresh.ID { + return old, errors.New("refresh token invalid") + } + old.Refresh[refresh.ClientID].LastUsed = lastUsed + old.ConnectorData = ident.ConnectorData + return old, nil + } + + rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + + // Update LastUsed time stamp in refresh token reference object + // in offline session for the user. + err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) + if err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + return newToken, &rerr + } + + // Update refresh token in the storage. + err = s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) + if err != nil { + s.logger.Errorf("failed to update refresh token: %v", err) + return newToken, &rerr + } + + return newToken, nil +} + +// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6 +func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { + token, rerr := s.extractRefreshTokenFromRequest(r) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + refresh, rerr := s.getRefreshTokenFromStorage(client.ID, token) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + scopes, rerr := s.getRefreshScopes(r, &refresh) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + ident, rerr := s.refreshWithConnector(r.Context(), &refresh, scopes) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + claims := storage.Claims{ + UserID: ident.UserID, + Username: ident.Username, + PreferredUsername: ident.PreferredUsername, + Email: ident.Email, + EmailVerified: ident.EmailVerified, + Groups: ident.Groups, + } + + accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) + if err != nil { + s.logger.Errorf("failed to create new access token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) + if err != nil { + s.logger.Errorf("failed to create ID token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + newToken, rerr := s.updateRefreshToken(token, &refresh, ident) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + rawNewToken, err := internal.Marshal(newToken) + if err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) + s.writeAccessToken(w, resp) +} diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go new file mode 100644 index 00000000..40e81435 --- /dev/null +++ b/server/refreshhandlers_test.go @@ -0,0 +1,187 @@ +package server + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "net/url" + "path" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/dexidp/dex/server/internal" + "github.com/dexidp/dex/storage" +) + +func TestRefreshTokenExpirationScenarios(t *testing.T) { + t0 := time.Now() + tests := []struct { + name string + policy *RefreshTokenPolicy + useObsolete bool + error string + }{ + { + name: "Normal", + policy: &RefreshTokenPolicy{rotateRefreshTokens: true}, + error: ``, + }, + { + name: "Not expired because used", + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: false, + validIfNotUsedFor: time.Second * 60, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: ``, + }, + { + name: "Expired because not used", + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: false, + validIfNotUsedFor: time.Second * 60, + now: func() time.Time { return t0.Add(time.Hour) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, + }, + { + name: "Absolutely expired", + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + absoluteLifetime: time.Second * 60, + now: func() time.Time { return t0.Add(time.Hour) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, + }, + { + name: "Obsolete tokens are not allowed", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`, + }, + { + name: "Obsolete tokens are allowed", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + reuseInterval: time.Second * 30, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: ``, + }, + { + name: "Obsolete tokens are allowed but token is expired globally", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + reuseInterval: time.Second * 30, + absoluteLifetime: time.Second * 20, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(*testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.RefreshTokenPolicy = tc.policy + c.Now = func() time.Time { return t0 } + }) + defer httpServer.Close() + + c := storage.Client{ + ID: "test", + Secret: "barfoo", + RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, + Name: "dex client", + LogoURL: "https://goo.gl/JIyzIC", + } + + err := s.storage.CreateClient(c) + require.NoError(t, err) + + c1 := storage.Connector{ + ID: "test", + Type: "mockCallback", + Name: "mockCallback", + Config: nil, + } + + err = s.storage.CreateConnector(c1) + require.NoError(t, err) + + refresh := storage.RefreshToken{ + ID: "test", + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "test", + ConnectorID: "test", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + ConnectorData: []byte(`{"some":"data"}`), + } + + if tc.useObsolete { + refresh.Token = "testtest" + refresh.ObsoleteToken = "bar" + } + + err = s.storage.CreateRefresh(refresh) + require.NoError(t, err) + + offlineSessions := storage.OfflineSessions{ + UserID: "1", + ConnID: "test", + Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}}, + ConnectorData: nil, + } + + err = s.storage.CreateOfflineSessions(offlineSessions) + require.NoError(t, err) + + u, err := url.Parse(s.issuerURL.String()) + require.NoError(t, err) + + tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"}) + require.NoError(t, err) + + u.Path = path.Join(u.Path, "/token") + v := url.Values{} + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", tokenData) + + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + req.SetBasicAuth("test", "barfoo") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + if tc.error == "" { + require.Equal(t, 200, rr.Code) + } else { + require.Equal(t, rr.Body.String(), tc.error) + } + }) + } +} diff --git a/server/server_test.go b/server/server_test.go index d8b40991..62ba40c9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -119,11 +119,13 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. // Default rotation policy - server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "") - if err != nil { - t.Fatalf("failed to prepare rotation policy: %v", err) + if server.refreshTokenPolicy == nil { + server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "") + if err != nil { + t.Fatalf("failed to prepare rotation policy: %v", err) + } + server.refreshTokenPolicy.now = config.Now } - server.refreshTokenPolicy.now = config.Now return s, server } diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 0bae52cb..dde369c4 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -381,7 +381,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { refresh2 := storage.RefreshToken{ ID: id2, Token: "bar_2", - ObsoleteToken: "bar", + ObsoleteToken: refresh.Token, Nonce: "foo_2", ClientID: "client_id_2", ConnectorID: "client_secret",