Merge pull request #575 from summerwind/support-expires_in-field
server: add expires_in field to the response of token endpoint
This commit is contained in:
commit
9a78dca137
4 changed files with 68 additions and 49 deletions
|
@ -283,7 +283,7 @@ func TestServerCodeTokenCrossClient(t *testing.T) {
|
||||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwt, token, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key)
|
jwt, token, expiresAt, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||||
}
|
}
|
||||||
|
@ -293,6 +293,9 @@ func TestServerCodeTokenCrossClient(t *testing.T) {
|
||||||
if token != tt.refreshToken {
|
if token != tt.refreshToken {
|
||||||
t.Errorf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
|
t.Errorf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
|
||||||
}
|
}
|
||||||
|
if expiresAt.IsZero() {
|
||||||
|
t.Errorf("case %d: expect non-zero expiration time", i)
|
||||||
|
}
|
||||||
|
|
||||||
claims, err := jwt.Claims()
|
claims, err := jwt.Claims()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -491,6 +491,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
|
||||||
|
|
||||||
var jwt *jose.JWT
|
var jwt *jose.JWT
|
||||||
var refreshToken string
|
var refreshToken string
|
||||||
|
var expiresAt time.Time
|
||||||
grantType := r.PostForm.Get("grant_type")
|
grantType := r.PostForm.Get("grant_type")
|
||||||
|
|
||||||
switch grantType {
|
switch grantType {
|
||||||
|
@ -501,14 +502,14 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
|
||||||
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwt, refreshToken, err = srv.CodeToken(creds, code)
|
jwt, refreshToken, expiresAt, err = srv.CodeToken(creds, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("couldn't exchange code for token: %v", err)
|
log.Errorf("couldn't exchange code for token: %v", err)
|
||||||
writeTokenError(w, err, state)
|
writeTokenError(w, err, state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case oauth2.GrantTypeClientCreds:
|
case oauth2.GrantTypeClientCreds:
|
||||||
jwt, err = srv.ClientCredsToken(creds)
|
jwt, expiresAt, err = srv.ClientCredsToken(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("couldn't creds for token: %v", err)
|
log.Errorf("couldn't creds for token: %v", err)
|
||||||
writeTokenError(w, err, state)
|
writeTokenError(w, err, state)
|
||||||
|
@ -521,7 +522,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
|
||||||
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwt, refreshToken, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
|
jwt, refreshToken, expiresAt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeTokenError(w, err, state)
|
writeTokenError(w, err, state)
|
||||||
return
|
return
|
||||||
|
@ -537,6 +538,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
|
||||||
IDToken: jwt.Encode(),
|
IDToken: jwt.Encode(),
|
||||||
TokenType: "bearer",
|
TokenType: "bearer",
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
|
ExpiresIn: int64(expiresAt.Sub(time.Now()).Seconds()),
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := json.Marshal(t)
|
b, err := json.Marshal(t)
|
||||||
|
@ -594,6 +596,7 @@ type oAuth2Token struct {
|
||||||
IDToken string `json:"id_token"`
|
IDToken string `json:"id_token"`
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func createLastSeenCookie() *http.Cookie {
|
func createLastSeenCookie() *http.Cookie {
|
||||||
|
|
|
@ -49,13 +49,13 @@ type OIDCServer interface {
|
||||||
Login(oidc.Identity, string) (string, error)
|
Login(oidc.Identity, string) (string, error)
|
||||||
|
|
||||||
// CodeToken exchanges a code for an ID token and a refresh token string on success.
|
// CodeToken exchanges a code for an ID token and a refresh token string on success.
|
||||||
CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error)
|
CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, time.Time, error)
|
||||||
|
|
||||||
ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error)
|
ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, time.Time, error)
|
||||||
|
|
||||||
// RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token
|
// RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token
|
||||||
// if the token is valid.
|
// if the token is valid.
|
||||||
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error)
|
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error)
|
||||||
|
|
||||||
KillSession(string) error
|
KillSession(string) error
|
||||||
|
|
||||||
|
@ -466,29 +466,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
||||||
return ru.String(), nil
|
return ru.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
|
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, time.Time, error) {
|
||||||
cli, err := s.Client(creds.ID)
|
cli, err := s.Client(creds.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, time.Time{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if cli.Public {
|
if cli.Public {
|
||||||
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := s.ClientManager.Authenticate(creds)
|
ok, err := s.ClientManager.Authenticate(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
|
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
signer, err := s.KeyManager.Signer()
|
signer, err := s.KeyManager.Signer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to generate ID token: %v", err)
|
log.Errorf("Failed to generate ID token: %v", err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
@ -499,49 +499,49 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro
|
||||||
jwt, err := jose.NewSignedJWT(claims, signer)
|
jwt, err := jose.NewSignedJWT(claims, signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to generate ID token: %v", err)
|
log.Errorf("Failed to generate ID token: %v", err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Client token sent: clientID=%s", creds.ID)
|
log.Infof("Client token sent: clientID=%s", creds.ID)
|
||||||
|
|
||||||
return jwt, nil
|
return jwt, exp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
|
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, time.Time, error) {
|
||||||
ok, err := s.ClientManager.Authenticate(creds)
|
ok, err := s.ClientManager.Authenticate(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("Failed to Authenticate client %s", creds.ID)
|
log.Errorf("Failed to Authenticate client %s", creds.ID)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := s.SessionManager.ExchangeKey(sessionKey)
|
sessionID, err := s.SessionManager.ExchangeKey(sessionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidGrant)
|
||||||
}
|
}
|
||||||
|
|
||||||
ses, err := s.SessionManager.Kill(sessionID)
|
ses, err := s.SessionManager.Kill(sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ses.ClientID != creds.ID {
|
if ses.ClientID != creds.ID {
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidGrant)
|
||||||
}
|
}
|
||||||
|
|
||||||
signer, err := s.KeyManager.Signer()
|
signer, err := s.KeyManager.Signer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to generate ID token: %v", err)
|
log.Errorf("Failed to generate ID token: %v", err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.UserRepo.Get(nil, ses.UserID)
|
user, err := s.UserRepo.Get(nil, ses.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to fetch user %q from repo: %v: ", ses.UserID, err)
|
log.Errorf("Failed to fetch user %q from repo: %v: ", ses.UserID, err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims := ses.Claims(s.IssuerURL.String())
|
claims := ses.Claims(s.IssuerURL.String())
|
||||||
|
@ -552,7 +552,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
||||||
jwt, err := jose.NewSignedJWT(claims, signer)
|
jwt, err := jose.NewSignedJWT(claims, signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to generate ID token: %v", err)
|
log.Errorf("Failed to generate ID token: %v", err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate refresh token when 'scope' contains 'offline_access'.
|
// Generate refresh token when 'scope' contains 'offline_access'.
|
||||||
|
@ -568,25 +568,25 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
log.Errorf("Failed to generate refresh token: %v", err)
|
log.Errorf("Failed to generate refresh token: %v", err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID)
|
log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID)
|
||||||
return jwt, refreshToken, nil
|
return jwt, refreshToken, ses.ExpiresAt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error) {
|
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error) {
|
||||||
ok, err := s.ClientManager.Authenticate(creds)
|
ok, err := s.ClientManager.Authenticate(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("Failed to Authenticate client %s", creds.ID)
|
log.Errorf("Failed to Authenticate client %s", creds.ID)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
|
||||||
|
@ -594,18 +594,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
case nil:
|
case nil:
|
||||||
break
|
break
|
||||||
case refresh.ErrorInvalidToken:
|
case refresh.ErrorInvalidToken:
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
|
||||||
case refresh.ErrorInvalidClientID:
|
case refresh.ErrorInvalidClientID:
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
default:
|
default:
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(scopes) == 0 {
|
if len(scopes) == 0 {
|
||||||
scopes = rtScopes
|
scopes = rtScopes
|
||||||
} else {
|
} else {
|
||||||
if !rtScopes.Contains(scopes) {
|
if !rtScopes.Contains(scopes) {
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -614,7 +614,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
// The error can be user.ErrorNotFound, but we are not deleting
|
// The error can be user.ErrorNotFound, but we are not deleting
|
||||||
// user at this moment, so this shouldn't happen.
|
// user at this moment, so this shouldn't happen.
|
||||||
log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
|
log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
var groups []string
|
var groups []string
|
||||||
|
@ -622,19 +622,19 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
conn, ok := s.connector(connectorID)
|
conn, ok := s.connector(connectorID)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
|
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
grouper, ok := conn.(connector.GroupsConnector)
|
grouper, ok := conn.(connector.GroupsConnector)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
|
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
|
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to get remote identities: %v", err)
|
log.Errorf("failed to get remote identities: %v", err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
|
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
|
||||||
for _, ri := range remoteIdentities {
|
for _, ri := range remoteIdentities {
|
||||||
|
@ -646,24 +646,24 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
}()
|
}()
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("failed to get remote identity for connector %s", connectorID)
|
log.Errorf("failed to get remote identity for connector %s", connectorID)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
|
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
|
||||||
log.Errorf("failed to get groups for refresh token: %v", connectorID)
|
log.Errorf("failed to get groups for refresh token: %v", connectorID)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
signer, err := s.KeyManager.Signer()
|
signer, err := s.KeyManager.Signer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to refresh ID token: %v", err)
|
log.Errorf("Failed to refresh ID token: %v", err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
expireAt := now.Add(session.DefaultSessionValidityWindow)
|
expiresAt := now.Add(session.DefaultSessionValidityWindow)
|
||||||
|
|
||||||
claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt)
|
claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expiresAt)
|
||||||
usr.AddToClaims(claims)
|
usr.AddToClaims(claims)
|
||||||
if rtScopes.HasScope(scope.ScopeGroups) {
|
if rtScopes.HasScope(scope.ScopeGroups) {
|
||||||
if groups == nil {
|
if groups == nil {
|
||||||
|
@ -677,18 +677,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
|
||||||
jwt, err := jose.NewSignedJWT(claims, signer)
|
jwt, err := jose.NewSignedJWT(claims, signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to generate ID token: %v", err)
|
log.Errorf("Failed to generate ID token: %v", err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token)
|
refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to generate new refresh token: %v", err)
|
log.Errorf("Failed to generate new refresh token: %v", err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("New token sent: clientID=%s", creds.ID)
|
log.Infof("New token sent: clientID=%s", creds.ID)
|
||||||
|
|
||||||
return jwt, refreshToken, nil
|
return jwt, refreshToken, expiresAt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {
|
func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {
|
||||||
|
|
|
@ -443,7 +443,7 @@ func TestServerCodeToken(t *testing.T) {
|
||||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwt, token, err := f.srv.CodeToken(oidc.ClientCredentials{
|
jwt, token, expiresAt, err := f.srv.CodeToken(oidc.ClientCredentials{
|
||||||
ID: testClientID,
|
ID: testClientID,
|
||||||
Secret: clientTestSecret}, key)
|
Secret: clientTestSecret}, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -455,6 +455,9 @@ func TestServerCodeToken(t *testing.T) {
|
||||||
if token != tt.refreshToken {
|
if token != tt.refreshToken {
|
||||||
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
|
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
|
||||||
}
|
}
|
||||||
|
if expiresAt.IsZero() {
|
||||||
|
t.Fatalf("case %d: expect non-zero expiration time", i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -475,7 +478,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwt, token, err := f.srv.CodeToken(testClientCredentials, "foo")
|
jwt, token, expiresAt, err := f.srv.CodeToken(testClientCredentials, "foo")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Expected non-nil error")
|
t.Fatalf("Expected non-nil error")
|
||||||
}
|
}
|
||||||
|
@ -485,6 +488,9 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
|
||||||
if token != "" {
|
if token != "" {
|
||||||
t.Fatalf("Expected empty refresh token")
|
t.Fatalf("Expected empty refresh token")
|
||||||
}
|
}
|
||||||
|
if !expiresAt.IsZero() {
|
||||||
|
t.Fatalf("Expected zero expiration time")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerTokenFail(t *testing.T) {
|
func TestServerTokenFail(t *testing.T) {
|
||||||
|
@ -580,7 +586,7 @@ func TestServerTokenFail(t *testing.T) {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwt, token, err := f.srv.CodeToken(tt.argCC, tt.argKey)
|
jwt, token, expiresAt, err := f.srv.CodeToken(tt.argCC, tt.argKey)
|
||||||
if token != tt.refreshToken {
|
if token != tt.refreshToken {
|
||||||
fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token)
|
fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token)
|
||||||
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
|
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
|
||||||
|
@ -595,6 +601,9 @@ func TestServerTokenFail(t *testing.T) {
|
||||||
if err != nil && jwt != nil {
|
if err != nil && jwt != nil {
|
||||||
t.Errorf("case %d: got non-nil JWT %v", i, jwt)
|
t.Errorf("case %d: got non-nil JWT %v", i, jwt)
|
||||||
}
|
}
|
||||||
|
if err == nil && expiresAt.IsZero() {
|
||||||
|
t.Errorf("case %d: got zero expiration time %v", i, expiresAt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -835,7 +844,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwt, refreshToken, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token)
|
jwt, refreshToken, expiresIn, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token)
|
||||||
if !reflect.DeepEqual(err, tt.err) {
|
if !reflect.DeepEqual(err, tt.err) {
|
||||||
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
|
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
|
||||||
}
|
}
|
||||||
|
@ -875,5 +884,9 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
if diff := pretty.Compare(refreshToken, tt.expectedRefreshToken); diff != "" {
|
if diff := pretty.Compare(refreshToken, tt.expectedRefreshToken); diff != "" {
|
||||||
t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken)
|
t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err == nil && expiresIn.IsZero() {
|
||||||
|
t.Errorf("case %d: got zero expiration time %v", i, expiresIn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue