server: use time.Time instead of int64 for token expiration time

This commit is contained in:
Moto Ishizawa 2016-09-12 18:52:50 +09:00
parent 25e4228e35
commit dc979c1d6d
4 changed files with 56 additions and 65 deletions

View file

@ -283,7 +283,7 @@ func TestServerCodeTokenCrossClient(t *testing.T) {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
jwt, token, expiresIn, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key)
jwt, token, expiresAt, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
@ -293,7 +293,7 @@ func TestServerCodeTokenCrossClient(t *testing.T) {
if token != tt.refreshToken {
t.Errorf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
}
if expiresIn == 0 {
if expiresAt.IsZero() {
t.Errorf("case %d: expect non-zero expiration time", i)
}

View file

@ -491,7 +491,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
var jwt *jose.JWT
var refreshToken string
var expiresIn int64
var expiresAt time.Time
grantType := r.PostForm.Get("grant_type")
switch grantType {
@ -502,14 +502,14 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
return
}
jwt, refreshToken, expiresIn, err = srv.CodeToken(creds, code)
jwt, refreshToken, expiresAt, err = srv.CodeToken(creds, code)
if err != nil {
log.Errorf("couldn't exchange code for token: %v", err)
writeTokenError(w, err, state)
return
}
case oauth2.GrantTypeClientCreds:
jwt, expiresIn, err = srv.ClientCredsToken(creds)
jwt, expiresAt, err = srv.ClientCredsToken(creds)
if err != nil {
log.Errorf("couldn't creds for token: %v", err)
writeTokenError(w, err, state)
@ -522,7 +522,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
return
}
jwt, refreshToken, expiresIn, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
jwt, refreshToken, expiresAt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
if err != nil {
writeTokenError(w, err, state)
return
@ -538,7 +538,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
IDToken: jwt.Encode(),
TokenType: "bearer",
RefreshToken: refreshToken,
ExpiresIn: expiresIn,
ExpiresIn: int64(expiresAt.Sub(time.Now()).Seconds()),
}
b, err := json.Marshal(t)

View file

@ -49,13 +49,13 @@ type OIDCServer interface {
Login(oidc.Identity, string) (string, error)
// CodeToken exchanges a code for an ID token and a refresh token string on success.
CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, int64, error)
CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, time.Time, error)
ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, int64, error)
ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, time.Time, error)
// RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token
// if the token is valid.
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, int64, error)
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error)
KillSession(string) error
@ -466,29 +466,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
return ru.String(), nil
}
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, int64, error) {
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, time.Time, error) {
cli, err := s.Client(creds.ID)
if err != nil {
return nil, 0, err
return nil, time.Time{}, err
}
if cli.Public {
return nil, 0, oauth2.NewError(oauth2.ErrorInvalidClient)
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
}
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
return nil, 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
if !ok {
return nil, 0, oauth2.NewError(oauth2.ErrorInvalidClient)
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
}
signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
now := time.Now()
@ -499,52 +499,49 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, int6
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
// Generate expires_in value
expiresIn := int64(exp.Sub(now).Seconds())
log.Infof("Client token sent: clientID=%s", creds.ID)
return jwt, expiresIn, nil
return jwt, exp, nil
}
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, int64, error) {
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, time.Time, error) {
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
if !ok {
log.Errorf("Failed to Authenticate client %s", creds.ID)
return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidClient)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
}
sessionID, err := s.SessionManager.ExchangeKey(sessionKey)
if err != nil {
return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidGrant)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidGrant)
}
ses, err := s.SessionManager.Kill(sessionID)
if err != nil {
return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidRequest)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
}
if ses.ClientID != creds.ID {
return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidGrant)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidGrant)
}
signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
user, err := s.UserRepo.Get(nil, ses.UserID)
if err != nil {
log.Errorf("Failed to fetch user %q from repo: %v: ", ses.UserID, err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
claims := ses.Claims(s.IssuerURL.String())
@ -555,7 +552,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
// Generate refresh token when 'scope' contains 'offline_access'.
@ -571,28 +568,25 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
break
default:
log.Errorf("Failed to generate refresh token: %v", err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
break
}
}
// Generate expires_in value
expiresIn := int64(ses.ExpiresAt.Sub(ses.CreatedAt).Seconds())
log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID)
return jwt, refreshToken, expiresIn, nil
return jwt, refreshToken, ses.ExpiresAt, nil
}
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, int64, error) {
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error) {
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
if !ok {
log.Errorf("Failed to Authenticate client %s", creds.ID)
return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidClient)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
}
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
@ -600,18 +594,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
case nil:
break
case refresh.ErrorInvalidToken:
return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidRequest)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
case refresh.ErrorInvalidClientID:
return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidClient)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
default:
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
if len(scopes) == 0 {
scopes = rtScopes
} else {
if !rtScopes.Contains(scopes) {
return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidRequest)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
}
}
@ -620,7 +614,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
// The error can be user.ErrorNotFound, but we are not deleting
// user at this moment, so this shouldn't happen.
log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
var groups []string
@ -628,19 +622,19 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
conn, ok := s.connector(connectorID)
if !ok {
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
if err != nil {
log.Errorf("failed to get remote identities: %v", err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
for _, ri := range remoteIdentities {
@ -652,24 +646,24 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
}()
if !ok {
log.Errorf("failed to get remote identity for connector %s", connectorID)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
log.Errorf("failed to get groups for refresh token: %v", connectorID)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
}
signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to refresh ID token: %v", err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
now := time.Now()
expireAt := now.Add(session.DefaultSessionValidityWindow)
expiresAt := now.Add(session.DefaultSessionValidityWindow)
claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt)
claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expiresAt)
usr.AddToClaims(claims)
if rtScopes.HasScope(scope.ScopeGroups) {
if groups == nil {
@ -683,21 +677,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token)
if err != nil {
log.Errorf("Failed to generate new refresh token: %v", err)
return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
}
// Generate expires_in value
expiresIn := int64(expireAt.Sub(now).Seconds())
log.Infof("New token sent: clientID=%s", creds.ID)
return jwt, refreshToken, expiresIn, nil
return jwt, refreshToken, expiresAt, nil
}
func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {

View file

@ -443,7 +443,7 @@ func TestServerCodeToken(t *testing.T) {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
jwt, token, expiresIn, err := f.srv.CodeToken(oidc.ClientCredentials{
jwt, token, expiresAt, err := f.srv.CodeToken(oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret}, key)
if err != nil {
@ -455,7 +455,7 @@ func TestServerCodeToken(t *testing.T) {
if token != tt.refreshToken {
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
}
if expiresIn == 0 {
if expiresAt.IsZero() {
t.Fatalf("case %d: expect non-zero expiration time", i)
}
}
@ -478,7 +478,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
jwt, token, expiresIn, err := f.srv.CodeToken(testClientCredentials, "foo")
jwt, token, expiresAt, err := f.srv.CodeToken(testClientCredentials, "foo")
if err == nil {
t.Fatalf("Expected non-nil error")
}
@ -488,7 +488,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
if token != "" {
t.Fatalf("Expected empty refresh token")
}
if expiresIn != 0 {
if !expiresAt.IsZero() {
t.Fatalf("Expected zero expiration time")
}
}
@ -586,7 +586,7 @@ func TestServerTokenFail(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
jwt, token, expiresIn, err := f.srv.CodeToken(tt.argCC, tt.argKey)
jwt, token, expiresAt, err := f.srv.CodeToken(tt.argCC, tt.argKey)
if token != tt.refreshToken {
fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token)
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
@ -601,8 +601,8 @@ func TestServerTokenFail(t *testing.T) {
if err != nil && jwt != nil {
t.Errorf("case %d: got non-nil JWT %v", i, jwt)
}
if err == nil && expiresIn == 0 {
t.Errorf("case %d: got zero expiration time %v", i, expiresIn)
if err == nil && expiresAt.IsZero() {
t.Errorf("case %d: got zero expiration time %v", i, expiresAt)
}
}
}
@ -885,7 +885,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken)
}
if err == nil && expiresIn == 0 {
if err == nil && expiresIn.IsZero() {
t.Errorf("case %d: got zero expiration time %v", i, expiresIn)
}
}