feat: Add refresh token expiration and rotation settings

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh 2020-10-28 10:26:34 +04:00
parent 10597cf09f
commit 91de99d57e
14 changed files with 226 additions and 42 deletions

View file

@ -304,6 +304,9 @@ type Expiry struct {
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid. // DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
DeviceRequests string `json:"deviceRequests"` DeviceRequests string `json:"deviceRequests"`
// RefreshToken defines refresh tokens expiry policy
RefreshToken RefreshTokenExpiry `json:"refreshTokens"`
} }
// Logger holds configuration required to customize logging for dex. // Logger holds configuration required to customize logging for dex.
@ -314,3 +317,10 @@ type Logger struct {
// Format specifies the format to be used for logging. // Format specifies the format to be used for logging.
Format string `json:"format"` Format string `json:"format"`
} }
type RefreshTokenExpiry struct {
DisableRotation bool `json:"disableRotation"`
ReuseInterval string `json:"reuseInterval"`
AbsoluteLifetime string `json:"absoluteLifetime"`
ValidIfNotUsedFor string `json:"validIfNotUsedFor"`
}

View file

@ -317,6 +317,18 @@ func runServe(options serveOptions) error {
logger.Infof("config device requests valid for: %v", deviceRequests) logger.Infof("config device requests valid for: %v", deviceRequests)
serverConfig.DeviceRequestsValidFor = deviceRequests serverConfig.DeviceRequestsValidFor = deviceRequests
} }
refreshTokenPolicy, err := server.NewRefreshTokenPolicyFromConfig(
logger,
c.Expiry.RefreshToken.DisableRotation,
c.Expiry.RefreshToken.ValidIfNotUsedFor,
c.Expiry.RefreshToken.AbsoluteLifetime,
c.Expiry.RefreshToken.ReuseInterval,
)
if err != nil {
return fmt.Errorf("invalid refresh token expiration policy config: %v", err)
}
serverConfig.RefreshTokenPolicy = refreshTokenPolicy
serv, err := server.NewServer(context.Background(), serverConfig) serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize server: %v", err) return fmt.Errorf("failed to initialize server: %v", err)

View file

@ -77,6 +77,10 @@ telemetry:
# deviceRequests: "5m" # deviceRequests: "5m"
# signingKeys: "6h" # signingKeys: "6h"
# idTokens: "24h" # idTokens: "24h"
# refreshTokens:
# reuseInterval: "3s"
# validIfNotUsedFor: "2190h"
# absoluteLifetime: "5000h"
# Options for controlling the logger. # Options for controlling the logger.
# logger: # logger:

View file

@ -1035,14 +1035,27 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
} }
return return
} }
if refresh.ClientID != client.ID { if refresh.ClientID != client.ID {
s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID) s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
return return
} }
if refresh.Token != token.Token { if refresh.Token != token.Token {
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) if !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) || refresh.ObsoleteToken != token.Token {
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
return
}
}
if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) {
s.logger.Errorf("refresh token with id %s expired", refresh.ID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest)
return
}
if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) {
s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest)
return return
} }
@ -1147,22 +1160,28 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
newToken := &internal.RefreshToken{ newToken := token
RefreshId: refresh.ID, if s.refreshTokenPolicy.RotationEnabled() {
Token: storage.NewID(), newToken = &internal.RefreshToken{
} RefreshId: refresh.ID,
rawNewToken, err := internal.Marshal(newToken) Token: storage.NewID(),
if err != nil { }
s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
} }
lastUsed := s.now() lastUsed := s.now()
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
if old.Token != refresh.Token { if s.refreshTokenPolicy.RotationEnabled() {
return old, errors.New("refresh token claimed twice") if old.Token != refresh.Token {
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token {
newToken.Token = old.Token
return old, nil
}
return old, errors.New("refresh token claimed twice")
}
old.ObsoleteToken = old.Token
} }
old.Token = newToken.Token old.Token = newToken.Token
// Update the claims of the refresh token. // Update the claims of the refresh token.
// //
@ -1201,6 +1220,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
rawNewToken, err := internal.Marshal(newToken)
if err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
s.writeAccessToken(w, resp) s.writeAccessToken(w, resp)
} }

