forked from mystiq/dex
Merge pull request #1846 from flant/refresh-token-expiration-policy
feat: Add refresh token expiration and rotation settings
This commit is contained in:
commit
551229a986
16 changed files with 738 additions and 230 deletions
|
@ -304,6 +304,9 @@ type Expiry struct {
|
||||||
|
|
||||||
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
|
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
|
||||||
DeviceRequests string `json:"deviceRequests"`
|
DeviceRequests string `json:"deviceRequests"`
|
||||||
|
|
||||||
|
// RefreshTokens defines refresh tokens expiry policy
|
||||||
|
RefreshTokens RefreshToken `json:"refreshTokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logger holds configuration required to customize logging for dex.
|
// Logger holds configuration required to customize logging for dex.
|
||||||
|
@ -314,3 +317,10 @@ type Logger struct {
|
||||||
// Format specifies the format to be used for logging.
|
// Format specifies the format to be used for logging.
|
||||||
Format string `json:"format"`
|
Format string `json:"format"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RefreshToken struct {
|
||||||
|
DisableRotation bool `json:"disableRotation"`
|
||||||
|
ReuseInterval string `json:"reuseInterval"`
|
||||||
|
AbsoluteLifetime string `json:"absoluteLifetime"`
|
||||||
|
ValidIfNotUsedFor string `json:"validIfNotUsedFor"`
|
||||||
|
}
|
||||||
|
|
|
@ -304,6 +304,18 @@ func runServe(options serveOptions) error {
|
||||||
logger.Infof("config device requests valid for: %v", deviceRequests)
|
logger.Infof("config device requests valid for: %v", deviceRequests)
|
||||||
serverConfig.DeviceRequestsValidFor = deviceRequests
|
serverConfig.DeviceRequestsValidFor = deviceRequests
|
||||||
}
|
}
|
||||||
|
refreshTokenPolicy, err := server.NewRefreshTokenPolicy(
|
||||||
|
logger,
|
||||||
|
c.Expiry.RefreshTokens.DisableRotation,
|
||||||
|
c.Expiry.RefreshTokens.ValidIfNotUsedFor,
|
||||||
|
c.Expiry.RefreshTokens.AbsoluteLifetime,
|
||||||
|
c.Expiry.RefreshTokens.ReuseInterval,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid refresh token expiration policy config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverConfig.RefreshTokenPolicy = refreshTokenPolicy
|
||||||
serv, err := server.NewServer(context.Background(), serverConfig)
|
serv, err := server.NewServer(context.Background(), serverConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize server: %v", err)
|
return fmt.Errorf("failed to initialize server: %v", err)
|
||||||
|
|
|
@ -73,10 +73,15 @@ telemetry:
|
||||||
# tlsClientCA: examples/grpc-client/ca.crt
|
# tlsClientCA: examples/grpc-client/ca.crt
|
||||||
|
|
||||||
# Uncomment this block to enable configuration for the expiration time durations.
|
# Uncomment this block to enable configuration for the expiration time durations.
|
||||||
|
# Is possible to specify units using only s, m and h suffixes.
|
||||||
# expiry:
|
# expiry:
|
||||||
# deviceRequests: "5m"
|
# deviceRequests: "5m"
|
||||||
# signingKeys: "6h"
|
# signingKeys: "6h"
|
||||||
# idTokens: "24h"
|
# idTokens: "24h"
|
||||||
|
# refreshTokens:
|
||||||
|
# reuseInterval: "3s"
|
||||||
|
# validIfNotUsedFor: "2160h" # 90 days
|
||||||
|
# absoluteLifetime: "3960h" # 165 days
|
||||||
|
|
||||||
# Options for controlling the logger.
|
# Options for controlling the logger.
|
||||||
# logger:
|
# logger:
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -919,206 +918,6 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
|
||||||
return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
|
return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
|
||||||
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
|
|
||||||
code := r.PostFormValue("refresh_token")
|
|
||||||
scope := r.PostFormValue("scope")
|
|
||||||
if code == "" {
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token := new(internal.RefreshToken)
|
|
||||||
if err := internal.Unmarshal(code, token); err != nil {
|
|
||||||
// For backward compatibility, assume the refresh_token is a raw refresh token ID
|
|
||||||
// if it fails to decode.
|
|
||||||
//
|
|
||||||
// Because refresh_token values that aren't unmarshable were generated by servers
|
|
||||||
// that don't have a Token value, we'll still reject any attempts to claim a
|
|
||||||
// refresh_token twice.
|
|
||||||
token = &internal.RefreshToken{RefreshId: code, Token: ""}
|
|
||||||
}
|
|
||||||
|
|
||||||
refresh, err := s.storage.GetRefresh(token.RefreshId)
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("failed to get refresh token: %v", err)
|
|
||||||
if err == storage.ErrNotFound {
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
|
||||||
} else {
|
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if refresh.ClientID != client.ID {
|
|
||||||
s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID)
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if refresh.Token != token.Token {
|
|
||||||
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
|
|
||||||
// authorized scopes.
|
|
||||||
//
|
|
||||||
// https://tools.ietf.org/html/rfc6749#section-6
|
|
||||||
scopes := refresh.Scopes
|
|
||||||
if scope != "" {
|
|
||||||
requestedScopes := strings.Fields(scope)
|
|
||||||
var unauthorizedScopes []string
|
|
||||||
|
|
||||||
for _, s := range requestedScopes {
|
|
||||||
contains := func() bool {
|
|
||||||
for _, scope := range refresh.Scopes {
|
|
||||||
if s == scope {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}()
|
|
||||||
if !contains {
|
|
||||||
unauthorizedScopes = append(unauthorizedScopes, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(unauthorizedScopes) > 0 {
|
|
||||||
msg := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes)
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
scopes = requestedScopes
|
|
||||||
}
|
|
||||||
|
|
||||||
var connectorData []byte
|
|
||||||
|
|
||||||
session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
if err != storage.ErrNotFound {
|
|
||||||
s.logger.Errorf("failed to get offline session: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case len(refresh.ConnectorData) > 0:
|
|
||||||
// Use the old connector data if it exists, should be deleted once used
|
|
||||||
connectorData = refresh.ConnectorData
|
|
||||||
default:
|
|
||||||
connectorData = session.ConnectorData
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := s.getConnector(refresh.ConnectorID)
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
|
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ident := connector.Identity{
|
|
||||||
UserID: refresh.Claims.UserID,
|
|
||||||
Username: refresh.Claims.Username,
|
|
||||||
PreferredUsername: refresh.Claims.PreferredUsername,
|
|
||||||
Email: refresh.Claims.Email,
|
|
||||||
EmailVerified: refresh.Claims.EmailVerified,
|
|
||||||
Groups: refresh.Claims.Groups,
|
|
||||||
ConnectorData: connectorData,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Can the connector refresh the identity? If so, attempt to refresh the data
|
|
||||||
// in the connector.
|
|
||||||
//
|
|
||||||
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
|
|
||||||
// this interface can't perform refreshing.
|
|
||||||
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
|
|
||||||
newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("failed to refresh identity: %v", err)
|
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ident = newIdent
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := storage.Claims{
|
|
||||||
UserID: ident.UserID,
|
|
||||||
Username: ident.Username,
|
|
||||||
PreferredUsername: ident.PreferredUsername,
|
|
||||||
Email: ident.Email,
|
|
||||||
EmailVerified: ident.EmailVerified,
|
|
||||||
Groups: ident.Groups,
|
|
||||||
}
|
|
||||||
|
|
||||||
accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("failed to create new access token: %v", err)
|
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID)
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("failed to create ID token: %v", err)
|
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
newToken := &internal.RefreshToken{
|
|
||||||
RefreshId: refresh.ID,
|
|
||||||
Token: storage.NewID(),
|
|
||||||
}
|
|
||||||
rawNewToken, err := internal.Marshal(newToken)
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lastUsed := s.now()
|
|
||||||
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
|
||||||
if old.Token != refresh.Token {
|
|
||||||
return old, errors.New("refresh token claimed twice")
|
|
||||||
}
|
|
||||||
old.Token = newToken.Token
|
|
||||||
// Update the claims of the refresh token.
|
|
||||||
//
|
|
||||||
// UserID intentionally ignored for now.
|
|
||||||
old.Claims.Username = ident.Username
|
|
||||||
old.Claims.PreferredUsername = ident.PreferredUsername
|
|
||||||
old.Claims.Email = ident.Email
|
|
||||||
old.Claims.EmailVerified = ident.EmailVerified
|
|
||||||
old.Claims.Groups = ident.Groups
|
|
||||||
old.LastUsed = lastUsed
|
|
||||||
|
|
||||||
// ConnectorData has been moved to OfflineSession
|
|
||||||
old.ConnectorData = []byte{}
|
|
||||||
return old, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update LastUsed time stamp in refresh token reference object
|
|
||||||
// in offline session for the user.
|
|
||||||
if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
|
||||||
if old.Refresh[refresh.ClientID].ID != refresh.ID {
|
|
||||||
return old, errors.New("refresh token invalid")
|
|
||||||
}
|
|
||||||
old.Refresh[refresh.ClientID].LastUsed = lastUsed
|
|
||||||
old.ConnectorData = ident.ConnectorData
|
|
||||||
return old, nil
|
|
||||||
}); err != nil {
|
|
||||||
s.logger.Errorf("failed to update offline session: %v", err)
|
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update refresh token in the storage.
|
|
||||||
if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
|
|
||||||
s.logger.Errorf("failed to update refresh token: %v", err)
|
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
|
|
||||||
s.writeAccessToken(w, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
const prefix = "Bearer "
|
const prefix = "Bearer "
|
||||||
|
|
||||||
|
|
339
server/refreshhandlers.go
Normal file
339
server/refreshhandlers.go
Normal file
|
@ -0,0 +1,339 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/connector"
|
||||||
|
"github.com/dexidp/dex/server/internal"
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func contains(arr []string, item string) bool {
|
||||||
|
for _, itemFromArray := range arr {
|
||||||
|
if itemFromArray == item {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type refreshError struct {
|
||||||
|
msg string
|
||||||
|
code int
|
||||||
|
desc string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInternalServerError() *refreshError {
|
||||||
|
return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBadRequestError(desc string) *refreshError {
|
||||||
|
return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) {
|
||||||
|
s.tokenErrHelper(w, err.msg, err.desc, err.code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) {
|
||||||
|
code := r.PostFormValue("refresh_token")
|
||||||
|
if code == "" {
|
||||||
|
return nil, newBadRequestError("No refresh token is found in request.")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := new(internal.RefreshToken)
|
||||||
|
if err := internal.Unmarshal(code, token); err != nil {
|
||||||
|
// For backward compatibility, assume the refresh_token is a raw refresh token ID
|
||||||
|
// if it fails to decode.
|
||||||
|
//
|
||||||
|
// Because refresh_token values that aren't unmarshable were generated by servers
|
||||||
|
// that don't have a Token value, we'll still reject any attempts to claim a
|
||||||
|
// refresh_token twice.
|
||||||
|
token = &internal.RefreshToken{RefreshId: code, Token: ""}
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info
|
||||||
|
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) {
|
||||||
|
invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.")
|
||||||
|
|
||||||
|
refresh, err := s.storage.GetRefresh(token.RefreshId)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to get refresh token: %v", err)
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
return nil, newInternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, invalidErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if refresh.ClientID != clientID {
|
||||||
|
s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID)
|
||||||
|
return nil, invalidErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if refresh.Token != token.Token {
|
||||||
|
switch {
|
||||||
|
case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed):
|
||||||
|
fallthrough
|
||||||
|
case refresh.ObsoleteToken != token.Token:
|
||||||
|
fallthrough
|
||||||
|
case refresh.ObsoleteToken == "":
|
||||||
|
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
|
||||||
|
return nil, invalidErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expiredErr := newBadRequestError("Refresh token expired.")
|
||||||
|
if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) {
|
||||||
|
s.logger.Errorf("refresh token with id %s expired", refresh.ID)
|
||||||
|
return nil, expiredErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) {
|
||||||
|
s.logger.Errorf("refresh token with id %s expired due to inactivity", refresh.ID)
|
||||||
|
return nil, expiredErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return &refresh, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) {
|
||||||
|
// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
|
||||||
|
// authorized scopes.
|
||||||
|
//
|
||||||
|
// https://tools.ietf.org/html/rfc6749#section-6
|
||||||
|
scope := r.PostFormValue("scope")
|
||||||
|
|
||||||
|
if scope == "" {
|
||||||
|
return refresh.Scopes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
requestedScopes := strings.Fields(scope)
|
||||||
|
var unauthorizedScopes []string
|
||||||
|
|
||||||
|
// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
|
||||||
|
// authorized scopes.
|
||||||
|
//
|
||||||
|
// https://tools.ietf.org/html/rfc6749#section-6
|
||||||
|
for _, requestScope := range requestedScopes {
|
||||||
|
if !contains(refresh.Scopes, requestScope) {
|
||||||
|
unauthorizedScopes = append(unauthorizedScopes, requestScope)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(unauthorizedScopes) > 0 {
|
||||||
|
desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes)
|
||||||
|
return nil, newBadRequestError(desc)
|
||||||
|
}
|
||||||
|
|
||||||
|
return requestedScopes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) {
|
||||||
|
var connectorData []byte
|
||||||
|
|
||||||
|
session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get offline session: %v", err)
|
||||||
|
return connector.Identity{}, newInternalServerError()
|
||||||
|
}
|
||||||
|
case len(refresh.ConnectorData) > 0:
|
||||||
|
// Use the old connector data if it exists, should be deleted once used
|
||||||
|
connectorData = refresh.ConnectorData
|
||||||
|
default:
|
||||||
|
connectorData = session.ConnectorData
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := s.getConnector(refresh.ConnectorID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
|
||||||
|
return connector.Identity{}, newInternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
ident := connector.Identity{
|
||||||
|
UserID: refresh.Claims.UserID,
|
||||||
|
Username: refresh.Claims.Username,
|
||||||
|
PreferredUsername: refresh.Claims.PreferredUsername,
|
||||||
|
Email: refresh.Claims.Email,
|
||||||
|
EmailVerified: refresh.Claims.EmailVerified,
|
||||||
|
Groups: refresh.Claims.Groups,
|
||||||
|
ConnectorData: connectorData,
|
||||||
|
}
|
||||||
|
|
||||||
|
// user's token was previously updated by a connector and is allowed to reuse
|
||||||
|
// it is excessive to refresh identity in upstream
|
||||||
|
if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken {
|
||||||
|
return ident, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can the connector refresh the identity? If so, attempt to refresh the data
|
||||||
|
// in the connector.
|
||||||
|
//
|
||||||
|
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
|
||||||
|
// this interface can't perform refreshing.
|
||||||
|
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
|
||||||
|
newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to refresh identity: %v", err)
|
||||||
|
return connector.Identity{}, newInternalServerError()
|
||||||
|
}
|
||||||
|
ident = newIdent
|
||||||
|
}
|
||||||
|
|
||||||
|
return ident, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateOfflineSession updates offline session in the storage
|
||||||
|
func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError {
|
||||||
|
offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
||||||
|
if old.Refresh[refresh.ClientID].ID != refresh.ID {
|
||||||
|
return old, errors.New("refresh token invalid")
|
||||||
|
}
|
||||||
|
old.Refresh[refresh.ClientID].LastUsed = lastUsed
|
||||||
|
old.ConnectorData = ident.ConnectorData
|
||||||
|
return old, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update LastUsed time stamp in refresh token reference object
|
||||||
|
// in offline session for the user.
|
||||||
|
err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to update offline session: %v", err)
|
||||||
|
return newInternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateRefreshToken updates refresh token and offline session in the storage
|
||||||
|
func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) {
|
||||||
|
newToken := token
|
||||||
|
if s.refreshTokenPolicy.RotationEnabled() {
|
||||||
|
newToken = &internal.RefreshToken{
|
||||||
|
RefreshId: refresh.ID,
|
||||||
|
Token: storage.NewID(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lastUsed := s.now()
|
||||||
|
|
||||||
|
rerr := s.updateOfflineSession(refresh, ident, lastUsed)
|
||||||
|
if rerr != nil {
|
||||||
|
return nil, rerr
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
|
if s.refreshTokenPolicy.RotationEnabled() {
|
||||||
|
if old.Token != token.Token {
|
||||||
|
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token {
|
||||||
|
newToken.Token = old.Token
|
||||||
|
return old, nil
|
||||||
|
}
|
||||||
|
return old, errors.New("refresh token claimed twice")
|
||||||
|
}
|
||||||
|
|
||||||
|
old.ObsoleteToken = old.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
old.Token = newToken.Token
|
||||||
|
// Update the claims of the refresh token.
|
||||||
|
//
|
||||||
|
// UserID intentionally ignored for now.
|
||||||
|
old.Claims.Username = ident.Username
|
||||||
|
old.Claims.PreferredUsername = ident.PreferredUsername
|
||||||
|
old.Claims.Email = ident.Email
|
||||||
|
old.Claims.EmailVerified = ident.EmailVerified
|
||||||
|
old.Claims.Groups = ident.Groups
|
||||||
|
old.LastUsed = lastUsed
|
||||||
|
|
||||||
|
// ConnectorData has been moved to OfflineSession
|
||||||
|
old.ConnectorData = []byte{}
|
||||||
|
return old, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update refresh token in the storage.
|
||||||
|
err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to update refresh token: %v", err)
|
||||||
|
return nil, newInternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return newToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
||||||
|
// this method is the entrypoint for refresh tokens handling
|
||||||
|
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
|
||||||
|
token, rerr := s.extractRefreshTokenFromRequest(r)
|
||||||
|
if rerr != nil {
|
||||||
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
refresh, rerr := s.getRefreshTokenFromStorage(client.ID, token)
|
||||||
|
if rerr != nil {
|
||||||
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
scopes, rerr := s.getRefreshScopes(r, refresh)
|
||||||
|
if rerr != nil {
|
||||||
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes)
|
||||||
|
if rerr != nil {
|
||||||
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := storage.Claims{
|
||||||
|
UserID: ident.UserID,
|
||||||
|
Username: ident.Username,
|
||||||
|
PreferredUsername: ident.PreferredUsername,
|
||||||
|
Email: ident.Email,
|
||||||
|
EmailVerified: ident.EmailVerified,
|
||||||
|
Groups: ident.Groups,
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to create new access token: %v", err)
|
||||||
|
s.refreshTokenErrHelper(w, newInternalServerError())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to create ID token: %v", err)
|
||||||
|
s.refreshTokenErrHelper(w, newInternalServerError())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newToken, rerr := s.updateRefreshToken(token, refresh, ident)
|
||||||
|
if rerr != nil {
|
||||||
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawNewToken, err := internal.Marshal(newToken)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||||
|
s.refreshTokenErrHelper(w, newInternalServerError())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
|
||||||
|
s.writeAccessToken(w, resp)
|
||||||
|
}
|
212
server/refreshhandlers_test.go
Normal file
212
server/refreshhandlers_test.go
Normal file
|
@ -0,0 +1,212 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/server/internal"
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) {
|
||||||
|
c := storage.Client{
|
||||||
|
ID: "test",
|
||||||
|
Secret: "barfoo",
|
||||||
|
RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
|
||||||
|
Name: "dex client",
|
||||||
|
LogoURL: "https://goo.gl/JIyzIC",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.CreateClient(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c1 := storage.Connector{
|
||||||
|
ID: "test",
|
||||||
|
Type: "mockCallback",
|
||||||
|
Name: "mockCallback",
|
||||||
|
Config: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.CreateConnector(c1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
refresh := storage.RefreshToken{
|
||||||
|
ID: "test",
|
||||||
|
Token: "bar",
|
||||||
|
ObsoleteToken: "",
|
||||||
|
Nonce: "foo",
|
||||||
|
ClientID: "test",
|
||||||
|
ConnectorID: "test",
|
||||||
|
Scopes: []string{"openid", "email", "profile"},
|
||||||
|
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: "1",
|
||||||
|
Username: "jane",
|
||||||
|
Email: "jane.doe@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Groups: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
if useObsolete {
|
||||||
|
refresh.Token = "testtest"
|
||||||
|
refresh.ObsoleteToken = "bar"
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.CreateRefresh(refresh)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
offlineSessions := storage.OfflineSessions{
|
||||||
|
UserID: "1",
|
||||||
|
ConnID: "test",
|
||||||
|
Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}},
|
||||||
|
ConnectorData: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.CreateOfflineSessions(offlineSessions)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenExpirationScenarios(t *testing.T) {
|
||||||
|
t0 := time.Now()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
policy *RefreshTokenPolicy
|
||||||
|
useObsolete bool
|
||||||
|
error string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Normal",
|
||||||
|
policy: &RefreshTokenPolicy{rotateRefreshTokens: true},
|
||||||
|
error: ``,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not expired because used",
|
||||||
|
policy: &RefreshTokenPolicy{
|
||||||
|
rotateRefreshTokens: false,
|
||||||
|
validIfNotUsedFor: time.Second * 60,
|
||||||
|
now: func() time.Time { return t0.Add(time.Second * 25) },
|
||||||
|
},
|
||||||
|
error: ``,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Expired because not used",
|
||||||
|
policy: &RefreshTokenPolicy{
|
||||||
|
rotateRefreshTokens: false,
|
||||||
|
validIfNotUsedFor: time.Second * 60,
|
||||||
|
now: func() time.Time { return t0.Add(time.Hour) },
|
||||||
|
},
|
||||||
|
error: `{"error":"invalid_request","error_description":"Refresh token expired."}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Absolutely expired",
|
||||||
|
policy: &RefreshTokenPolicy{
|
||||||
|
rotateRefreshTokens: true,
|
||||||
|
absoluteLifetime: time.Second * 60,
|
||||||
|
now: func() time.Time { return t0.Add(time.Hour) },
|
||||||
|
},
|
||||||
|
error: `{"error":"invalid_request","error_description":"Refresh token expired."}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Obsolete tokens are allowed",
|
||||||
|
useObsolete: true,
|
||||||
|
policy: &RefreshTokenPolicy{
|
||||||
|
rotateRefreshTokens: true,
|
||||||
|
reuseInterval: time.Second * 30,
|
||||||
|
now: func() time.Time { return t0.Add(time.Second * 25) },
|
||||||
|
},
|
||||||
|
error: ``,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Obsolete tokens are not allowed",
|
||||||
|
useObsolete: true,
|
||||||
|
policy: &RefreshTokenPolicy{
|
||||||
|
rotateRefreshTokens: true,
|
||||||
|
now: func() time.Time { return t0.Add(time.Second * 25) },
|
||||||
|
},
|
||||||
|
error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Obsolete tokens are allowed but token is expired globally",
|
||||||
|
useObsolete: true,
|
||||||
|
policy: &RefreshTokenPolicy{
|
||||||
|
rotateRefreshTokens: true,
|
||||||
|
reuseInterval: time.Second * 30,
|
||||||
|
absoluteLifetime: time.Second * 20,
|
||||||
|
now: func() time.Time { return t0.Add(time.Second * 25) },
|
||||||
|
},
|
||||||
|
error: `{"error":"invalid_request","error_description":"Refresh token expired."}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(*testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Setup a dex server.
|
||||||
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
c.RefreshTokenPolicy = tc.policy
|
||||||
|
c.Now = func() time.Time { return t0 }
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
mockRefreshTokenTestStorage(t, s.storage, tc.useObsolete)
|
||||||
|
|
||||||
|
u, err := url.Parse(s.issuerURL.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
u.Path = path.Join(u.Path, "/token")
|
||||||
|
v := url.Values{}
|
||||||
|
v.Add("grant_type", "refresh_token")
|
||||||
|
v.Add("refresh_token", tokenData)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||||
|
req.SetBasicAuth("test", "barfoo")
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if tc.error == "" {
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
} else {
|
||||||
|
require.Equal(t, rr.Body.String(), tc.error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we received expected refresh token
|
||||||
|
var ref struct {
|
||||||
|
Token string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(rr.Body.Bytes(), &ref)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tc.policy.rotateRefreshTokens == false {
|
||||||
|
require.Equal(t, tokenData, ref.Token)
|
||||||
|
} else {
|
||||||
|
require.NotEqual(t, tokenData, ref.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.useObsolete {
|
||||||
|
updatedTokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "testtest"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, updatedTokenData, ref.Token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -177,3 +177,73 @@ func (k keyRotator) rotate() error {
|
||||||
k.logger.Infof("keys rotated, next rotation: %s", nextRotation)
|
k.logger.Infof("keys rotated, next rotation: %s", nextRotation)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RefreshTokenPolicy struct {
|
||||||
|
rotateRefreshTokens bool // enable rotation
|
||||||
|
|
||||||
|
absoluteLifetime time.Duration // interval from token creation to the end of its life
|
||||||
|
validIfNotUsedFor time.Duration // interval from last token update to the end of its life
|
||||||
|
reuseInterval time.Duration // interval within which old refresh token is allowed to be reused
|
||||||
|
|
||||||
|
now func() time.Time
|
||||||
|
|
||||||
|
logger log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) {
|
||||||
|
r := RefreshTokenPolicy{now: time.Now, logger: logger}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if validIfNotUsedFor != "" {
|
||||||
|
r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err)
|
||||||
|
}
|
||||||
|
logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if absoluteLifetime != "" {
|
||||||
|
r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err)
|
||||||
|
}
|
||||||
|
logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime)
|
||||||
|
}
|
||||||
|
|
||||||
|
if reuseInterval != "" {
|
||||||
|
r.reuseInterval, err = time.ParseDuration(reuseInterval)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err)
|
||||||
|
}
|
||||||
|
logger.Infof("config refresh tokens reuse interval: %v", reuseInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rotateRefreshTokens = !rotation
|
||||||
|
logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens)
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RefreshTokenPolicy) RotationEnabled() bool {
|
||||||
|
return r.rotateRefreshTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool {
|
||||||
|
if r.absoluteLifetime == 0 {
|
||||||
|
return false // expiration disabled
|
||||||
|
}
|
||||||
|
return r.now().After(lastUsed.Add(r.absoluteLifetime))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool {
|
||||||
|
if r.validIfNotUsedFor == 0 {
|
||||||
|
return false // expiration disabled
|
||||||
|
}
|
||||||
|
return r.now().After(lastUsed.Add(r.validIfNotUsedFor))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool {
|
||||||
|
if r.reuseInterval == 0 {
|
||||||
|
return false // expiration disabled
|
||||||
|
}
|
||||||
|
return !r.now().After(lastUsed.Add(r.reuseInterval))
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
"github.com/dexidp/dex/storage/memory"
|
"github.com/dexidp/dex/storage/memory"
|
||||||
|
@ -100,3 +101,29 @@ func TestKeyRotator(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenPolicy(t *testing.T) {
|
||||||
|
lastTime := time.Now()
|
||||||
|
l := &logrus.Logger{
|
||||||
|
Out: os.Stderr,
|
||||||
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||||
|
Level: logrus.DebugLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("Allowed", func(t *testing.T) {
|
||||||
|
r.now = func() time.Time { return lastTime }
|
||||||
|
require.Equal(t, true, r.AllowedToReuse(lastTime))
|
||||||
|
require.Equal(t, false, r.ExpiredBecauseUnused(lastTime))
|
||||||
|
require.Equal(t, false, r.CompletelyExpired(lastTime))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Expired", func(t *testing.T) {
|
||||||
|
r.now = func() time.Time { return lastTime.Add(2 * time.Minute) }
|
||||||
|
require.Equal(t, false, r.AllowedToReuse(lastTime))
|
||||||
|
require.Equal(t, true, r.ExpiredBecauseUnused(lastTime))
|
||||||
|
require.Equal(t, true, r.CompletelyExpired(lastTime))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -84,6 +84,10 @@ type Config struct {
|
||||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||||
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
||||||
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
|
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
|
||||||
|
|
||||||
|
// Refresh token expiration settings
|
||||||
|
RefreshTokenPolicy *RefreshTokenPolicy
|
||||||
|
|
||||||
// If set, the server will use this connector to handle password grants
|
// If set, the server will use this connector to handle password grants
|
||||||
PasswordConnector string
|
PasswordConnector string
|
||||||
|
|
||||||
|
@ -171,6 +175,8 @@ type Server struct {
|
||||||
authRequestsValidFor time.Duration
|
authRequestsValidFor time.Duration
|
||||||
deviceRequestsValidFor time.Duration
|
deviceRequestsValidFor time.Duration
|
||||||
|
|
||||||
|
refreshTokenPolicy *RefreshTokenPolicy
|
||||||
|
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,6 +252,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||||
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
||||||
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
|
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
|
||||||
|
refreshTokenPolicy: c.RefreshTokenPolicy,
|
||||||
skipApproval: c.SkipApprovalScreen,
|
skipApproval: c.SkipApprovalScreen,
|
||||||
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
||||||
now: now,
|
now: now,
|
||||||
|
|
|
@ -119,6 +119,16 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
||||||
|
|
||||||
|
// Default rotation policy
|
||||||
|
if server.refreshTokenPolicy == nil {
|
||||||
|
server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to prepare rotation policy: %v", err)
|
||||||
|
}
|
||||||
|
server.refreshTokenPolicy.now = config.Now
|
||||||
|
}
|
||||||
|
|
||||||
return s, server
|
return s, server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -326,6 +326,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
refresh := storage.RefreshToken{
|
refresh := storage.RefreshToken{
|
||||||
ID: id,
|
ID: id,
|
||||||
Token: "bar",
|
Token: "bar",
|
||||||
|
ObsoleteToken: "",
|
||||||
Nonce: "foo",
|
Nonce: "foo",
|
||||||
ClientID: "client_id",
|
ClientID: "client_id",
|
||||||
ConnectorID: "client_secret",
|
ConnectorID: "client_secret",
|
||||||
|
@ -380,6 +381,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
refresh2 := storage.RefreshToken{
|
refresh2 := storage.RefreshToken{
|
||||||
ID: id2,
|
ID: id2,
|
||||||
Token: "bar_2",
|
Token: "bar_2",
|
||||||
|
ObsoleteToken: refresh.Token,
|
||||||
Nonce: "foo_2",
|
Nonce: "foo_2",
|
||||||
ClientID: "client_id_2",
|
ClientID: "client_id_2",
|
||||||
ConnectorID: "client_secret",
|
ConnectorID: "client_secret",
|
||||||
|
|
|
@ -133,6 +133,7 @@ type RefreshToken struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
|
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
ObsoleteToken string `json:"obsolete_token"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
LastUsed time.Time `json:"last_used"`
|
LastUsed time.Time `json:"last_used"`
|
||||||
|
@ -152,6 +153,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
|
||||||
return storage.RefreshToken{
|
return storage.RefreshToken{
|
||||||
ID: r.ID,
|
ID: r.ID,
|
||||||
Token: r.Token,
|
Token: r.Token,
|
||||||
|
ObsoleteToken: r.ObsoleteToken,
|
||||||
CreatedAt: r.CreatedAt,
|
CreatedAt: r.CreatedAt,
|
||||||
LastUsed: r.LastUsed,
|
LastUsed: r.LastUsed,
|
||||||
ClientID: r.ClientID,
|
ClientID: r.ClientID,
|
||||||
|
@ -167,6 +169,7 @@ func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
|
||||||
return RefreshToken{
|
return RefreshToken{
|
||||||
ID: r.ID,
|
ID: r.ID,
|
||||||
Token: r.Token,
|
Token: r.Token,
|
||||||
|
ObsoleteToken: r.ObsoleteToken,
|
||||||
CreatedAt: r.CreatedAt,
|
CreatedAt: r.CreatedAt,
|
||||||
LastUsed: r.LastUsed,
|
LastUsed: r.LastUsed,
|
||||||
ClientID: r.ClientID,
|
ClientID: r.ClientID,
|
||||||
|
|
|
@ -497,6 +497,7 @@ type RefreshToken struct {
|
||||||
Scopes []string `json:"scopes,omitempty"`
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
|
||||||
Token string `json:"token,omitempty"`
|
Token string `json:"token,omitempty"`
|
||||||
|
ObsoleteToken string `json:"obsoleteToken,omitempty"`
|
||||||
|
|
||||||
Nonce string `json:"nonce,omitempty"`
|
Nonce string `json:"nonce,omitempty"`
|
||||||
|
|
||||||
|
@ -516,6 +517,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
|
||||||
return storage.RefreshToken{
|
return storage.RefreshToken{
|
||||||
ID: r.ObjectMeta.Name,
|
ID: r.ObjectMeta.Name,
|
||||||
Token: r.Token,
|
Token: r.Token,
|
||||||
|
ObsoleteToken: r.ObsoleteToken,
|
||||||
CreatedAt: r.CreatedAt,
|
CreatedAt: r.CreatedAt,
|
||||||
LastUsed: r.LastUsed,
|
LastUsed: r.LastUsed,
|
||||||
ClientID: r.ClientID,
|
ClientID: r.ClientID,
|
||||||
|
@ -538,6 +540,7 @@ func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken
|
||||||
Namespace: cli.namespace,
|
Namespace: cli.namespace,
|
||||||
},
|
},
|
||||||
Token: r.Token,
|
Token: r.Token,
|
||||||
|
ObsoleteToken: r.ObsoleteToken,
|
||||||
CreatedAt: r.CreatedAt,
|
CreatedAt: r.CreatedAt,
|
||||||
LastUsed: r.LastUsed,
|
LastUsed: r.LastUsed,
|
||||||
ClientID: r.ClientID,
|
ClientID: r.ClientID,
|
||||||
|
|
|
@ -285,16 +285,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||||
claims_user_id, claims_username, claims_preferred_username,
|
claims_user_id, claims_username, claims_preferred_username,
|
||||||
claims_email, claims_email_verified, claims_groups,
|
claims_email, claims_email_verified, claims_groups,
|
||||||
connector_id, connector_data,
|
connector_id, connector_data,
|
||||||
token, created_at, last_used
|
token, obsolete_token, created_at, last_used
|
||||||
)
|
)
|
||||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15);
|
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16);
|
||||||
`,
|
`,
|
||||||
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
|
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||||
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
|
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
|
||||||
r.Claims.Email, r.Claims.EmailVerified,
|
r.Claims.Email, r.Claims.EmailVerified,
|
||||||
encoder(r.Claims.Groups),
|
encoder(r.Claims.Groups),
|
||||||
r.ConnectorID, r.ConnectorData,
|
r.ConnectorID, r.ConnectorData,
|
||||||
r.Token, r.CreatedAt, r.LastUsed,
|
r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if c.alreadyExistsCheck(err) {
|
if c.alreadyExistsCheck(err) {
|
||||||
|
@ -329,17 +329,18 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
|
||||||
connector_id = $10,
|
connector_id = $10,
|
||||||
connector_data = $11,
|
connector_data = $11,
|
||||||
token = $12,
|
token = $12,
|
||||||
created_at = $13,
|
obsolete_token = $13,
|
||||||
last_used = $14
|
created_at = $14,
|
||||||
|
last_used = $15
|
||||||
where
|
where
|
||||||
id = $15
|
id = $16
|
||||||
`,
|
`,
|
||||||
r.ClientID, encoder(r.Scopes), r.Nonce,
|
r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||||
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
|
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
|
||||||
r.Claims.Email, r.Claims.EmailVerified,
|
r.Claims.Email, r.Claims.EmailVerified,
|
||||||
encoder(r.Claims.Groups),
|
encoder(r.Claims.Groups),
|
||||||
r.ConnectorID, r.ConnectorData,
|
r.ConnectorID, r.ConnectorData,
|
||||||
r.Token, r.CreatedAt, r.LastUsed, id,
|
r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, id,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("update refresh token: %v", err)
|
return fmt.Errorf("update refresh token: %v", err)
|
||||||
|
@ -360,7 +361,7 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) {
|
||||||
claims_email, claims_email_verified,
|
claims_email, claims_email_verified,
|
||||||
claims_groups,
|
claims_groups,
|
||||||
connector_id, connector_data,
|
connector_id, connector_data,
|
||||||
token, created_at, last_used
|
token, obsolete_token, created_at, last_used
|
||||||
from refresh_token where id = $1;
|
from refresh_token where id = $1;
|
||||||
`, id))
|
`, id))
|
||||||
}
|
}
|
||||||
|
@ -372,7 +373,7 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||||
claims_user_id, claims_username, claims_preferred_username,
|
claims_user_id, claims_username, claims_preferred_username,
|
||||||
claims_email, claims_email_verified, claims_groups,
|
claims_email, claims_email_verified, claims_groups,
|
||||||
connector_id, connector_data,
|
connector_id, connector_data,
|
||||||
token, created_at, last_used
|
token, obsolete_token, created_at, last_used
|
||||||
from refresh_token;
|
from refresh_token;
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -401,7 +402,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
|
||||||
&r.Claims.Email, &r.Claims.EmailVerified,
|
&r.Claims.Email, &r.Claims.EmailVerified,
|
||||||
decoder(&r.Claims.Groups),
|
decoder(&r.Claims.Groups),
|
||||||
&r.ConnectorID, &r.ConnectorData,
|
&r.ConnectorID, &r.ConnectorData,
|
||||||
&r.Token, &r.CreatedAt, &r.LastUsed,
|
&r.Token, &r.ObsoleteToken, &r.CreatedAt, &r.LastUsed,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
|
|
@ -274,4 +274,11 @@ var migrations = []migration{
|
||||||
add column code_challenge_method text not null default '';`,
|
add column code_challenge_method text not null default '';`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
stmts: []string{
|
||||||
|
`
|
||||||
|
alter table refresh_token
|
||||||
|
add column obsolete_token text default '';`,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -272,6 +272,7 @@ type RefreshToken struct {
|
||||||
//
|
//
|
||||||
// May be empty.
|
// May be empty.
|
||||||
Token string
|
Token string
|
||||||
|
ObsoleteToken string
|
||||||
|
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
LastUsed time.Time
|
LastUsed time.Time
|
||||||
|
|
Loading…
Reference in a new issue