diff --git a/server/cross_client_test.go b/server/cross_client_test.go index 00099622..3468bde4 100644 --- a/server/cross_client_test.go +++ b/server/cross_client_test.go @@ -283,7 +283,7 @@ func TestServerCodeTokenCrossClient(t *testing.T) { t.Fatalf("case %d: unexpected error: %v", i, err) } - jwt, token, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key) + jwt, token, expiresAt, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key) if err != nil { t.Fatalf("case %d: unexpected error: %v", i, err) } @@ -293,6 +293,9 @@ func TestServerCodeTokenCrossClient(t *testing.T) { if token != tt.refreshToken { t.Errorf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) } + if expiresAt.IsZero() { + t.Errorf("case %d: expect non-zero expiration time", i) + } claims, err := jwt.Claims() if err != nil { diff --git a/server/http.go b/server/http.go index 773246ac..218f1920 100644 --- a/server/http.go +++ b/server/http.go @@ -491,6 +491,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { var jwt *jose.JWT var refreshToken string + var expiresAt time.Time grantType := r.PostForm.Get("grant_type") switch grantType { @@ -501,14 +502,14 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return } - jwt, refreshToken, err = srv.CodeToken(creds, code) + jwt, refreshToken, expiresAt, err = srv.CodeToken(creds, code) if err != nil { log.Errorf("couldn't exchange code for token: %v", err) writeTokenError(w, err, state) return } case oauth2.GrantTypeClientCreds: - jwt, err = srv.ClientCredsToken(creds) + jwt, expiresAt, err = srv.ClientCredsToken(creds) if err != nil { log.Errorf("couldn't creds for token: %v", err) writeTokenError(w, err, state) @@ -521,7 +522,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return } - jwt, refreshToken, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token) + jwt, refreshToken, expiresAt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token) if err != nil { writeTokenError(w, err, state) return @@ -537,6 +538,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { IDToken: jwt.Encode(), TokenType: "bearer", RefreshToken: refreshToken, + ExpiresIn: int64(expiresAt.Sub(time.Now()).Seconds()), } b, err := json.Marshal(t) @@ -594,6 +596,7 @@ type oAuth2Token struct { IDToken string `json:"id_token"` TokenType string `json:"token_type"` RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int64 `json:"expires_in"` } func createLastSeenCookie() *http.Cookie { diff --git a/server/server.go b/server/server.go index de461f89..b3f2708e 100644 --- a/server/server.go +++ b/server/server.go @@ -49,13 +49,13 @@ type OIDCServer interface { Login(oidc.Identity, string) (string, error) // CodeToken exchanges a code for an ID token and a refresh token string on success. - CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) + CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, time.Time, error) - ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) + ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, time.Time, error) // RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token // if the token is valid. - RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error) + RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error) KillSession(string) error @@ -466,29 +466,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { return ru.String(), nil } -func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) { +func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, time.Time, error) { cli, err := s.Client(creds.ID) if err != nil { - return nil, err + return nil, time.Time{}, err } if cli.Public { - return nil, oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) } ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if !ok { - return nil, oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } now := time.Now() @@ -499,49 +499,49 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } log.Infof("Client token sent: clientID=%s", creds.ID) - return jwt, nil + return jwt, exp, nil } -func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) { +func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, time.Time, error) { ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if !ok { log.Errorf("Failed to Authenticate client %s", creds.ID) - return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) } sessionID, err := s.SessionManager.ExchangeKey(sessionKey) if err != nil { - return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidGrant) } ses, err := s.SessionManager.Kill(sessionID) if err != nil { - return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest) } if ses.ClientID != creds.ID { - return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidGrant) } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } user, err := s.UserRepo.Get(nil, ses.UserID) if err != nil { log.Errorf("Failed to fetch user %q from repo: %v: ", ses.UserID, err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } claims := ses.Claims(s.IssuerURL.String()) @@ -552,7 +552,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } // Generate refresh token when 'scope' contains 'offline_access'. @@ -568,25 +568,25 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo break default: log.Errorf("Failed to generate refresh token: %v", err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } break } } log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID) - return jwt, refreshToken, nil + return jwt, refreshToken, ses.ExpiresAt, nil } -func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error) { +func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error) { ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if !ok { log.Errorf("Failed to Authenticate client %s", creds.ID) - return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) } userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) @@ -594,18 +594,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, case nil: break case refresh.ErrorInvalidToken: - return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest) case refresh.ErrorInvalidClientID: - return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient) default: - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if len(scopes) == 0 { scopes = rtScopes } else { if !rtScopes.Contains(scopes) { - return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest) } } @@ -614,7 +614,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, // The error can be user.ErrorNotFound, but we are not deleting // user at this moment, so this shouldn't happen. log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } var groups []string @@ -622,19 +622,19 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, conn, ok := s.connector(connectorID) if !ok { log.Errorf("refresh token contained invalid connector ID (%s)", connectorID) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } grouper, ok := conn.(connector.GroupsConnector) if !ok { log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID) if err != nil { log.Errorf("failed to get remote identities: %v", err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } remoteIdentity, ok := func() (user.RemoteIdentity, bool) { for _, ri := range remoteIdentities { @@ -646,24 +646,24 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, }() if !ok { log.Errorf("failed to get remote identity for connector %s", connectorID) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } if groups, err = grouper.Groups(remoteIdentity.ID); err != nil { log.Errorf("failed to get groups for refresh token: %v", connectorID) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to refresh ID token: %v", err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } now := time.Now() - expireAt := now.Add(session.DefaultSessionValidityWindow) + expiresAt := now.Add(session.DefaultSessionValidityWindow) - claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt) + claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expiresAt) usr.AddToClaims(claims) if rtScopes.HasScope(scope.ScopeGroups) { if groups == nil { @@ -677,18 +677,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token) if err != nil { log.Errorf("Failed to generate new refresh token: %v", err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError) } log.Infof("New token sent: clientID=%s", creds.ID) - return jwt, refreshToken, nil + return jwt, refreshToken, expiresAt, nil } func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) { diff --git a/server/server_test.go b/server/server_test.go index 280bbd00..956bdd26 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -443,7 +443,7 @@ func TestServerCodeToken(t *testing.T) { t.Fatalf("case %d: unexpected error: %v", i, err) } - jwt, token, err := f.srv.CodeToken(oidc.ClientCredentials{ + jwt, token, expiresAt, err := f.srv.CodeToken(oidc.ClientCredentials{ ID: testClientID, Secret: clientTestSecret}, key) if err != nil { @@ -455,6 +455,9 @@ func TestServerCodeToken(t *testing.T) { if token != tt.refreshToken { t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) } + if expiresAt.IsZero() { + t.Fatalf("case %d: expect non-zero expiration time", i) + } } } @@ -475,7 +478,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - jwt, token, err := f.srv.CodeToken(testClientCredentials, "foo") + jwt, token, expiresAt, err := f.srv.CodeToken(testClientCredentials, "foo") if err == nil { t.Fatalf("Expected non-nil error") } @@ -485,6 +488,9 @@ func TestServerTokenUnrecognizedKey(t *testing.T) { if token != "" { t.Fatalf("Expected empty refresh token") } + if !expiresAt.IsZero() { + t.Fatalf("Expected zero expiration time") + } } func TestServerTokenFail(t *testing.T) { @@ -580,7 +586,7 @@ func TestServerTokenFail(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - jwt, token, err := f.srv.CodeToken(tt.argCC, tt.argKey) + jwt, token, expiresAt, err := f.srv.CodeToken(tt.argCC, tt.argKey) if token != tt.refreshToken { fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token) t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) @@ -595,6 +601,9 @@ func TestServerTokenFail(t *testing.T) { if err != nil && jwt != nil { t.Errorf("case %d: got non-nil JWT %v", i, jwt) } + if err == nil && expiresAt.IsZero() { + t.Errorf("case %d: got zero expiration time %v", i, expiresAt) + } } } @@ -835,7 +844,7 @@ func TestServerRefreshToken(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - jwt, refreshToken, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token) + jwt, refreshToken, expiresIn, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token) if !reflect.DeepEqual(err, tt.err) { t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) } @@ -875,5 +884,9 @@ func TestServerRefreshToken(t *testing.T) { if diff := pretty.Compare(refreshToken, tt.expectedRefreshToken); diff != "" { t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken) } + + if err == nil && expiresIn.IsZero() { + t.Errorf("case %d: got zero expiration time %v", i, expiresIn) + } } }