forked from mystiq/dex
More refresh token handler refactoring, more tests
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
parent
4e73f39f57
commit
89295a5b4a
2 changed files with 156 additions and 118 deletions
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/dexidp/dex/connector"
|
"github.com/dexidp/dex/connector"
|
||||||
"github.com/dexidp/dex/server/internal"
|
"github.com/dexidp/dex/server/internal"
|
||||||
|
@ -27,6 +28,12 @@ type refreshError struct {
|
||||||
desc string
|
desc string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var internalErr = &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
|
||||||
|
|
||||||
|
func newBadRequestError(desc string) *refreshError {
|
||||||
|
return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) {
|
func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) {
|
||||||
s.tokenErrHelper(w, err.msg, err.desc, err.code)
|
s.tokenErrHelper(w, err.msg, err.desc, err.code)
|
||||||
}
|
}
|
||||||
|
@ -34,7 +41,7 @@ func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError)
|
||||||
func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) {
|
func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) {
|
||||||
code := r.PostFormValue("refresh_token")
|
code := r.PostFormValue("refresh_token")
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, &refreshError{msg: errInvalidRequest, desc: "No refresh token in request.", code: http.StatusBadRequest}
|
return nil, newBadRequestError("No refresh token is found in request.")
|
||||||
}
|
}
|
||||||
|
|
||||||
token := new(internal.RefreshToken)
|
token := new(internal.RefreshToken)
|
||||||
|
@ -52,26 +59,22 @@ func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.Refr
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info
|
// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info
|
||||||
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (storage.RefreshToken, *refreshError) {
|
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) {
|
||||||
refresh, err := s.storage.GetRefresh(token.RefreshId)
|
invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.")
|
||||||
rerr := refreshError{
|
|
||||||
msg: errInvalidRequest,
|
|
||||||
desc: "Refresh token is invalid or has already been claimed by another client.",
|
|
||||||
code: http.StatusBadRequest,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
refresh, err := s.storage.GetRefresh(token.RefreshId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to get refresh token: %v", err)
|
s.logger.Errorf("failed to get refresh token: %v", err)
|
||||||
if err != storage.ErrNotFound {
|
if err != storage.ErrNotFound {
|
||||||
return storage.RefreshToken{}, &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
|
return nil, internalErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return storage.RefreshToken{}, &rerr
|
return nil, invalidErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if refresh.ClientID != clientID {
|
if refresh.ClientID != clientID {
|
||||||
s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID)
|
s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID)
|
||||||
return storage.RefreshToken{}, &rerr
|
return nil, invalidErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if refresh.Token != token.Token {
|
if refresh.Token != token.Token {
|
||||||
|
@ -82,22 +85,22 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref
|
||||||
fallthrough
|
fallthrough
|
||||||
case refresh.ObsoleteToken == "":
|
case refresh.ObsoleteToken == "":
|
||||||
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
|
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
|
||||||
return storage.RefreshToken{}, &rerr
|
return nil, invalidErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rerr.desc = "Refresh token expired."
|
expiredErr := newBadRequestError("Refresh token expired.")
|
||||||
if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) {
|
if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) {
|
||||||
s.logger.Errorf("refresh token with id %s expired", refresh.ID)
|
s.logger.Errorf("refresh token with id %s expired", refresh.ID)
|
||||||
return storage.RefreshToken{}, &rerr
|
return nil, expiredErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) {
|
if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) {
|
||||||
s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID)
|
s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID)
|
||||||
return storage.RefreshToken{}, &rerr
|
return nil, expiredErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return refresh, nil
|
return &refresh, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) {
|
func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) {
|
||||||
|
@ -126,7 +129,7 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken
|
||||||
|
|
||||||
if len(unauthorizedScopes) > 0 {
|
if len(unauthorizedScopes) > 0 {
|
||||||
desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes)
|
desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes)
|
||||||
return nil, &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest}
|
return nil, newBadRequestError(desc)
|
||||||
}
|
}
|
||||||
|
|
||||||
return requestedScopes, nil
|
return requestedScopes, nil
|
||||||
|
@ -134,15 +137,15 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken
|
||||||
|
|
||||||
func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) {
|
func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) {
|
||||||
var connectorData []byte
|
var connectorData []byte
|
||||||
rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
|
|
||||||
|
|
||||||
session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
|
session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
if err != storage.ErrNotFound {
|
if err != storage.ErrNotFound {
|
||||||
s.logger.Errorf("failed to get offline session: %v", err)
|
s.logger.Errorf("failed to get offline session: %v", err)
|
||||||
// TODO: previously there was a naked return without writing anything in response, need to figure it out
|
// TODO: previously there was a naked return without writing anything in response
|
||||||
return connector.Identity{}, &rerr
|
// Need to ensure that everything works as expected.
|
||||||
|
return connector.Identity{}, internalErr
|
||||||
}
|
}
|
||||||
case len(refresh.ConnectorData) > 0:
|
case len(refresh.ConnectorData) > 0:
|
||||||
// Use the old connector data if it exists, should be deleted once used
|
// Use the old connector data if it exists, should be deleted once used
|
||||||
|
@ -154,7 +157,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre
|
||||||
conn, err := s.getConnector(refresh.ConnectorID)
|
conn, err := s.getConnector(refresh.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
|
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
|
||||||
return connector.Identity{}, &rerr
|
return connector.Identity{}, internalErr
|
||||||
}
|
}
|
||||||
|
|
||||||
ident := connector.Identity{
|
ident := connector.Identity{
|
||||||
|
@ -182,7 +185,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre
|
||||||
newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident)
|
newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to refresh identity: %v", err)
|
s.logger.Errorf("failed to refresh identity: %v", err)
|
||||||
return connector.Identity{}, &rerr
|
return connector.Identity{}, internalErr
|
||||||
}
|
}
|
||||||
ident = newIdent
|
ident = newIdent
|
||||||
}
|
}
|
||||||
|
@ -190,6 +193,28 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre
|
||||||
return ident, nil
|
return ident, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateOfflineSession updates offline session in the storage
|
||||||
|
func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError {
|
||||||
|
offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
||||||
|
if old.Refresh[refresh.ClientID].ID != refresh.ID {
|
||||||
|
return old, errors.New("refresh token invalid")
|
||||||
|
}
|
||||||
|
old.Refresh[refresh.ClientID].LastUsed = lastUsed
|
||||||
|
old.ConnectorData = ident.ConnectorData
|
||||||
|
return old, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update LastUsed time stamp in refresh token reference object
|
||||||
|
// in offline session for the user.
|
||||||
|
err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to update offline session: %v", err)
|
||||||
|
return internalErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// updateRefreshToken updates refresh token and offline session in the storage
|
// updateRefreshToken updates refresh token and offline session in the storage
|
||||||
func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) {
|
func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) {
|
||||||
newToken := token
|
newToken := token
|
||||||
|
@ -201,10 +226,16 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora
|
||||||
}
|
}
|
||||||
|
|
||||||
lastUsed := s.now()
|
lastUsed := s.now()
|
||||||
|
|
||||||
|
rerr := s.updateOfflineSession(refresh, ident, lastUsed)
|
||||||
|
if rerr != nil {
|
||||||
|
return nil, rerr
|
||||||
|
}
|
||||||
|
|
||||||
refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
if s.refreshTokenPolicy.RotationEnabled() {
|
if s.refreshTokenPolicy.RotationEnabled() {
|
||||||
if old.Token != refresh.Token {
|
if old.Token != token.Token {
|
||||||
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token {
|
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token {
|
||||||
newToken.Token = old.Token
|
newToken.Token = old.Token
|
||||||
return old, nil
|
return old, nil
|
||||||
}
|
}
|
||||||
|
@ -230,36 +261,18 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora
|
||||||
return old, nil
|
return old, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
|
||||||
if old.Refresh[refresh.ClientID].ID != refresh.ID {
|
|
||||||
return old, errors.New("refresh token invalid")
|
|
||||||
}
|
|
||||||
old.Refresh[refresh.ClientID].LastUsed = lastUsed
|
|
||||||
old.ConnectorData = ident.ConnectorData
|
|
||||||
return old, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
|
|
||||||
|
|
||||||
// Update LastUsed time stamp in refresh token reference object
|
|
||||||
// in offline session for the user.
|
|
||||||
err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater)
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("failed to update offline session: %v", err)
|
|
||||||
return newToken, &rerr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update refresh token in the storage.
|
// Update refresh token in the storage.
|
||||||
err = s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater)
|
err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to update refresh token: %v", err)
|
s.logger.Errorf("failed to update refresh token: %v", err)
|
||||||
return newToken, &rerr
|
return nil, internalErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return newToken, nil
|
return newToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
||||||
|
// this method is the entrypoint for refresh tokens handling
|
||||||
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
|
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
|
||||||
token, rerr := s.extractRefreshTokenFromRequest(r)
|
token, rerr := s.extractRefreshTokenFromRequest(r)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
|
@ -273,13 +286,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes, rerr := s.getRefreshScopes(r, &refresh)
|
scopes, rerr := s.getRefreshScopes(r, refresh)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
s.refreshTokenErrHelper(w, rerr)
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ident, rerr := s.refreshWithConnector(r.Context(), token, &refresh, scopes)
|
ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
s.refreshTokenErrHelper(w, rerr)
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
return
|
return
|
||||||
|
@ -297,18 +310,18 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||||
accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
|
accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to create new access token: %v", err)
|
s.logger.Errorf("failed to create new access token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.refreshTokenErrHelper(w, internalErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID)
|
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to create ID token: %v", err)
|
s.logger.Errorf("failed to create ID token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.refreshTokenErrHelper(w, internalErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newToken, rerr := s.updateRefreshToken(token, &refresh, ident)
|
newToken, rerr := s.updateRefreshToken(token, refresh, ident)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
s.refreshTokenErrHelper(w, rerr)
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
return
|
return
|
||||||
|
@ -317,7 +330,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||||
rawNewToken, err := internal.Marshal(newToken)
|
rawNewToken, err := internal.Marshal(newToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.refreshTokenErrHelper(w, internalErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package server
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -16,6 +17,67 @@ import (
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) {
|
||||||
|
c := storage.Client{
|
||||||
|
ID: "test",
|
||||||
|
Secret: "barfoo",
|
||||||
|
RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
|
||||||
|
Name: "dex client",
|
||||||
|
LogoURL: "https://goo.gl/JIyzIC",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.CreateClient(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c1 := storage.Connector{
|
||||||
|
ID: "test",
|
||||||
|
Type: "mockCallback",
|
||||||
|
Name: "mockCallback",
|
||||||
|
Config: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.CreateConnector(c1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
refresh := storage.RefreshToken{
|
||||||
|
ID: "test",
|
||||||
|
Token: "bar",
|
||||||
|
ObsoleteToken: "",
|
||||||
|
Nonce: "foo",
|
||||||
|
ClientID: "test",
|
||||||
|
ConnectorID: "test",
|
||||||
|
Scopes: []string{"openid", "email", "profile"},
|
||||||
|
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: "1",
|
||||||
|
Username: "jane",
|
||||||
|
Email: "jane.doe@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Groups: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
if useObsolete {
|
||||||
|
refresh.Token = "testtest"
|
||||||
|
refresh.ObsoleteToken = "bar"
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.CreateRefresh(refresh)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
offlineSessions := storage.OfflineSessions{
|
||||||
|
UserID: "1",
|
||||||
|
ConnID: "test",
|
||||||
|
Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}},
|
||||||
|
ConnectorData: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.CreateOfflineSessions(offlineSessions)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestRefreshTokenExpirationScenarios(t *testing.T) {
|
func TestRefreshTokenExpirationScenarios(t *testing.T) {
|
||||||
t0 := time.Now()
|
t0 := time.Now()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -56,15 +118,6 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) {
|
||||||
},
|
},
|
||||||
error: `{"error":"invalid_request","error_description":"Refresh token expired."}`,
|
error: `{"error":"invalid_request","error_description":"Refresh token expired."}`,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "Obsolete tokens are not allowed",
|
|
||||||
useObsolete: true,
|
|
||||||
policy: &RefreshTokenPolicy{
|
|
||||||
rotateRefreshTokens: true,
|
|
||||||
now: func() time.Time { return t0.Add(time.Second * 25) },
|
|
||||||
},
|
|
||||||
error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "Obsolete tokens are allowed",
|
name: "Obsolete tokens are allowed",
|
||||||
useObsolete: true,
|
useObsolete: true,
|
||||||
|
@ -75,6 +128,15 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) {
|
||||||
},
|
},
|
||||||
error: ``,
|
error: ``,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Obsolete tokens are not allowed",
|
||||||
|
useObsolete: true,
|
||||||
|
policy: &RefreshTokenPolicy{
|
||||||
|
rotateRefreshTokens: true,
|
||||||
|
now: func() time.Time { return t0.Add(time.Second * 25) },
|
||||||
|
},
|
||||||
|
error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Obsolete tokens are allowed but token is expired globally",
|
name: "Obsolete tokens are allowed but token is expired globally",
|
||||||
useObsolete: true,
|
useObsolete: true,
|
||||||
|
@ -100,64 +162,7 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer httpServer.Close()
|
defer httpServer.Close()
|
||||||
|
|
||||||
c := storage.Client{
|
mockRefreshTokenTestStorage(t, s.storage, tc.useObsolete)
|
||||||
ID: "test",
|
|
||||||
Secret: "barfoo",
|
|
||||||
RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
|
|
||||||
Name: "dex client",
|
|
||||||
LogoURL: "https://goo.gl/JIyzIC",
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.storage.CreateClient(c)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
c1 := storage.Connector{
|
|
||||||
ID: "test",
|
|
||||||
Type: "mockCallback",
|
|
||||||
Name: "mockCallback",
|
|
||||||
Config: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.storage.CreateConnector(c1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
refresh := storage.RefreshToken{
|
|
||||||
ID: "test",
|
|
||||||
Token: "bar",
|
|
||||||
ObsoleteToken: "",
|
|
||||||
Nonce: "foo",
|
|
||||||
ClientID: "test",
|
|
||||||
ConnectorID: "test",
|
|
||||||
Scopes: []string{"openid", "email", "profile"},
|
|
||||||
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
|
||||||
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
|
||||||
Claims: storage.Claims{
|
|
||||||
UserID: "1",
|
|
||||||
Username: "jane",
|
|
||||||
Email: "jane.doe@example.com",
|
|
||||||
EmailVerified: true,
|
|
||||||
Groups: []string{"a", "b"},
|
|
||||||
},
|
|
||||||
ConnectorData: []byte(`{"some":"data"}`),
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.useObsolete {
|
|
||||||
refresh.Token = "testtest"
|
|
||||||
refresh.ObsoleteToken = "bar"
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.storage.CreateRefresh(refresh)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
offlineSessions := storage.OfflineSessions{
|
|
||||||
UserID: "1",
|
|
||||||
ConnID: "test",
|
|
||||||
Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}},
|
|
||||||
ConnectorData: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.storage.CreateOfflineSessions(offlineSessions)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
u, err := url.Parse(s.issuerURL.String())
|
u, err := url.Parse(s.issuerURL.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -181,6 +186,26 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) {
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
} else {
|
} else {
|
||||||
require.Equal(t, rr.Body.String(), tc.error)
|
require.Equal(t, rr.Body.String(), tc.error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we received expected refresh token
|
||||||
|
var ref struct {
|
||||||
|
Token string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(rr.Body.Bytes(), &ref)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tc.policy.rotateRefreshTokens == false {
|
||||||
|
require.Equal(t, tokenData, ref.Token)
|
||||||
|
} else {
|
||||||
|
require.NotEqual(t, tokenData, ref.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.useObsolete {
|
||||||
|
updatedTokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "testtest"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, updatedTokenData, ref.Token)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue