refresh token rotation
Update refresh token flow to revoke old refresh token and generates a new one. Fixes #519
This commit is contained in:
parent
44295706ea
commit
c91b37aa9e
5 changed files with 198 additions and 132 deletions
203
db/refresh.go
203
db/refresh.go
|
@ -91,103 +91,52 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
|
func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
|
||||||
if userID == "" {
|
return r.create(nil, userID, clientID, connectorID, scopes)
|
||||||
return "", refresh.ErrorInvalidUserID
|
|
||||||
}
|
|
||||||
if clientID == "" {
|
|
||||||
return "", refresh.ErrorInvalidClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(yifan): Check the number of tokens given to the client-user pair.
|
|
||||||
tokenPayload, err := r.tokenGenerator.Generate()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
record := &refreshTokenModel{
|
|
||||||
PayloadHash: payloadHash,
|
|
||||||
UserID: userID,
|
|
||||||
ClientID: clientID,
|
|
||||||
ConnectorID: connectorID,
|
|
||||||
Scopes: strings.Join(scopes, " "),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.executor(nil).Insert(record); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return buildToken(record.ID, tokenPayload), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
|
func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
|
||||||
tokenID, tokenPayload, err := parseToken(token)
|
return r.verify(nil, clientID, token)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
record, err := r.get(nil, tokenID)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if record.ClientID != clientID {
|
|
||||||
return "", "", nil, refresh.ErrorInvalidClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var scopes []string
|
|
||||||
if len(record.Scopes) > 0 {
|
|
||||||
scopes = strings.Split(record.Scopes, " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
return record.UserID, record.ConnectorID, scopes, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
||||||
tokenID, tokenPayload, err := parseToken(token)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := r.begin()
|
tx, err := r.begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
exec := r.executor(tx)
|
if err := r.revoke(tx, userID, token); err != nil {
|
||||||
record, err := r.get(tx, tokenID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if record.UserID != userID {
|
|
||||||
return refresh.ErrorInvalidUserID
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
deleted, err := exec.Delete(record)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if deleted == 0 {
|
|
||||||
return refresh.ErrorInvalidToken
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *refreshTokenRepo) RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error) {
|
||||||
|
// Verify
|
||||||
|
userID, connectorID, scopes, err := r.verify(nil, clientID, oldToken)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke old refresh token
|
||||||
|
tx, err := r.begin()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if err := r.revoke(tx, userID, oldToken); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Renew refresh token
|
||||||
|
newRefreshToken, err = r.create(tx, userID, clientID, connectorID, scopes)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return newRefreshToken, tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error {
|
func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error {
|
||||||
q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName))
|
q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName))
|
||||||
_, err := r.executor(nil).Exec(q, userID, clientID)
|
_, err := r.executor(nil).Exec(q, userID, clientID)
|
||||||
|
@ -235,3 +184,97 @@ func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshToke
|
||||||
}
|
}
|
||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *refreshTokenRepo) verify(tx repo.Transaction, clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
|
||||||
|
tokenID, tokenPayload, err := parseToken(token)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := r.get(tx, tokenID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if record.ClientID != clientID {
|
||||||
|
return "", "", nil, refresh.ErrorInvalidClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the hash of token received is the same stored in database
|
||||||
|
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var scopes []string
|
||||||
|
if len(record.Scopes) > 0 {
|
||||||
|
scopes = strings.Split(record.Scopes, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return record.UserID, record.ConnectorID, scopes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refreshTokenRepo) create(tx repo.Transaction, userID, clientID, connectorID string, scopes []string) (string, error) {
|
||||||
|
if userID == "" {
|
||||||
|
return "", refresh.ErrorInvalidUserID
|
||||||
|
}
|
||||||
|
if clientID == "" {
|
||||||
|
return "", refresh.ErrorInvalidClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(yifan): Check the number of tokens given to the client-user pair.
|
||||||
|
tokenPayload, err := r.tokenGenerator.Generate()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
record := &refreshTokenModel{
|
||||||
|
PayloadHash: payloadHash,
|
||||||
|
UserID: userID,
|
||||||
|
ClientID: clientID,
|
||||||
|
ConnectorID: connectorID,
|
||||||
|
Scopes: strings.Join(scopes, " "),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.executor(tx).Insert(record); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildToken(record.ID, tokenPayload), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refreshTokenRepo) revoke(tx repo.Transaction, userID, token string) error {
|
||||||
|
tokenID, tokenPayload, err := parseToken(token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
exec := r.executor(tx)
|
||||||
|
record, err := r.get(tx, tokenID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if record.UserID != userID {
|
||||||
|
return refresh.ErrorInvalidUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted, err := exec.Delete(record)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if deleted == 0 {
|
||||||
|
return refresh.ErrorInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -54,6 +54,9 @@ type RefreshTokenRepo interface {
|
||||||
// Revoke deletes the refresh token if the token belongs to the given userID.
|
// Revoke deletes the refresh token if the token belongs to the given userID.
|
||||||
Revoke(userID, token string) error
|
Revoke(userID, token string) error
|
||||||
|
|
||||||
|
// Revoke old refresh token and generates a new one
|
||||||
|
RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error)
|
||||||
|
|
||||||
// RevokeTokensForClient revokes all tokens issued for the userID for the provided client.
|
// RevokeTokensForClient revokes all tokens issued for the userID for the provided client.
|
||||||
RevokeTokensForClient(userID, clientID string) error
|
RevokeTokensForClient(userID, clientID string) error
|
||||||
|
|
||||||
|
|
|
@ -520,7 +520,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
|
||||||
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
|
jwt, refreshToken, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeTokenError(w, err, state)
|
writeTokenError(w, err, state)
|
||||||
return
|
return
|
||||||
|
|
|
@ -53,9 +53,9 @@ type OIDCServer interface {
|
||||||
|
|
||||||
ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error)
|
ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error)
|
||||||
|
|
||||||
// RefreshToken takes a previously generated refresh token and returns a new ID token
|
// RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token
|
||||||
// if the token is valid.
|
// if the token is valid.
|
||||||
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error)
|
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error)
|
||||||
|
|
||||||
KillSession(string) error
|
KillSession(string) error
|
||||||
|
|
||||||
|
@ -567,15 +567,15 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
||||||
return jwt, refreshToken, nil
|
return jwt, refreshToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) {
|
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error) {
|
||||||
ok, err := s.ClientManager.Authenticate(creds)
|
ok, err := s.ClientManager.Authenticate(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("Failed to Authenticate client %s", creds.ID)
|
log.Errorf("Failed to Authenticate client %s", creds.ID)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
||||||
|
@ -583,18 +583,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
case nil:
|
case nil:
|
||||||
break
|
break
|
||||||
case refresh.ErrorInvalidToken:
|
case refresh.ErrorInvalidToken:
|
||||||
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
|
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
|
||||||
case refresh.ErrorInvalidClientID:
|
case refresh.ErrorInvalidClientID:
|
||||||
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
default:
|
default:
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(scopes) == 0 {
|
if len(scopes) == 0 {
|
||||||
scopes = rtScopes
|
scopes = rtScopes
|
||||||
} else {
|
} else {
|
||||||
if !rtScopes.Contains(scopes) {
|
if !rtScopes.Contains(scopes) {
|
||||||
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
|
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -603,7 +603,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
// The error can be user.ErrorNotFound, but we are not deleting
|
// The error can be user.ErrorNotFound, but we are not deleting
|
||||||
// user at this moment, so this shouldn't happen.
|
// user at this moment, so this shouldn't happen.
|
||||||
log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
|
log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
var groups []string
|
var groups []string
|
||||||
|
@ -611,19 +611,19 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
conn, ok := s.connector(connectorID)
|
conn, ok := s.connector(connectorID)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
|
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
grouper, ok := conn.(connector.GroupsConnector)
|
grouper, ok := conn.(connector.GroupsConnector)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
|
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
|
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to get remote identities: %v", err)
|
log.Errorf("failed to get remote identities: %v", err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
|
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
|
||||||
for _, ri := range remoteIdentities {
|
for _, ri := range remoteIdentities {
|
||||||
|
@ -635,18 +635,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
}()
|
}()
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("failed to get remote identity for connector %s", connectorID)
|
log.Errorf("failed to get remote identity for connector %s", connectorID)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
|
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
|
||||||
log.Errorf("failed to get groups for refresh token: %v", connectorID)
|
log.Errorf("failed to get groups for refresh token: %v", connectorID)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
signer, err := s.KeyManager.Signer()
|
signer, err := s.KeyManager.Signer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to refresh ID token: %v", err)
|
log.Errorf("Failed to refresh ID token: %v", err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
@ -666,12 +666,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
jwt, err := jose.NewSignedJWT(claims, signer)
|
jwt, err := jose.NewSignedJWT(claims, signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to generate ID token: %v", err)
|
log.Errorf("Failed to generate ID token: %v", err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to generate new refresh token: %v", err)
|
||||||
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("New token sent: clientID=%s", creds.ID)
|
log.Infof("New token sent: clientID=%s", creds.ID)
|
||||||
|
|
||||||
return jwt, nil
|
return jwt, refreshToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {
|
func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {
|
||||||
|
|
|
@ -107,6 +107,10 @@ func makeNewUserRepo() (user.UserRepo, error) {
|
||||||
return userRepo, nil
|
return userRepo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getRefreshTokenEncoded(id, value string) string {
|
||||||
|
return fmt.Sprintf("%v/%s", id, base64.URLEncoding.EncodeToString([]byte(value)))
|
||||||
|
}
|
||||||
|
|
||||||
func TestServerProviderConfig(t *testing.T) {
|
func TestServerProviderConfig(t *testing.T) {
|
||||||
srv := &Server{IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}}
|
srv := &Server{IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}}
|
||||||
|
|
||||||
|
@ -613,6 +617,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
// refresh token will be "1".
|
// refresh token will be "1".
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
token string
|
token string
|
||||||
|
expectedRefreshToken string
|
||||||
clientID string // The client that associates with the token.
|
clientID string // The client that associates with the token.
|
||||||
creds oidc.ClientCredentials
|
creds oidc.ClientCredentials
|
||||||
signer jose.Signer
|
signer jose.Signer
|
||||||
|
@ -623,7 +628,8 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
// Everything is good.
|
// Everything is good.
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
|
expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: testClientCredentials,
|
creds: testClientCredentials,
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
|
@ -632,7 +638,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Asking for a scope not originally granted to you.
|
// Asking for a scope not originally granted to you.
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: testClientCredentials,
|
creds: testClientCredentials,
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
|
@ -652,7 +658,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid refresh token(invalid payload content).
|
// Invalid refresh token(invalid payload content).
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
token: getRefreshTokenEncoded("1", "refresh-2"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: testClientCredentials,
|
creds: testClientCredentials,
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
|
@ -662,7 +668,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid refresh token(invalid ID content).
|
// Invalid refresh token(invalid ID content).
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("0", "refresh-1"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: testClientCredentials,
|
creds: testClientCredentials,
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
|
@ -672,7 +678,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(client is not associated with the token).
|
// Invalid client(client is not associated with the token).
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: clientB.Credentials,
|
creds: clientB.Credentials,
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
|
@ -682,7 +688,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(no client ID).
|
// Invalid client(no client ID).
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
creds: oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
|
@ -692,7 +698,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(no such client).
|
// Invalid client(no such client).
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
|
@ -702,7 +708,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(no secrets).
|
// Invalid client(no secrets).
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: oidc.ClientCredentials{ID: testClientID},
|
creds: oidc.ClientCredentials{ID: testClientID},
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
|
@ -712,7 +718,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Invalid client(invalid secret).
|
// Invalid client(invalid secret).
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
|
creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
|
@ -722,7 +728,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Signing operation fails.
|
// Signing operation fails.
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
clientID: testClientID,
|
clientID: testClientID,
|
||||||
creds: testClientCredentials,
|
creds: testClientCredentials,
|
||||||
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
||||||
|
@ -732,7 +738,8 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Valid Cross-Client
|
// Valid Cross-Client
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
|
expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
|
||||||
clientID: "client_a",
|
clientID: "client_a",
|
||||||
creds: oidc.ClientCredentials{
|
creds: oidc.ClientCredentials{
|
||||||
ID: "client_a",
|
ID: "client_a",
|
||||||
|
@ -748,7 +755,8 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
// refresh request, which should result in the original stored scopes
|
// refresh request, which should result in the original stored scopes
|
||||||
// being used.
|
// being used.
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
|
expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
|
||||||
clientID: "client_a",
|
clientID: "client_a",
|
||||||
creds: oidc.ClientCredentials{
|
creds: oidc.ClientCredentials{
|
||||||
ID: "client_a",
|
ID: "client_a",
|
||||||
|
@ -763,7 +771,8 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
// Valid Cross-Client - asking for fewer scopes than originally used
|
// Valid Cross-Client - asking for fewer scopes than originally used
|
||||||
// when creating the refresh token, which is ok.
|
// when creating the refresh token, which is ok.
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
|
expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
|
||||||
clientID: "client_a",
|
clientID: "client_a",
|
||||||
creds: oidc.ClientCredentials{
|
creds: oidc.ClientCredentials{
|
||||||
ID: "client_a",
|
ID: "client_a",
|
||||||
|
@ -777,7 +786,8 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
// Valid Cross-Client - asking for multiple clients in the audience.
|
// Valid Cross-Client - asking for multiple clients in the audience.
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
|
expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
|
||||||
clientID: "client_a",
|
clientID: "client_a",
|
||||||
creds: oidc.ClientCredentials{
|
creds: oidc.ClientCredentials{
|
||||||
ID: "client_a",
|
ID: "client_a",
|
||||||
|
@ -792,7 +802,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
// Invalid Cross-Client - didn't orignally request cross-client when
|
// Invalid Cross-Client - didn't orignally request cross-client when
|
||||||
// refresh token was created.
|
// refresh token was created.
|
||||||
{
|
{
|
||||||
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
token: getRefreshTokenEncoded("1", "refresh-1"),
|
||||||
clientID: "client_a",
|
clientID: "client_a",
|
||||||
creds: oidc.ClientCredentials{
|
creds: oidc.ClientCredentials{
|
||||||
ID: "client_a",
|
ID: "client_a",
|
||||||
|
@ -825,7 +835,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwt, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token)
|
jwt, refreshToken, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token)
|
||||||
if !reflect.DeepEqual(err, tt.err) {
|
if !reflect.DeepEqual(err, tt.err) {
|
||||||
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
|
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
|
||||||
}
|
}
|
||||||
|
@ -861,5 +871,9 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
expectedAud, claims["aud"])
|
expectedAud, claims["aud"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if diff := pretty.Compare(refreshToken, tt.expectedRefreshToken); diff != "" {
|
||||||
|
t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue