refresh token rotation

Update refresh token flow to revoke old refresh token and generates a new one.

Fixes #519
This commit is contained in:
Rubén Soleto Buenvarón 2016-08-08 12:17:01 +02:00
parent 44295706ea
commit c91b37aa9e
5 changed files with 198 additions and 132 deletions

View file

@ -91,103 +91,52 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
} }
func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) { func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" { return r.create(nil, userID, clientID, connectorID, scopes)
return "", refresh.ErrorInvalidUserID
}
if clientID == "" {
return "", refresh.ErrorInvalidClientID
}
// TODO(yifan): Check the number of tokens given to the client-user pair.
tokenPayload, err := r.tokenGenerator.Generate()
if err != nil {
return "", err
}
payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
if err != nil {
return "", err
}
record := &refreshTokenModel{
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "),
}
if err := r.executor(nil).Insert(record); err != nil {
return "", err
}
return buildToken(record.ID, tokenPayload), nil
} }
func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) { func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
tokenID, tokenPayload, err := parseToken(token) return r.verify(nil, clientID, token)
if err != nil {
return
}
record, err := r.get(nil, tokenID)
if err != nil {
return
}
if record.ClientID != clientID {
return "", "", nil, refresh.ErrorInvalidClientID
}
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return
}
var scopes []string
if len(record.Scopes) > 0 {
scopes = strings.Split(record.Scopes, " ")
}
return record.UserID, record.ConnectorID, scopes, nil
} }
func (r *refreshTokenRepo) Revoke(userID, token string) error { func (r *refreshTokenRepo) Revoke(userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return err
}
tx, err := r.begin() tx, err := r.begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
exec := r.executor(tx) if err := r.revoke(tx, userID, token); err != nil {
record, err := r.get(tx, tokenID)
if err != nil {
return err return err
} }
if record.UserID != userID {
return refresh.ErrorInvalidUserID
}
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return err
}
deleted, err := exec.Delete(record)
if err != nil {
return err
}
if deleted == 0 {
return refresh.ErrorInvalidToken
}
return tx.Commit() return tx.Commit()
} }
func (r *refreshTokenRepo) RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error) {
// Verify
userID, connectorID, scopes, err := r.verify(nil, clientID, oldToken)
if err != nil {
return "", err
}
// Revoke old refresh token
tx, err := r.begin()
if err != nil {
return "", err
}
defer tx.Rollback()
if err := r.revoke(tx, userID, oldToken); err != nil {
return "", err
}
// Renew refresh token
newRefreshToken, err = r.create(tx, userID, clientID, connectorID, scopes)
if err != nil {
return "", err
}
return newRefreshToken, tx.Commit()
}
func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error { func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error {
q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName)) q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName))
_, err := r.executor(nil).Exec(q, userID, clientID) _, err := r.executor(nil).Exec(q, userID, clientID)
@ -235,3 +184,97 @@ func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshToke
} }
return record, nil return record, nil
} }
func (r *refreshTokenRepo) verify(tx repo.Transaction, clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return
}
record, err := r.get(tx, tokenID)
if err != nil {
return
}
if record.ClientID != clientID {
return "", "", nil, refresh.ErrorInvalidClientID
}
// Check if the hash of token received is the same stored in database
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return
}
var scopes []string
if len(record.Scopes) > 0 {
scopes = strings.Split(record.Scopes, " ")
}
return record.UserID, record.ConnectorID, scopes, nil
}
func (r *refreshTokenRepo) create(tx repo.Transaction, userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
}
if clientID == "" {
return "", refresh.ErrorInvalidClientID
}
// TODO(yifan): Check the number of tokens given to the client-user pair.
tokenPayload, err := r.tokenGenerator.Generate()
if err != nil {
return "", err
}
payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
if err != nil {
return "", err
}
record := &refreshTokenModel{
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "),
}
if err := r.executor(tx).Insert(record); err != nil {
return "", err
}
return buildToken(record.ID, tokenPayload), nil
}
func (r *refreshTokenRepo) revoke(tx repo.Transaction, userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return err
}
exec := r.executor(tx)
record, err := r.get(tx, tokenID)
if err != nil {
return err
}
if record.UserID != userID {
return refresh.ErrorInvalidUserID
}
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return err
}
deleted, err := exec.Delete(record)
if err != nil {
return err
}
if deleted == 0 {
return refresh.ErrorInvalidToken
}
return nil
}

View file

@ -54,6 +54,9 @@ type RefreshTokenRepo interface {
// Revoke deletes the refresh token if the token belongs to the given userID. // Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error Revoke(userID, token string) error
// Revoke old refresh token and generates a new one
RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error)
// RevokeTokensForClient revokes all tokens issued for the userID for the provided client. // RevokeTokensForClient revokes all tokens issued for the userID for the provided client.
RevokeTokensForClient(userID, clientID string) error RevokeTokensForClient(userID, clientID string) error

View file

