From dc979c1d6d1766c95867fc9138bd034a7a8910c5 Mon Sep 17 00:00:00 2001 From: Moto Ishizawa Date: Mon, 12 Sep 2016 18:52:50 +0900 Subject: [PATCH] server: use time.Time instead of int64 for token expiration time --- server/cross_client_test.go | 4 +- server/http.go | 10 ++-- server/server.go | 91 +++++++++++++++++-------------------- server/server_test.go | 16 +++---- 4 files changed, 56 insertions(+), 65 deletions(-) diff --git a/server/cross_client_test.go b/server/cross_client_test.go index cfc69c91..3468bde4 100644 --- a/server/cross_client_test.go +++ b/server/cross_client_test.go @@ -283,7 +283,7 @@ func TestServerCodeTokenCrossClient(t *testing.T) { t.Fatalf("case %d: unexpected error: %v", i, err) } - jwt, token, expiresIn, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key) + jwt, token, expiresAt, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key) if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } @@ -293,7 +293,7 @@ func TestServerCodeTokenCrossClient(t *testing.T) { if token != tt.refreshToken { t.Errorf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) } - if expiresIn == 0 { + if expiresAt.IsZero() { t.Errorf("case %d: expect non-zero expiration time", i) } diff --git a/server/http.go b/server/http.go index 7c71af30..218f1920 100644 --- a/server/http.go +++ b/server/http.go @@ -491,7 +491,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { var jwt *jose.JWT var refreshToken string - var expiresIn int64 + var expiresAt time.Time grantType := r.PostForm.Get("grant_type") switch grantType { @@ -502,14 +502,14 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return } - jwt, refreshToken, expiresIn, err = srv.CodeToken(creds, code) + jwt, refreshToken, expiresAt, err = srv.CodeToken(creds, code) if err != nil { log.Errorf("couldn't exchange code for token: %v", err) writeTokenError(w, err, state) return } case oauth2.GrantTypeClientCreds: - jwt, expiresIn, err = srv.ClientCredsToken(creds) + jwt, expiresAt, err = srv.ClientCredsToken(creds) if err != nil { log.Errorf("couldn't creds for token: %v", err) writeTokenError(w, err, state) @@ -522,7 +522,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return } - jwt, refreshToken, expiresIn, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token) + jwt, refreshToken, expiresAt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token) if err != nil { writeTokenError(w, err, state) return @@ -538,7 +538,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { IDToken: jwt.Encode(), TokenType: "bearer", RefreshToken: refreshToken, - ExpiresIn: expiresIn, + ExpiresIn: int64(expiresAt.Sub(time.Now()).Seconds()), } b, err := json.Marshal(t) diff --git a/server/server.go b/server/server.go index 20ca6de1..b3f2708e 100644 --- a/server/server.go +++ b/server/server.go @@ -49,13 +49,13 @@ type OIDCServer interface { Login(oidc.Identity, string) (string, error) // CodeToken exchanges a code for an ID token and a refresh token string on success. - CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, int64, error) + CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, time.Time, error) - ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, int64, error) + ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, time.Time, error) // RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token // if the token is valid. - RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, int64, error) + RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error) KillSession(string) error @@ -466,29 +466,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { return ru.String(), nil } -func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, int64, error) { +func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, time.Time, error) { cli, err := s.Client(creds.ID) if err != nil { - return nil, 0, err + return nil, time.Time{}, err } if cli.Public { - return nil, 0, oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) } ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err) - return nil, 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if !ok { - return nil, 0, oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } now := time.Now() @@ -499,52 +499,49 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, int6 jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } - // Generate expires_in value - expiresIn := int64(exp.Sub(now).Seconds()) - log.Infof("Client token sent: clientID=%s", creds.ID) - return jwt, expiresIn, nil + return jwt, exp, nil } -func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, int64, error) { +func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, time.Time, error) { ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if !ok { log.Errorf("Failed to Authenticate client %s", creds.ID) - return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) } sessionID, err := s.SessionManager.ExchangeKey(sessionKey) if err != nil { - return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidGrant) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidGrant) } ses, err := s.SessionManager.Kill(sessionID) if err != nil { - return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidRequest) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest) } if ses.ClientID != creds.ID { - return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidGrant) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidGrant) } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } user, err := s.UserRepo.Get(nil, ses.UserID) if err != nil { log.Errorf("Failed to fetch user %q from repo: %v: ", ses.UserID, err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } claims := ses.Claims(s.IssuerURL.String()) @@ -555,7 +552,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } // Generate refresh token when 'scope' contains 'offline_access'. @@ -571,28 +568,25 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo break default: log.Errorf("Failed to generate refresh token: %v", err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } break } } - // Generate expires_in value - expiresIn := int64(ses.ExpiresAt.Sub(ses.CreatedAt).Seconds()) - log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID) - return jwt, refreshToken, expiresIn, nil + return jwt, refreshToken, ses.ExpiresAt, nil } -func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, int64, error) { +func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error) { ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if !ok { log.Errorf("Failed to Authenticate client %s", creds.ID) - return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) } userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) @@ -600,18 +594,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, case nil: break case refresh.ErrorInvalidToken: - return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidRequest) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest) case refresh.ErrorInvalidClientID: - return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) default: - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if len(scopes) == 0 { scopes = rtScopes } else { if !rtScopes.Contains(scopes) { - return nil, "", 0, oauth2.NewError(oauth2.ErrorInvalidRequest) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest) } } @@ -620,7 +614,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, // The error can be user.ErrorNotFound, but we are not deleting // user at this moment, so this shouldn't happen. log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } var groups []string @@ -628,19 +622,19 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, conn, ok := s.connector(connectorID) if !ok { log.Errorf("refresh token contained invalid connector ID (%s)", connectorID) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } grouper, ok := conn.(connector.GroupsConnector) if !ok { log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID) if err != nil { log.Errorf("failed to get remote identities: %v", err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } remoteIdentity, ok := func() (user.RemoteIdentity, bool) { for _, ri := range remoteIdentities { @@ -652,24 +646,24 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, }() if !ok { log.Errorf("failed to get remote identity for connector %s", connectorID) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if groups, err = grouper.Groups(remoteIdentity.ID); err != nil { log.Errorf("failed to get groups for refresh token: %v", connectorID) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to refresh ID token: %v", err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } now := time.Now() - expireAt := now.Add(session.DefaultSessionValidityWindow) + expiresAt := now.Add(session.DefaultSessionValidityWindow) - claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt) + claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expiresAt) usr.AddToClaims(claims) if rtScopes.HasScope(scope.ScopeGroups) { if groups == nil { @@ -683,21 +677,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token) if err != nil { log.Errorf("Failed to generate new refresh token: %v", err) - return nil, "", 0, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } - // Generate expires_in value - expiresIn := int64(expireAt.Sub(now).Seconds()) - log.Infof("New token sent: clientID=%s", creds.ID) - return jwt, refreshToken, expiresIn, nil + return jwt, refreshToken, expiresAt, nil } func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) { diff --git a/server/server_test.go b/server/server_test.go index 74e316b7..956bdd26 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -443,7 +443,7 @@ func TestServerCodeToken(t *testing.T) { t.Fatalf("case %d: unexpected error: %v", i, err) } - jwt, token, expiresIn, err := f.srv.CodeToken(oidc.ClientCredentials{ + jwt, token, expiresAt, err := f.srv.CodeToken(oidc.ClientCredentials{ ID: testClientID, Secret: clientTestSecret}, key) if err != nil { @@ -455,7 +455,7 @@ func TestServerCodeToken(t *testing.T) { if token != tt.refreshToken { t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) } - if expiresIn == 0 { + if expiresAt.IsZero() { t.Fatalf("case %d: expect non-zero expiration time", i) } } @@ -478,7 +478,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - jwt, token, expiresIn, err := f.srv.CodeToken(testClientCredentials, "foo") + jwt, token, expiresAt, err := f.srv.CodeToken(testClientCredentials, "foo") if err == nil { t.Fatalf("Expected non-nil error") } @@ -488,7 +488,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) { if token != "" { t.Fatalf("Expected empty refresh token") } - if expiresIn != 0 { + if !expiresAt.IsZero() { t.Fatalf("Expected zero expiration time") } } @@ -586,7 +586,7 @@ func TestServerTokenFail(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - jwt, token, expiresIn, err := f.srv.CodeToken(tt.argCC, tt.argKey) + jwt, token, expiresAt, err := f.srv.CodeToken(tt.argCC, tt.argKey) if token != tt.refreshToken { fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token) t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) @@ -601,8 +601,8 @@ func TestServerTokenFail(t *testing.T) { if err != nil && jwt != nil { t.Errorf("case %d: got non-nil JWT %v", i, jwt) } - if err == nil && expiresIn == 0 { - t.Errorf("case %d: got zero expiration time %v", i, expiresIn) + if err == nil && expiresAt.IsZero() { + t.Errorf("case %d: got zero expiration time %v", i, expiresAt) } } } @@ -885,7 +885,7 @@ func TestServerRefreshToken(t *testing.T) { t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken) } - if err == nil && expiresIn == 0 { + if err == nil && expiresIn.IsZero() { t.Errorf("case %d: got zero expiration time %v", i, expiresIn) } }