View file

@ -177,3 +177,73 @@ func (k keyRotator) rotate() error {
k.logger.Infof("keys rotated, next rotation: %s", nextRotation) k.logger.Infof("keys rotated, next rotation: %s", nextRotation)
return nil return nil
} }
type RefreshTokenPolicy struct {
rotateRefreshTokens bool // enable rotation
absoluteLifetime time.Duration // interval from token creation to the end of its life
validIfNotUsedFor time.Duration // interval from last token update to the end of its life
reuseInterval time.Duration // interval within which old refresh token is allowed to be reused
Clock func() time.Time
logger log.Logger
}
func NewRefreshTokenPolicyFromConfig(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) {
r := RefreshTokenPolicy{Clock: time.Now, logger: logger}
var err error
if validIfNotUsedFor != "" {
r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err)
}
logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor)
}
if absoluteLifetime != "" {
r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err)
}
logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime)
}
if reuseInterval != "" {
r.reuseInterval, err = time.ParseDuration(reuseInterval)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err)
}
logger.Infof("config refresh tokens reuse interval: %v", reuseInterval)
}
r.rotateRefreshTokens = !rotation
logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens)
return &r, nil
}
func (r *RefreshTokenPolicy) RotationEnabled() bool {
return r.rotateRefreshTokens
}
func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool {
if r.absoluteLifetime == 0 {
return false // expiration disabled
}
return r.Clock().After(lastUsed.Add(r.absoluteLifetime))
}
func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool {
if r.validIfNotUsedFor == 0 {
return false // expiration disabled
}
return r.Clock().After(lastUsed.Add(r.validIfNotUsedFor))
}
func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool {
if r.reuseInterval == 0 {
return false // expiration disabled
}
return !r.Clock().After(lastUsed.Add(r.reuseInterval))
}

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory" "github.com/dexidp/dex/storage/memory"
@ -100,3 +101,30 @@ func TestKeyRotator(t *testing.T) {
} }
} }
} }
func TestRefreshTokenPolicy(t *testing.T) {
lastTime := time.Now()
l := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
r, err := NewRefreshTokenPolicyFromConfig(l, true, "1m", "1m", "1m")
require.NoError(t, err)
t.Run("Allowed", func(t *testing.T) {
r.Clock = func() time.Time { return lastTime }
require.Equal(t, true, r.AllowedToReuse(lastTime))
require.Equal(t, false, r.ExpiredBecauseUnused(lastTime))
require.Equal(t, false, r.CompletelyExpired(lastTime))
})
t.Run("Expired", func(t *testing.T) {
r.Clock = func() time.Time { return lastTime.Add(2 * time.Minute) }
time.Sleep(1 * time.Second)
require.Equal(t, false, r.AllowedToReuse(lastTime))
require.Equal(t, true, r.ExpiredBecauseUnused(lastTime))
require.Equal(t, true, r.CompletelyExpired(lastTime))
})
}

View file

@ -80,6 +80,10 @@ type Config struct {
IDTokensValidFor time.Duration // Defaults to 24 hours IDTokensValidFor time.Duration // Defaults to 24 hours
AuthRequestsValidFor time.Duration // Defaults to 24 hours AuthRequestsValidFor time.Duration // Defaults to 24 hours
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
// Refresh token expiration settings
RefreshTokenPolicy *RefreshTokenPolicy
// If set, the server will use this connector to handle password grants // If set, the server will use this connector to handle password grants
PasswordConnector string PasswordConnector string
@ -159,6 +163,8 @@ type Server struct {
authRequestsValidFor time.Duration authRequestsValidFor time.Duration
deviceRequestsValidFor time.Duration deviceRequestsValidFor time.Duration
refreshTokenPolicy *RefreshTokenPolicy
logger log.Logger logger log.Logger
} }
@ -227,6 +233,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute), deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
refreshTokenPolicy: c.RefreshTokenPolicy,
skipApproval: c.SkipApprovalScreen, skipApproval: c.SkipApprovalScreen,
alwaysShowLogin: c.AlwaysShowLoginScreen, alwaysShowLogin: c.AlwaysShowLoginScreen,
now: now, now: now,

