feat: Add refresh token expiration and rotation settings

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh 2020-10-28 10:26:34 +04:00
parent 10597cf09f
commit 91de99d57e
14 changed files with 226 additions and 42 deletions

View file

@ -304,6 +304,9 @@ type Expiry struct {
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid. // DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
DeviceRequests string `json:"deviceRequests"` DeviceRequests string `json:"deviceRequests"`
// RefreshToken defines refresh tokens expiry policy
RefreshToken RefreshTokenExpiry `json:"refreshTokens"`
} }
// Logger holds configuration required to customize logging for dex. // Logger holds configuration required to customize logging for dex.
@ -314,3 +317,10 @@ type Logger struct {
// Format specifies the format to be used for logging. // Format specifies the format to be used for logging.
Format string `json:"format"` Format string `json:"format"`
} }
type RefreshTokenExpiry struct {
DisableRotation bool `json:"disableRotation"`
ReuseInterval string `json:"reuseInterval"`
AbsoluteLifetime string `json:"absoluteLifetime"`
ValidIfNotUsedFor string `json:"validIfNotUsedFor"`
}

View file

@ -317,6 +317,18 @@ func runServe(options serveOptions) error {
logger.Infof("config device requests valid for: %v", deviceRequests) logger.Infof("config device requests valid for: %v", deviceRequests)
serverConfig.DeviceRequestsValidFor = deviceRequests serverConfig.DeviceRequestsValidFor = deviceRequests
} }
refreshTokenPolicy, err := server.NewRefreshTokenPolicyFromConfig(
logger,
c.Expiry.RefreshToken.DisableRotation,
c.Expiry.RefreshToken.ValidIfNotUsedFor,
c.Expiry.RefreshToken.AbsoluteLifetime,
c.Expiry.RefreshToken.ReuseInterval,
)
if err != nil {
return fmt.Errorf("invalid refresh token expiration policy config: %v", err)
}
serverConfig.RefreshTokenPolicy = refreshTokenPolicy
serv, err := server.NewServer(context.Background(), serverConfig) serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize server: %v", err) return fmt.Errorf("failed to initialize server: %v", err)

View file

@ -77,6 +77,10 @@ telemetry:
# deviceRequests: "5m" # deviceRequests: "5m"
# signingKeys: "6h" # signingKeys: "6h"
# idTokens: "24h" # idTokens: "24h"
# refreshTokens:
# reuseInterval: "3s"
# validIfNotUsedFor: "2190h"
# absoluteLifetime: "5000h"
# Options for controlling the logger. # Options for controlling the logger.
# logger: # logger:

View file

@ -1035,16 +1035,29 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
} }
return return
} }
if refresh.ClientID != client.ID { if refresh.ClientID != client.ID {
s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID) s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
return return
} }
if refresh.Token != token.Token { if refresh.Token != token.Token {
if !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) || refresh.ObsoleteToken != token.Token {
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
return return
} }
}
if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) {
s.logger.Errorf("refresh token with id %s expired", refresh.ID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest)
return
}
if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) {
s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest)
return
}
// Per the OAuth2 spec, if the client has omitted the scopes, default to the original // Per the OAuth2 spec, if the client has omitted the scopes, default to the original
// authorized scopes. // authorized scopes.
@ -1147,22 +1160,28 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
newToken := &internal.RefreshToken{ newToken := token
if s.refreshTokenPolicy.RotationEnabled() {
newToken = &internal.RefreshToken{
RefreshId: refresh.ID, RefreshId: refresh.ID,
Token: storage.NewID(), Token: storage.NewID(),
} }
rawNewToken, err := internal.Marshal(newToken)
if err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
} }
lastUsed := s.now() lastUsed := s.now()
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
if s.refreshTokenPolicy.RotationEnabled() {
if old.Token != refresh.Token { if old.Token != refresh.Token {
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token {
newToken.Token = old.Token
return old, nil
}
return old, errors.New("refresh token claimed twice") return old, errors.New("refresh token claimed twice")
} }
old.ObsoleteToken = old.Token
}
old.Token = newToken.Token old.Token = newToken.Token
// Update the claims of the refresh token. // Update the claims of the refresh token.
// //
@ -1201,6 +1220,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
rawNewToken, err := internal.Marshal(newToken)
if err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
s.writeAccessToken(w, resp) s.writeAccessToken(w, resp)
} }

View file

