Fixes of naming and code style

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh 2020-11-15 22:26:34 +04:00
parent 91de99d57e
commit 06c8ab5aa7
7 changed files with 37 additions and 35 deletions

View file

@ -305,8 +305,8 @@ type Expiry struct {
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid. // DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
DeviceRequests string `json:"deviceRequests"` DeviceRequests string `json:"deviceRequests"`
// RefreshToken defines refresh tokens expiry policy // RefreshTokens defines refresh tokens expiry policy
RefreshToken RefreshTokenExpiry `json:"refreshTokens"` RefreshTokens RefreshTokenExpiry `json:"refreshTokens"`
} }
// Logger holds configuration required to customize logging for dex. // Logger holds configuration required to customize logging for dex.

View file

@ -317,12 +317,12 @@ func runServe(options serveOptions) error {
logger.Infof("config device requests valid for: %v", deviceRequests) logger.Infof("config device requests valid for: %v", deviceRequests)
serverConfig.DeviceRequestsValidFor = deviceRequests serverConfig.DeviceRequestsValidFor = deviceRequests
} }
refreshTokenPolicy, err := server.NewRefreshTokenPolicyFromConfig( refreshTokenPolicy, err := server.NewRefreshTokenPolicy(
logger, logger,
c.Expiry.RefreshToken.DisableRotation, c.Expiry.RefreshTokens.DisableRotation,
c.Expiry.RefreshToken.ValidIfNotUsedFor, c.Expiry.RefreshTokens.ValidIfNotUsedFor,
c.Expiry.RefreshToken.AbsoluteLifetime, c.Expiry.RefreshTokens.AbsoluteLifetime,
c.Expiry.RefreshToken.ReuseInterval, c.Expiry.RefreshTokens.ReuseInterval,
) )
if err != nil { if err != nil {
return fmt.Errorf("invalid refresh token expiration policy config: %v", err) return fmt.Errorf("invalid refresh token expiration policy config: %v", err)

View file

@ -1042,7 +1042,12 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
if refresh.Token != token.Token { if refresh.Token != token.Token {
if !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) || refresh.ObsoleteToken != token.Token { switch {
case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed):
fallthrough
case refresh.ObsoleteToken != token.Token:
fallthrough
case refresh.ObsoleteToken == "":
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
return return

View file

@ -185,13 +185,13 @@ type RefreshTokenPolicy struct {
validIfNotUsedFor time.Duration // interval from last token update to the end of its life validIfNotUsedFor time.Duration // interval from last token update to the end of its life
reuseInterval time.Duration // interval within which old refresh token is allowed to be reused reuseInterval time.Duration // interval within which old refresh token is allowed to be reused
Clock func() time.Time now func() time.Time
logger log.Logger logger log.Logger
} }
func NewRefreshTokenPolicyFromConfig(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) {
r := RefreshTokenPolicy{Clock: time.Now, logger: logger} r := RefreshTokenPolicy{now: time.Now, logger: logger}
var err error var err error
if validIfNotUsedFor != "" { if validIfNotUsedFor != "" {
@ -231,19 +231,19 @@ func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool {
if r.absoluteLifetime == 0 { if r.absoluteLifetime == 0 {
return false // expiration disabled return false // expiration disabled
} }
return r.Clock().After(lastUsed.Add(r.absoluteLifetime)) return r.now().After(lastUsed.Add(r.absoluteLifetime))
} }
func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool { func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool {
if r.validIfNotUsedFor == 0 { if r.validIfNotUsedFor == 0 {
return false // expiration disabled return false // expiration disabled
} }
return r.Clock().After(lastUsed.Add(r.validIfNotUsedFor)) return r.now().After(lastUsed.Add(r.validIfNotUsedFor))
} }
func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool { func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool {
if r.reuseInterval == 0 { if r.reuseInterval == 0 {
return false // expiration disabled return false // expiration disabled
} }
return !r.Clock().After(lastUsed.Add(r.reuseInterval)) return !r.now().After(lastUsed.Add(r.reuseInterval))
} }

View file

@ -110,19 +110,18 @@ func TestRefreshTokenPolicy(t *testing.T) {
Level: logrus.DebugLevel, Level: logrus.DebugLevel,
} }
r, err := NewRefreshTokenPolicyFromConfig(l, true, "1m", "1m", "1m") r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m")
require.NoError(t, err) require.NoError(t, err)
t.Run("Allowed", func(t *testing.T) { t.Run("Allowed", func(t *testing.T) {
r.Clock = func() time.Time { return lastTime } r.now = func() time.Time { return lastTime }
require.Equal(t, true, r.AllowedToReuse(lastTime)) require.Equal(t, true, r.AllowedToReuse(lastTime))
require.Equal(t, false, r.ExpiredBecauseUnused(lastTime)) require.Equal(t, false, r.ExpiredBecauseUnused(lastTime))
require.Equal(t, false, r.CompletelyExpired(lastTime)) require.Equal(t, false, r.CompletelyExpired(lastTime))
}) })
t.Run("Expired", func(t *testing.T) { t.Run("Expired", func(t *testing.T) {
r.Clock = func() time.Time { return lastTime.Add(2 * time.Minute) } r.now = func() time.Time { return lastTime.Add(2 * time.Minute) }
time.Sleep(1 * time.Second)
require.Equal(t, false, r.AllowedToReuse(lastTime)) require.Equal(t, false, r.AllowedToReuse(lastTime))
require.Equal(t, true, r.ExpiredBecauseUnused(lastTime)) require.Equal(t, true, r.ExpiredBecauseUnused(lastTime))
require.Equal(t, true, r.CompletelyExpired(lastTime)) require.Equal(t, true, r.CompletelyExpired(lastTime))

View file

@ -117,6 +117,14 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
t.Fatal(err) t.Fatal(err)
} }
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
// Default rotation policy
server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "")
if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err)
}
server.refreshTokenPolicy.now = config.Now
return s, server return s, server
} }
@ -677,13 +685,6 @@ func TestOAuth2CodeFlow(t *testing.T) {
}) })
defer httpServer.Close() defer httpServer.Close()
policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "")
if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err)
}
policy.Clock = now
s.refreshTokenPolicy = policy
mockConn := s.connectors["mock"] mockConn := s.connectors["mock"]
conn = mockConn.Connector.(*mock.Callback) conn = mockConn.Connector.(*mock.Callback)
@ -1515,13 +1516,6 @@ func TestOAuth2DeviceFlow(t *testing.T) {
}) })
defer httpServer.Close() defer httpServer.Close()
policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "")
if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err)
}
policy.Clock = now
s.refreshTokenPolicy = policy
mockConn := s.connectors["mock"] mockConn := s.connectors["mock"]
conn = mockConn.Connector.(*mock.Callback) conn = mockConn.Connector.(*mock.Callback)

View file

@ -176,9 +176,6 @@ var migrations = []migration{
alter table refresh_token alter table refresh_token
add column token text not null default '';`, add column token text not null default '';`,
` `
alter table refresh_token
add column obsolete_token text default '';`,
`
alter table refresh_token alter table refresh_token
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`, add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`,
` `
@ -277,4 +274,11 @@ var migrations = []migration{
add column code_challenge_method text not null default '';`, add column code_challenge_method text not null default '';`,
}, },
}, },
{
stmts: []string{
`
alter table refresh_token
add column obsolete_token text default '';`,
},
},
} }