@ -520,7 +520,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
return return
} }
jwt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token) jwt, refreshToken, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
if err != nil { if err != nil {
writeTokenError(w, err, state) writeTokenError(w, err, state)
return return

View file

@ -53,9 +53,9 @@ type OIDCServer interface {
ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error)
// RefreshToken takes a previously generated refresh token and returns a new ID token // RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token
// if the token is valid. // if the token is valid.
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error)
KillSession(string) error KillSession(string) error
@ -567,15 +567,15 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
return jwt, refreshToken, nil return jwt, refreshToken, nil
} }
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) { func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error) {
ok, err := s.ClientManager.Authenticate(creds) ok, err := s.ClientManager.Authenticate(creds)
if err != nil { if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
if !ok { if !ok {
log.Errorf("Failed to Authenticate client %s", creds.ID) log.Errorf("Failed to Authenticate client %s", creds.ID)
return nil, oauth2.NewError(oauth2.ErrorInvalidClient) return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
} }
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
@ -583,18 +583,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
case nil: case nil:
break break
case refresh.ErrorInvalidToken: case refresh.ErrorInvalidToken:
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest) return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
case refresh.ErrorInvalidClientID: case refresh.ErrorInvalidClientID:
return nil, oauth2.NewError(oauth2.ErrorInvalidClient) return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
default: default:
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
if len(scopes) == 0 { if len(scopes) == 0 {
scopes = rtScopes scopes = rtScopes
} else { } else {
if !rtScopes.Contains(scopes) { if !rtScopes.Contains(scopes) {
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest) return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
} }
} }
@ -603,7 +603,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
// The error can be user.ErrorNotFound, but we are not deleting // The error can be user.ErrorNotFound, but we are not deleting
// user at this moment, so this shouldn't happen. // user at this moment, so this shouldn't happen.
log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err) log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
var groups []string var groups []string
@ -611,19 +611,19 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
conn, ok := s.connector(connectorID) conn, ok := s.connector(connectorID)
if !ok { if !ok {
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID) log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
grouper, ok := conn.(connector.GroupsConnector) grouper, ok := conn.(connector.GroupsConnector)
if !ok { if !ok {
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID) log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID) remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
if err != nil { if err != nil {
log.Errorf("failed to get remote identities: %v", err) log.Errorf("failed to get remote identities: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
remoteIdentity, ok := func() (user.RemoteIdentity, bool) { remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
for _, ri := range remoteIdentities { for _, ri := range remoteIdentities {
@ -635,18 +635,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
}() }()
if !ok { if !ok {
log.Errorf("failed to get remote identity for connector %s", connectorID) log.Errorf("failed to get remote identity for connector %s", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil { if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
log.Errorf("failed to get groups for refresh token: %v", connectorID) log.Errorf("failed to get groups for refresh token: %v", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
} }
signer, err := s.KeyManager.Signer() signer, err := s.KeyManager.Signer()
if err != nil { if err != nil {
log.Errorf("Failed to refresh ID token: %v", err) log.Errorf("Failed to refresh ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
now := time.Now() now := time.Now()
@ -666,12 +666,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
jwt, err := jose.NewSignedJWT(claims, signer) jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil { if err != nil {
log.Errorf("Failed to generate ID token: %v", err) log.Errorf("Failed to generate ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token)
if err != nil {
log.Errorf("Failed to generate new refresh token: %v", err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
} }
log.Infof("New token sent: clientID=%s", creds.ID) log.Infof("New token sent: clientID=%s", creds.ID)
return jwt, nil return jwt, refreshToken, nil
} }
func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) { func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {

View file

@ -107,6 +107,10 @@ func makeNewUserRepo() (user.UserRepo, error) {
return userRepo, nil return userRepo, nil
} }
func getRefreshTokenEncoded(id, value string) string {
return fmt.Sprintf("%v/%s", id, base64.URLEncoding.EncodeToString([]byte(value)))
}
func TestServerProviderConfig(t *testing.T) { func TestServerProviderConfig(t *testing.T) {
srv := &Server{IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}} srv := &Server{IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}}
@ -612,27 +616,29 @@ func TestServerRefreshToken(t *testing.T) {
// NOTE(ericchiang): These tests assume that the database ID of the first // NOTE(ericchiang): These tests assume that the database ID of the first
// refresh token will be "1". // refresh token will be "1".
tests := []struct { tests := []struct {
token string token string
clientID string // The client that associates with the token. expectedRefreshToken string
creds oidc.ClientCredentials clientID string // The client that associates with the token.
signer jose.Signer creds oidc.ClientCredentials
createScopes []string signer jose.Signer
refreshScopes []string createScopes []string
expectedAud []string refreshScopes []string
err error expectedAud []string
err error
}{ }{
// Everything is good. // Everything is good.
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: testClientID, expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
creds: testClientCredentials, clientID: testClientID,
signer: signerFixture, creds: testClientCredentials,
createScopes: []string{"openid", "profile"}, signer: signerFixture,
refreshScopes: []string{"openid", "profile"}, createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
}, },
// Asking for a scope not originally granted to you. // Asking for a scope not originally granted to you.
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: testClientID, clientID: testClientID,
creds: testClientCredentials, creds: testClientCredentials,
signer: signerFixture, signer: signerFixture,
@ -652,7 +658,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid refresh token(invalid payload content). // Invalid refresh token(invalid payload content).
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), token: getRefreshTokenEncoded("1", "refresh-2"),
clientID: testClientID, clientID: testClientID,
creds: testClientCredentials, creds: testClientCredentials,
signer: signerFixture, signer: signerFixture,
@ -662,7 +668,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid refresh token(invalid ID content). // Invalid refresh token(invalid ID content).
{ {
token: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("0", "refresh-1"),
clientID: testClientID, clientID: testClientID,
creds: testClientCredentials, creds: testClientCredentials,
signer: signerFixture, signer: signerFixture,
@ -672,7 +678,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(client is not associated with the token). // Invalid client(client is not associated with the token).
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: testClientID, clientID: testClientID,
creds: clientB.Credentials, creds: clientB.Credentials,
signer: signerFixture, signer: signerFixture,
@ -682,7 +688,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no client ID). // Invalid client(no client ID).
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: testClientID, clientID: testClientID,
creds: oidc.ClientCredentials{ID: "", Secret: "aaa"}, creds: oidc.ClientCredentials{ID: "", Secret: "aaa"},
signer: signerFixture, signer: signerFixture,
@ -692,7 +698,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no such client). // Invalid client(no such client).
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: testClientID, clientID: testClientID,
creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signer: signerFixture, signer: signerFixture,
@ -702,7 +708,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no secrets). // Invalid client(no secrets).
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: testClientID, clientID: testClientID,
creds: oidc.ClientCredentials{ID: testClientID}, creds: oidc.ClientCredentials{ID: testClientID},
signer: signerFixture, signer: signerFixture,
@ -712,7 +718,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(invalid secret). // Invalid client(invalid secret).
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: testClientID, clientID: testClientID,
creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
signer: signerFixture, signer: signerFixture,
@ -722,7 +728,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Signing operation fails. // Signing operation fails.
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: testClientID, clientID: testClientID,
creds: testClientCredentials, creds: testClientCredentials,
signer: &StaticSigner{sig: nil, err: errors.New("fail")}, signer: &StaticSigner{sig: nil, err: errors.New("fail")},
@ -732,8 +738,9 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Valid Cross-Client // Valid Cross-Client
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: "client_a", expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
clientID: "client_a",
creds: oidc.ClientCredentials{ creds: oidc.ClientCredentials{
ID: "client_a", ID: "client_a",
Secret: base64.URLEncoding.EncodeToString( Secret: base64.URLEncoding.EncodeToString(
@ -748,8 +755,9 @@ func TestServerRefreshToken(t *testing.T) {
// refresh request, which should result in the original stored scopes // refresh request, which should result in the original stored scopes
// being used. // being used.
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: "client_a", expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
clientID: "client_a",
creds: oidc.ClientCredentials{ creds: oidc.ClientCredentials{
ID: "client_a", ID: "client_a",
Secret: base64.URLEncoding.EncodeToString( Secret: base64.URLEncoding.EncodeToString(
@ -763,8 +771,9 @@ func TestServerRefreshToken(t *testing.T) {
// Valid Cross-Client - asking for fewer scopes than originally used // Valid Cross-Client - asking for fewer scopes than originally used
// when creating the refresh token, which is ok. // when creating the refresh token, which is ok.
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: "client_a", expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
clientID: "client_a",
creds: oidc.ClientCredentials{ creds: oidc.ClientCredentials{
ID: "client_a", ID: "client_a",
Secret: base64.URLEncoding.EncodeToString( Secret: base64.URLEncoding.EncodeToString(
@ -777,8 +786,9 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Valid Cross-Client - asking for multiple clients in the audience. // Valid Cross-Client - asking for multiple clients in the audience.
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: "client_a", expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
clientID: "client_a",
creds: oidc.ClientCredentials{ creds: oidc.ClientCredentials{
ID: "client_a", ID: "client_a",
Secret: base64.URLEncoding.EncodeToString( Secret: base64.URLEncoding.EncodeToString(
@ -792,7 +802,7 @@ func TestServerRefreshToken(t *testing.T) {
// Invalid Cross-Client - didn't orignally request cross-client when // Invalid Cross-Client - didn't orignally request cross-client when
// refresh token was created. // refresh token was created.
{ {
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: getRefreshTokenEncoded("1", "refresh-1"),
clientID: "client_a", clientID: "client_a",
creds: oidc.ClientCredentials{ creds: oidc.ClientCredentials{
ID: "client_a", ID: "client_a",
@ -825,7 +835,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
jwt, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token) jwt, refreshToken, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token)
if !reflect.DeepEqual(err, tt.err) { if !reflect.DeepEqual(err, tt.err) {
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
} }
@ -861,5 +871,9 @@ func TestServerRefreshToken(t *testing.T) {
expectedAud, claims["aud"]) expectedAud, claims["aud"])
} }
} }
if diff := pretty.Compare(refreshToken, tt.expectedRefreshToken); diff != "" {
t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken)
}
} }
} }