@ -177,3 +177,73 @@ func (k keyRotator) rotate() error {
k.logger.Infof("keys rotated, next rotation: %s", nextRotation) k.logger.Infof("keys rotated, next rotation: %s", nextRotation)
return nil return nil
} }
type RefreshTokenPolicy struct {
rotateRefreshTokens bool // enable rotation
absoluteLifetime time.Duration // interval from token creation to the end of its life
validIfNotUsedFor time.Duration // interval from last token update to the end of its life
reuseInterval time.Duration // interval within which old refresh token is allowed to be reused
Clock func() time.Time
logger log.Logger
}
func NewRefreshTokenPolicyFromConfig(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) {
r := RefreshTokenPolicy{Clock: time.Now, logger: logger}
var err error
if validIfNotUsedFor != "" {
r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err)
}
logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor)
}
if absoluteLifetime != "" {
r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err)
}
logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime)
}
if reuseInterval != "" {
r.reuseInterval, err = time.ParseDuration(reuseInterval)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err)
}
logger.Infof("config refresh tokens reuse interval: %v", reuseInterval)
}
r.rotateRefreshTokens = !rotation
logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens)
return &r, nil
}
func (r *RefreshTokenPolicy) RotationEnabled() bool {
return r.rotateRefreshTokens
}
func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool {
if r.absoluteLifetime == 0 {
return false // expiration disabled
}
return r.Clock().After(lastUsed.Add(r.absoluteLifetime))
}
func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool {
if r.validIfNotUsedFor == 0 {
return false // expiration disabled
}
return r.Clock().After(lastUsed.Add(r.validIfNotUsedFor))
}
func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool {
if r.reuseInterval == 0 {
return false // expiration disabled
}
return !r.Clock().After(lastUsed.Add(r.reuseInterval))
}

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory" "github.com/dexidp/dex/storage/memory"
@ -100,3 +101,30 @@ func TestKeyRotator(t *testing.T) {
} }
} }
} }
func TestRefreshTokenPolicy(t *testing.T) {
lastTime := time.Now()
l := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
r, err := NewRefreshTokenPolicyFromConfig(l, true, "1m", "1m", "1m")
require.NoError(t, err)
t.Run("Allowed", func(t *testing.T) {
r.Clock = func() time.Time { return lastTime }
require.Equal(t, true, r.AllowedToReuse(lastTime))
require.Equal(t, false, r.ExpiredBecauseUnused(lastTime))
require.Equal(t, false, r.CompletelyExpired(lastTime))
})
t.Run("Expired", func(t *testing.T) {
r.Clock = func() time.Time { return lastTime.Add(2 * time.Minute) }
time.Sleep(1 * time.Second)
require.Equal(t, false, r.AllowedToReuse(lastTime))
require.Equal(t, true, r.ExpiredBecauseUnused(lastTime))
require.Equal(t, true, r.CompletelyExpired(lastTime))
})
}

View file

@ -80,6 +80,10 @@ type Config struct {
IDTokensValidFor time.Duration // Defaults to 24 hours IDTokensValidFor time.Duration // Defaults to 24 hours
AuthRequestsValidFor time.Duration // Defaults to 24 hours AuthRequestsValidFor time.Duration // Defaults to 24 hours
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
// Refresh token expiration settings
RefreshTokenPolicy *RefreshTokenPolicy
// If set, the server will use this connector to handle password grants // If set, the server will use this connector to handle password grants
PasswordConnector string PasswordConnector string
@ -159,6 +163,8 @@ type Server struct {
authRequestsValidFor time.Duration authRequestsValidFor time.Duration
deviceRequestsValidFor time.Duration deviceRequestsValidFor time.Duration
refreshTokenPolicy *RefreshTokenPolicy
logger log.Logger logger log.Logger
} }
@ -227,6 +233,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute), deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
refreshTokenPolicy: c.RefreshTokenPolicy,
skipApproval: c.SkipApprovalScreen, skipApproval: c.SkipApprovalScreen,
alwaysShowLogin: c.AlwaysShowLoginScreen, alwaysShowLogin: c.AlwaysShowLoginScreen,
now: now, now: now,

View file

@ -677,6 +677,13 @@ func TestOAuth2CodeFlow(t *testing.T) {
}) })
defer httpServer.Close() defer httpServer.Close()
policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "")
if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err)
}
policy.Clock = now
s.refreshTokenPolicy = policy
mockConn := s.connectors["mock"] mockConn := s.connectors["mock"]
conn = mockConn.Connector.(*mock.Callback) conn = mockConn.Connector.(*mock.Callback)
@ -1508,6 +1515,13 @@ func TestOAuth2DeviceFlow(t *testing.T) {
}) })
defer httpServer.Close() defer httpServer.Close()
policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "")
if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err)
}
policy.Clock = now
s.refreshTokenPolicy = policy
mockConn := s.connectors["mock"] mockConn := s.connectors["mock"]
conn = mockConn.Connector.(*mock.Callback) conn = mockConn.Connector.(*mock.Callback)

View file

