diff --git a/db/refresh.go b/db/refresh.go index 552cc1e6..86495748 100644 --- a/db/refresh.go +++ b/db/refresh.go @@ -14,6 +14,7 @@ import ( "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/refresh" "github.com/coreos/dex/repo" + "github.com/coreos/go-oidc/oidc" ) const ( @@ -179,6 +180,35 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error { return tx.Commit() } +func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error { + q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName)) + _, err := r.executor(nil).Exec(q, userID, clientID) + return err +} + +func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]oidc.ClientIdentity, error) { + q := `SELECT c.* FROM %s as c + INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;` + q = fmt.Sprintf(q, r.quote(clientIdentityTableName), r.quote(refreshTokenTableName)) + + var clients []clientIdentityModel + if _, err := r.executor(nil).Select(&clients, q, userID); err != nil { + return nil, err + } + + c := make([]oidc.ClientIdentity, len(clients)) + for i, client := range clients { + ident, err := client.ClientIdentity() + if err != nil { + return nil, err + } + c[i] = *ident + // Do not share the secret. + c[i].Credentials.Secret = "" + } + return c, nil +} + func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) { ex := r.executor(tx) result, err := ex.Get(refreshTokenModel{}, tokenID) diff --git a/functional/repo/client_repo_test.go b/functional/repo/client_repo_test.go index 6c580d68..fff50ca1 100644 --- a/functional/repo/client_repo_test.go +++ b/functional/repo/client_repo_test.go @@ -24,7 +24,8 @@ var ( RedirectURIs: []url.URL{ url.URL{ Scheme: "https", - Host: "client1.example.com/callback", + Host: "client1.example.com", + Path: "/callback", }, }, }, @@ -38,7 +39,8 @@ var ( RedirectURIs: []url.URL{ url.URL{ Scheme: "https", - Host: "client2.example.com/callback", + Host: "client2.example.com", + Path: "/callback", }, }, }, diff --git a/functional/repo/refresh_repo_test.go b/functional/repo/refresh_repo_test.go new file mode 100644 index 00000000..197f59ad --- /dev/null +++ b/functional/repo/refresh_repo_test.go @@ -0,0 +1,93 @@ +package repo + +import ( + "encoding/base64" + "net/url" + "os" + "testing" + "time" + + "github.com/coreos/go-oidc/oidc" + "github.com/go-gorp/gorp" + "github.com/kylelemons/godebug/pretty" + + "github.com/coreos/dex/db" + "github.com/coreos/dex/refresh" + "github.com/coreos/dex/user" +) + +func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []oidc.ClientIdentity) refresh.RefreshTokenRepo { + var dbMap *gorp.DbMap + if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" { + dbMap = db.NewMemDB() + } else { + dbMap = connect(t) + } + if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil { + t.Fatalf("Unable to add users: %v", err) + } + if _, err := db.NewClientIdentityRepoFromClients(dbMap, clients); err != nil { + t.Fatalf("Unable to add clients: %v", err) + } + return db.NewRefreshTokenRepo(dbMap) +} + +func TestRefreshTokenRepo(t *testing.T) { + clientID := "client1" + userID := "user1" + clients := []oidc.ClientIdentity{ + { + Credentials: oidc.ClientCredentials{ + ID: clientID, + Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), + }, + Metadata: oidc.ClientMetadata{ + RedirectURIs: []url.URL{ + url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"}, + }, + }, + }, + } + users := []user.UserWithRemoteIdentities{ + { + User: user.User{ + ID: userID, + Email: "Email-1@example.com", + CreatedAt: time.Now().Truncate(time.Second), + }, + RemoteIdentities: []user.RemoteIdentity{ + { + ConnectorID: "IDPC-1", + ID: "RID-1", + }, + }, + }, + } + + repo := newRefreshRepo(t, users, clients) + tok, err := repo.Create(userID, clientID) + if err != nil { + t.Fatalf("failed to create refresh token: %v", err) + } + if tokUserID, err := repo.Verify(clientID, tok); err != nil { + t.Errorf("Could not verify token: %v", err) + } else if tokUserID != userID { + t.Errorf("Verified token returned wrong user id, want=%s, got=%s", userID, tokUserID) + } + + if userClients, err := repo.ClientsWithRefreshTokens(userID); err != nil { + t.Errorf("Failed to get the list of clients the user was logged into: %v", err) + } else { + if diff := pretty.Compare(userClients, clients); diff == "" { + t.Errorf("Clients user logged into: want did not equal got %s", diff) + } + } + + if err := repo.RevokeTokensForClient(userID, clientID); err != nil { + t.Errorf("Failed to revoke refresh token: %v", err) + } + + if _, err := repo.Verify(clientID, tok); err == nil { + t.Errorf("Token which should have been revoked was verified") + } +} diff --git a/refresh/repo.go b/refresh/repo.go index 0c65c0e6..607169fe 100644 --- a/refresh/repo.go +++ b/refresh/repo.go @@ -3,6 +3,8 @@ package refresh import ( "crypto/rand" "errors" + + "github.com/coreos/go-oidc/oidc" ) const ( @@ -47,4 +49,10 @@ type RefreshTokenRepo interface { // Revoke deletes the refresh token if the token belongs to the given userID. Revoke(userID, token string) error + + // RevokeTokensForClient revokes all tokens issued for the userID for the provided client. + RevokeTokensForClient(userID, clientID string) error + + // ClientsWithRefreshTokens returns a list of all clients the user has an outstanding client with. + ClientsWithRefreshTokens(userID string) ([]oidc.ClientIdentity, error) }