diff --git a/db/refresh.go b/db/refresh.go index 0baa655f..08735416 100644 --- a/db/refresh.go +++ b/db/refresh.go @@ -91,103 +91,52 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG } func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) { - if userID == "" { - return "", refresh.ErrorInvalidUserID - } - if clientID == "" { - return "", refresh.ErrorInvalidClientID - } - - // TODO(yifan): Check the number of tokens given to the client-user pair. - tokenPayload, err := r.tokenGenerator.Generate() - if err != nil { - return "", err - } - - payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost) - if err != nil { - return "", err - } - - record := &refreshTokenModel{ - PayloadHash: payloadHash, - UserID: userID, - ClientID: clientID, - ConnectorID: connectorID, - Scopes: strings.Join(scopes, " "), - } - - if err := r.executor(nil).Insert(record); err != nil { - return "", err - } - - return buildToken(record.ID, tokenPayload), nil + return r.create(nil, userID, clientID, connectorID, scopes) } func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) { - tokenID, tokenPayload, err := parseToken(token) - - if err != nil { - return - } - - record, err := r.get(nil, tokenID) - if err != nil { - return - } - - if record.ClientID != clientID { - return "", "", nil, refresh.ErrorInvalidClientID - } - - if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { - return - } - - var scopes []string - if len(record.Scopes) > 0 { - scopes = strings.Split(record.Scopes, " ") - } - - return record.UserID, record.ConnectorID, scopes, nil + return r.verify(nil, clientID, token) } func (r *refreshTokenRepo) Revoke(userID, token string) error { - tokenID, tokenPayload, err := parseToken(token) - if err != nil { - return err - } - tx, err := r.begin() if err != nil { return err } defer tx.Rollback() - exec := r.executor(tx) - record, err := r.get(tx, tokenID) - if err != nil { + if err := r.revoke(tx, userID, token); err != nil { return err } - if record.UserID != userID { - return refresh.ErrorInvalidUserID - } - - if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { - return err - } - - deleted, err := exec.Delete(record) - if err != nil { - return err - } - if deleted == 0 { - return refresh.ErrorInvalidToken - } - return tx.Commit() } +func (r *refreshTokenRepo) RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error) { + // Verify + userID, connectorID, scopes, err := r.verify(nil, clientID, oldToken) + if err != nil { + return "", err + } + + // Revoke old refresh token + tx, err := r.begin() + if err != nil { + return "", err + } + defer tx.Rollback() + if err := r.revoke(tx, userID, oldToken); err != nil { + return "", err + } + + // Renew refresh token + newRefreshToken, err = r.create(tx, userID, clientID, connectorID, scopes) + if err != nil { + return "", err + } + + return newRefreshToken, tx.Commit() +} + func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error { q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName)) _, err := r.executor(nil).Exec(q, userID, clientID) @@ -235,3 +184,97 @@ func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshToke } return record, nil } + +func (r *refreshTokenRepo) verify(tx repo.Transaction, clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) { + tokenID, tokenPayload, err := parseToken(token) + + if err != nil { + return + } + + record, err := r.get(tx, tokenID) + if err != nil { + return + } + + if record.ClientID != clientID { + return "", "", nil, refresh.ErrorInvalidClientID + } + + // Check if the hash of token received is the same stored in database + if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { + return + } + + var scopes []string + if len(record.Scopes) > 0 { + scopes = strings.Split(record.Scopes, " ") + } + + return record.UserID, record.ConnectorID, scopes, nil +} + +func (r *refreshTokenRepo) create(tx repo.Transaction, userID, clientID, connectorID string, scopes []string) (string, error) { + if userID == "" { + return "", refresh.ErrorInvalidUserID + } + if clientID == "" { + return "", refresh.ErrorInvalidClientID + } + + // TODO(yifan): Check the number of tokens given to the client-user pair. + tokenPayload, err := r.tokenGenerator.Generate() + if err != nil { + return "", err + } + + payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost) + if err != nil { + return "", err + } + + record := &refreshTokenModel{ + PayloadHash: payloadHash, + UserID: userID, + ClientID: clientID, + ConnectorID: connectorID, + Scopes: strings.Join(scopes, " "), + } + + if err := r.executor(tx).Insert(record); err != nil { + return "", err + } + + return buildToken(record.ID, tokenPayload), nil +} + +func (r *refreshTokenRepo) revoke(tx repo.Transaction, userID, token string) error { + tokenID, tokenPayload, err := parseToken(token) + if err != nil { + return err + } + + exec := r.executor(tx) + record, err := r.get(tx, tokenID) + if err != nil { + return err + } + + if record.UserID != userID { + return refresh.ErrorInvalidUserID + } + + if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { + return err + } + + deleted, err := exec.Delete(record) + if err != nil { + return err + } + if deleted == 0 { + return refresh.ErrorInvalidToken + } + + return nil +} diff --git a/refresh/repo.go b/refresh/repo.go index 23ccda85..b2fdc054 100644 --- a/refresh/repo.go +++ b/refresh/repo.go @@ -54,6 +54,9 @@ type RefreshTokenRepo interface { // Revoke deletes the refresh token if the token belongs to the given userID. Revoke(userID, token string) error + // Revoke old refresh token and generates a new one + RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error) + // RevokeTokensForClient revokes all tokens issued for the userID for the provided client. RevokeTokensForClient(userID, clientID string) error diff --git a/server/http.go b/server/http.go index b3a7cc86..007a9d4e 100644 --- a/server/http.go +++ b/server/http.go @@ -520,7 +520,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) return } - jwt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token) + jwt, refreshToken, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token) if err != nil { writeTokenError(w, err, state) return diff --git a/server/server.go b/server/server.go index 48ff8dc8..2ca65b82 100644 --- a/server/server.go +++ b/server/server.go @@ -53,9 +53,9 @@ type OIDCServer interface { ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) - // RefreshToken takes a previously generated refresh token and returns a new ID token + // RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token // if the token is valid. - RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) + RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error) KillSession(string) error @@ -567,15 +567,15 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo return jwt, refreshToken, nil } -func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) { +func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error) { ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } if !ok { log.Errorf("Failed to Authenticate client %s", creds.ID) - return nil, oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient) } userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) @@ -583,18 +583,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, case nil: break case refresh.ErrorInvalidToken: - return nil, oauth2.NewError(oauth2.ErrorInvalidRequest) + return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest) case refresh.ErrorInvalidClientID: - return nil, oauth2.NewError(oauth2.ErrorInvalidClient) + return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient) default: - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } if len(scopes) == 0 { scopes = rtScopes } else { if !rtScopes.Contains(scopes) { - return nil, oauth2.NewError(oauth2.ErrorInvalidRequest) + return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest) } } @@ -603,7 +603,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, // The error can be user.ErrorNotFound, but we are not deleting // user at this moment, so this shouldn't happen. log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } var groups []string @@ -611,19 +611,19 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, conn, ok := s.connector(connectorID) if !ok { log.Errorf("refresh token contained invalid connector ID (%s)", connectorID) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } grouper, ok := conn.(connector.GroupsConnector) if !ok { log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID) if err != nil { log.Errorf("failed to get remote identities: %v", err) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } remoteIdentity, ok := func() (user.RemoteIdentity, bool) { for _, ri := range remoteIdentities { @@ -635,18 +635,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, }() if !ok { log.Errorf("failed to get remote identity for connector %s", connectorID) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } if groups, err = grouper.Groups(remoteIdentity.ID); err != nil { log.Errorf("failed to get groups for refresh token: %v", connectorID) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } } signer, err := s.KeyManager.Signer() if err != nil { log.Errorf("Failed to refresh ID token: %v", err) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } now := time.Now() @@ -666,12 +666,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, jwt, err := jose.NewSignedJWT(claims, signer) if err != nil { log.Errorf("Failed to generate ID token: %v", err) - return nil, oauth2.NewError(oauth2.ErrorServerError) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) + } + + refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token) + if err != nil { + log.Errorf("Failed to generate new refresh token: %v", err) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) } log.Infof("New token sent: clientID=%s", creds.ID) - return jwt, nil + return jwt, refreshToken, nil } func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) { diff --git a/server/server_test.go b/server/server_test.go index f724a260..280bbd00 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -107,6 +107,10 @@ func makeNewUserRepo() (user.UserRepo, error) { return userRepo, nil } +func getRefreshTokenEncoded(id, value string) string { + return fmt.Sprintf("%v/%s", id, base64.URLEncoding.EncodeToString([]byte(value))) +} + func TestServerProviderConfig(t *testing.T) { srv := &Server{IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}} @@ -612,27 +616,29 @@ func TestServerRefreshToken(t *testing.T) { // NOTE(ericchiang): These tests assume that the database ID of the first // refresh token will be "1". tests := []struct { - token string - clientID string // The client that associates with the token. - creds oidc.ClientCredentials - signer jose.Signer - createScopes []string - refreshScopes []string - expectedAud []string - err error + token string + expectedRefreshToken string + clientID string // The client that associates with the token. + creds oidc.ClientCredentials + signer jose.Signer + createScopes []string + refreshScopes []string + expectedAud []string + err error }{ // Everything is good. { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - clientID: testClientID, - creds: testClientCredentials, - signer: signerFixture, - createScopes: []string{"openid", "profile"}, - refreshScopes: []string{"openid", "profile"}, + token: getRefreshTokenEncoded("1", "refresh-1"), + expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), + clientID: testClientID, + creds: testClientCredentials, + signer: signerFixture, + createScopes: []string{"openid", "profile"}, + refreshScopes: []string{"openid", "profile"}, }, // Asking for a scope not originally granted to you. { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: testClientCredentials, signer: signerFixture, @@ -652,7 +658,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid refresh token(invalid payload content). { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), + token: getRefreshTokenEncoded("1", "refresh-2"), clientID: testClientID, creds: testClientCredentials, signer: signerFixture, @@ -662,7 +668,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid refresh token(invalid ID content). { - token: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + token: getRefreshTokenEncoded("0", "refresh-1"), clientID: testClientID, creds: testClientCredentials, signer: signerFixture, @@ -672,7 +678,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(client is not associated with the token). { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: clientB.Credentials, signer: signerFixture, @@ -682,7 +688,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(no client ID). { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: oidc.ClientCredentials{ID: "", Secret: "aaa"}, signer: signerFixture, @@ -692,7 +698,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(no such client). { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, signer: signerFixture, @@ -702,7 +708,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(no secrets). { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: oidc.ClientCredentials{ID: testClientID}, signer: signerFixture, @@ -712,7 +718,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Invalid client(invalid secret). { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, signer: signerFixture, @@ -722,7 +728,7 @@ func TestServerRefreshToken(t *testing.T) { }, // Signing operation fails. { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + token: getRefreshTokenEncoded("1", "refresh-1"), clientID: testClientID, creds: testClientCredentials, signer: &StaticSigner{sig: nil, err: errors.New("fail")}, @@ -732,8 +738,9 @@ func TestServerRefreshToken(t *testing.T) { }, // Valid Cross-Client { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - clientID: "client_a", + token: getRefreshTokenEncoded("1", "refresh-1"), + expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), + clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", Secret: base64.URLEncoding.EncodeToString( @@ -748,8 +755,9 @@ func TestServerRefreshToken(t *testing.T) { // refresh request, which should result in the original stored scopes // being used. { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - clientID: "client_a", + token: getRefreshTokenEncoded("1", "refresh-1"), + expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), + clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", Secret: base64.URLEncoding.EncodeToString( @@ -763,8 +771,9 @@ func TestServerRefreshToken(t *testing.T) { // Valid Cross-Client - asking for fewer scopes than originally used // when creating the refresh token, which is ok. { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - clientID: "client_a", + token: getRefreshTokenEncoded("1", "refresh-1"), + expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), + clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", Secret: base64.URLEncoding.EncodeToString( @@ -777,8 +786,9 @@ func TestServerRefreshToken(t *testing.T) { }, // Valid Cross-Client - asking for multiple clients in the audience. { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), - clientID: "client_a", + token: getRefreshTokenEncoded("1", "refresh-1"), + expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"), + clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", Secret: base64.URLEncoding.EncodeToString( @@ -792,7 +802,7 @@ func TestServerRefreshToken(t *testing.T) { // Invalid Cross-Client - didn't orignally request cross-client when // refresh token was created. { - token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), + token: getRefreshTokenEncoded("1", "refresh-1"), clientID: "client_a", creds: oidc.ClientCredentials{ ID: "client_a", @@ -825,7 +835,7 @@ func TestServerRefreshToken(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - jwt, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token) + jwt, refreshToken, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token) if !reflect.DeepEqual(err, tt.err) { t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) } @@ -861,5 +871,9 @@ func TestServerRefreshToken(t *testing.T) { expectedAud, claims["aud"]) } } + + if diff := pretty.Compare(refreshToken, tt.expectedRefreshToken); diff != "" { + t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken) + } } }