@ -326,6 +326,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
refresh := storage.RefreshToken{ refresh := storage.RefreshToken{
ID: id, ID: id,
Token: "bar", Token: "bar",
ObsoleteToken: "",
Nonce: "foo", Nonce: "foo",
ClientID: "client_id", ClientID: "client_id",
ConnectorID: "client_secret", ConnectorID: "client_secret",
@ -380,6 +381,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
refresh2 := storage.RefreshToken{ refresh2 := storage.RefreshToken{
ID: id2, ID: id2,
Token: "bar_2", Token: "bar_2",
ObsoleteToken: "bar",
Nonce: "foo_2", Nonce: "foo_2",
ClientID: "client_id_2", ClientID: "client_id_2",
ConnectorID: "client_secret", ConnectorID: "client_secret",

View file

@ -133,6 +133,7 @@ type RefreshToken struct {
ID string `json:"id"` ID string `json:"id"`
Token string `json:"token"` Token string `json:"token"`
ObsoleteToken string `json:"obsolete_token"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
LastUsed time.Time `json:"last_used"` LastUsed time.Time `json:"last_used"`
@ -152,6 +153,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
return storage.RefreshToken{ return storage.RefreshToken{
ID: r.ID, ID: r.ID,
Token: r.Token, Token: r.Token,
ObsoleteToken: r.ObsoleteToken,
CreatedAt: r.CreatedAt, CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed, LastUsed: r.LastUsed,
ClientID: r.ClientID, ClientID: r.ClientID,
@ -167,6 +169,7 @@ func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
return RefreshToken{ return RefreshToken{
ID: r.ID, ID: r.ID,
Token: r.Token, Token: r.Token,
ObsoleteToken: r.ObsoleteToken,
CreatedAt: r.CreatedAt, CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed, LastUsed: r.LastUsed,
ClientID: r.ClientID, ClientID: r.ClientID,

View file

@ -497,6 +497,7 @@ type RefreshToken struct {
Scopes []string `json:"scopes,omitempty"` Scopes []string `json:"scopes,omitempty"`
Token string `json:"token,omitempty"` Token string `json:"token,omitempty"`
ObsoleteToken string `json:"obsoleteToken,omitempty"`
Nonce string `json:"nonce,omitempty"` Nonce string `json:"nonce,omitempty"`
@ -516,6 +517,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
return storage.RefreshToken{ return storage.RefreshToken{
ID: r.ObjectMeta.Name, ID: r.ObjectMeta.Name,
Token: r.Token, Token: r.Token,
ObsoleteToken: r.ObsoleteToken,
CreatedAt: r.CreatedAt, CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed, LastUsed: r.LastUsed,
ClientID: r.ClientID, ClientID: r.ClientID,
@ -538,6 +540,7 @@ func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken
Namespace: cli.namespace, Namespace: cli.namespace,
}, },
Token: r.Token, Token: r.Token,
ObsoleteToken: r.ObsoleteToken,
CreatedAt: r.CreatedAt, CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed, LastUsed: r.LastUsed,
ClientID: r.ClientID, ClientID: r.ClientID,

View file

@ -285,16 +285,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, connector_id, connector_data,
token, created_at, last_used token, obsolete_token, created_at, last_used
) )
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15); values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16);
`, `,
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce, r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
r.Claims.Email, r.Claims.EmailVerified, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups), encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData, r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
@ -329,17 +329,18 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
connector_id = $10, connector_id = $10,
connector_data = $11, connector_data = $11,
token = $12, token = $12,
created_at = $13, obsolete_token = $13,
last_used = $14 created_at = $14,
last_used = $15
where where
id = $15 id = $16
`, `,
r.ClientID, encoder(r.Scopes), r.Nonce, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
r.Claims.Email, r.Claims.EmailVerified, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups), encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData, r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, id, r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, id,
) )
if err != nil { if err != nil {
return fmt.Errorf("update refresh token: %v", err) return fmt.Errorf("update refresh token: %v", err)
@ -360,7 +361,7 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) {
claims_email, claims_email_verified, claims_email, claims_email_verified,
claims_groups, claims_groups,
connector_id, connector_data, connector_id, connector_data,
token, created_at, last_used token, obsolete_token, created_at, last_used
from refresh_token where id = $1; from refresh_token where id = $1;
`, id)) `, id))
} }
@ -372,7 +373,7 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, connector_id, connector_data,
token, created_at, last_used token, obsolete_token, created_at, last_used
from refresh_token; from refresh_token;
`) `)
if err != nil { if err != nil {
@ -401,7 +402,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
&r.Claims.Email, &r.Claims.EmailVerified, &r.Claims.Email, &r.Claims.EmailVerified,
decoder(&r.Claims.Groups), decoder(&r.Claims.Groups),
&r.ConnectorID, &r.ConnectorData, &r.ConnectorID, &r.ConnectorData,
&r.Token, &r.CreatedAt, &r.LastUsed, &r.Token, &r.ObsoleteToken, &r.CreatedAt, &r.LastUsed,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {

View file

@ -176,6 +176,9 @@ var migrations = []migration{
alter table refresh_token alter table refresh_token
add column token text not null default '';`, add column token text not null default '';`,
` `
alter table refresh_token
add column obsolete_token text default '';`,
`
alter table refresh_token alter table refresh_token
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`, add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`,
` `

View file

@ -272,6 +272,7 @@ type RefreshToken struct {
// //
// May be empty. // May be empty.
Token string Token string
ObsoleteToken string
CreatedAt time.Time CreatedAt time.Time
LastUsed time.Time LastUsed time.Time