View file

@ -677,6 +677,13 @@ func TestOAuth2CodeFlow(t *testing.T) {
}) })
defer httpServer.Close() defer httpServer.Close()
policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "")
if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err)
}
policy.Clock = now
s.refreshTokenPolicy = policy
mockConn := s.connectors["mock"] mockConn := s.connectors["mock"]
conn = mockConn.Connector.(*mock.Callback) conn = mockConn.Connector.(*mock.Callback)
@ -1508,6 +1515,13 @@ func TestOAuth2DeviceFlow(t *testing.T) {
}) })
defer httpServer.Close() defer httpServer.Close()
policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "")
if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err)
}
policy.Clock = now
s.refreshTokenPolicy = policy
mockConn := s.connectors["mock"] mockConn := s.connectors["mock"]
conn = mockConn.Connector.(*mock.Callback) conn = mockConn.Connector.(*mock.Callback)

View file

@ -324,14 +324,15 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
id := storage.NewID() id := storage.NewID()
refresh := storage.RefreshToken{ refresh := storage.RefreshToken{
ID: id, ID: id,
Token: "bar", Token: "bar",
Nonce: "foo", ObsoleteToken: "",
ClientID: "client_id", Nonce: "foo",
ConnectorID: "client_secret", ClientID: "client_id",
Scopes: []string{"openid", "email", "profile"}, ConnectorID: "client_secret",
CreatedAt: time.Now().UTC().Round(time.Millisecond), Scopes: []string{"openid", "email", "profile"},
LastUsed: time.Now().UTC().Round(time.Millisecond), CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
Claims: storage.Claims{ Claims: storage.Claims{
UserID: "1", UserID: "1",
Username: "jane", Username: "jane",
@ -378,14 +379,15 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
id2 := storage.NewID() id2 := storage.NewID()
refresh2 := storage.RefreshToken{ refresh2 := storage.RefreshToken{
ID: id2, ID: id2,
Token: "bar_2", Token: "bar_2",
Nonce: "foo_2", ObsoleteToken: "bar",
ClientID: "client_id_2", Nonce: "foo_2",
ConnectorID: "client_secret", ClientID: "client_id_2",
Scopes: []string{"openid", "email", "profile"}, ConnectorID: "client_secret",
CreatedAt: time.Now().UTC().Round(time.Millisecond), Scopes: []string{"openid", "email", "profile"},
LastUsed: time.Now().UTC().Round(time.Millisecond), CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
Claims: storage.Claims{ Claims: storage.Claims{
UserID: "2", UserID: "2",
Username: "john", Username: "john",

View file

@ -132,7 +132,8 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest {
type RefreshToken struct { type RefreshToken struct {
ID string `json:"id"` ID string `json:"id"`
Token string `json:"token"` Token string `json:"token"`
ObsoleteToken string `json:"obsolete_token"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
LastUsed time.Time `json:"last_used"` LastUsed time.Time `json:"last_used"`
@ -152,6 +153,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
return storage.RefreshToken{ return storage.RefreshToken{
ID: r.ID, ID: r.ID,
Token: r.Token, Token: r.Token,
ObsoleteToken: r.ObsoleteToken,
CreatedAt: r.CreatedAt, CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed, LastUsed: r.LastUsed,
ClientID: r.ClientID, ClientID: r.ClientID,
@ -167,6 +169,7 @@ func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
return RefreshToken{ return RefreshToken{
ID: r.ID, ID: r.ID,
Token: r.Token, Token: r.Token,
ObsoleteToken: r.ObsoleteToken,
CreatedAt: r.CreatedAt, CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed, LastUsed: r.LastUsed,
ClientID: r.ClientID, ClientID: r.ClientID,

View file

@ -496,7 +496,8 @@ type RefreshToken struct {
ClientID string `json:"clientID"` ClientID string `json:"clientID"`
Scopes []string `json:"scopes,omitempty"` Scopes []string `json:"scopes,omitempty"`
Token string `json:"token,omitempty"` Token string `json:"token,omitempty"`
ObsoleteToken string `json:"obsoleteToken,omitempty"`
Nonce string `json:"nonce,omitempty"` Nonce string `json:"nonce,omitempty"`
@ -516,6 +517,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
return storage.RefreshToken{ return storage.RefreshToken{
ID: r.ObjectMeta.Name, ID: r.ObjectMeta.Name,
Token: r.Token, Token: r.Token,
ObsoleteToken: r.ObsoleteToken,
CreatedAt: r.CreatedAt, CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed, LastUsed: r.LastUsed,
ClientID: r.ClientID, ClientID: r.ClientID,
@ -538,6 +540,7 @@ func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken
Namespace: cli.namespace, Namespace: cli.namespace,
}, },
Token: r.Token, Token: r.Token,
ObsoleteToken: r.ObsoleteToken,
CreatedAt: r.CreatedAt, CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed, LastUsed: r.LastUsed,
ClientID: r.ClientID, ClientID: r.ClientID,

View file

@ -285,16 +285,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, connector_id, connector_data,
token, created_at, last_used token, obsolete_token, created_at, last_used
) )
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15); values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16);
`, `,
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce, r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
r.Claims.Email, r.Claims.EmailVerified, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups), encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData, r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
@ -329,17 +329,18 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
connector_id = $10, connector_id = $10,
connector_data = $11, connector_data = $11,
token = $12, token = $12,
created_at = $13, obsolete_token = $13,
last_used = $14 created_at = $14,
last_used = $15
where where
id = $15 id = $16
`, `,
r.ClientID, encoder(r.Scopes), r.Nonce, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
r.Claims.Email, r.Claims.EmailVerified, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups), encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData, r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, id, r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, id,
) )
if err != nil { if err != nil {
return fmt.Errorf("update refresh token: %v", err) return fmt.Errorf("update refresh token: %v", err)
@ -360,7 +361,7 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) {
claims_email, claims_email_verified, claims_email, claims_email_verified,
claims_groups, claims_groups,
connector_id, connector_data, connector_id, connector_data,
token, created_at, last_used token, obsolete_token, created_at, last_used
from refresh_token where id = $1; from refresh_token where id = $1;
`, id)) `, id))
} }
@ -372,7 +373,7 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, connector_id, connector_data,
token, created_at, last_used token, obsolete_token, created_at, last_used
from refresh_token; from refresh_token;
`) `)
if err != nil { if err != nil {
@ -401,7 +402,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
&r.Claims.Email, &r.Claims.EmailVerified, &r.Claims.Email, &r.Claims.EmailVerified,
decoder(&r.Claims.Groups), decoder(&r.Claims.Groups),
&r.ConnectorID, &r.ConnectorData, &r.ConnectorID, &r.ConnectorData,
&r.Token, &r.CreatedAt, &r.LastUsed, &r.Token, &r.ObsoleteToken, &r.CreatedAt, &r.LastUsed,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {

View file

@ -176,6 +176,9 @@ var migrations = []migration{
alter table refresh_token alter table refresh_token
add column token text not null default '';`, add column token text not null default '';`,
` `
alter table refresh_token
add column obsolete_token text default '';`,
`
alter table refresh_token alter table refresh_token
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`, add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`,
` `

View file

@ -271,7 +271,8 @@ type RefreshToken struct {
// A single token that's rotated every time the refresh token is refreshed. // A single token that's rotated every time the refresh token is refreshed.
// //
// May be empty. // May be empty.
Token string Token string
ObsoleteToken string
CreatedAt time.Time CreatedAt time.Time
LastUsed time.Time LastUsed time.Time