From 19884d92ac5177e42d3c206589c545067a6e4d54 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Fri, 14 May 2021 23:14:38 +0400 Subject: [PATCH] feat: Add ent-based postgres storage Signed-off-by: m.nabokikh --- cmd/dex/config.go | 19 +-- storage/ent/client/authrequest.go | 2 +- storage/ent/client/client.go | 2 +- storage/ent/client/connector.go | 2 +- storage/ent/client/devicetoken.go | 2 +- storage/ent/client/keys.go | 2 +- storage/ent/client/main.go | 17 ++- storage/ent/client/offlinesession.go | 2 +- storage/ent/client/password.go | 2 +- storage/ent/client/refreshtoken.go | 2 +- storage/ent/postgres.go | 155 +++++++++++++++++++++++ storage/ent/postgres_test.go | 183 +++++++++++++++++++++++++++ storage/ent/sqlite_test.go | 6 +- storage/ent/types.go | 25 ++++ 14 files changed, 401 insertions(+), 20 deletions(-) create mode 100644 storage/ent/postgres.go create mode 100644 storage/ent/postgres_test.go create mode 100644 storage/ent/types.go diff --git a/cmd/dex/config.go b/cmd/dex/config.go index bec6f620..309fc52c 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -182,14 +182,17 @@ var ( _ StorageConfig = (*sql.Postgres)(nil) _ StorageConfig = (*sql.MySQL)(nil) _ StorageConfig = (*ent.SQLite3)(nil) + _ StorageConfig = (*ent.Postgres)(nil) ) -func getORMBasedSQLiteStorage() StorageConfig { - switch os.Getenv("DEX_ENT_ENABLED") { - case "true", "yes": - return new(ent.SQLite3) - default: - return new(sql.SQLite3) +func getORMBasedSQLStorage(normal, entBased StorageConfig) func() StorageConfig { + return func() StorageConfig { + switch os.Getenv("DEX_ENT_ENABLED") { + case "true", "yes": + return entBased + default: + return normal + } } } @@ -197,9 +200,9 @@ var storages = map[string]func() StorageConfig{ "etcd": func() StorageConfig { return new(etcd.Etcd) }, "kubernetes": func() StorageConfig { return new(kubernetes.Config) }, "memory": func() StorageConfig { return new(memory.Config) }, - "postgres": func() StorageConfig { return new(sql.Postgres) }, "mysql": func() StorageConfig { return new(sql.MySQL) }, - "sqlite3": getORMBasedSQLiteStorage, + "sqlite3": getORMBasedSQLStorage(&sql.SQLite3{}, &ent.SQLite3{}), + "postgres": getORMBasedSQLStorage(&sql.Postgres{}, &ent.Postgres{}), } // isExpandEnvEnabled returns if os.ExpandEnv should be used for each storage and connector config. diff --git a/storage/ent/client/authrequest.go b/storage/ent/client/authrequest.go index 4cbb8b4e..bde37adc 100644 --- a/storage/ent/client/authrequest.go +++ b/storage/ent/client/authrequest.go @@ -58,7 +58,7 @@ func (d *Database) DeleteAuthRequest(id string) error { // UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database. func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error { - tx, err := d.client.Tx(context.TODO()) + tx, err := d.BeginTx(context.TODO()) if err != nil { return fmt.Errorf("update auth request tx: %w", err) } diff --git a/storage/ent/client/client.go b/storage/ent/client/client.go index 577508d6..07434bd6 100644 --- a/storage/ent/client/client.go +++ b/storage/ent/client/client.go @@ -57,7 +57,7 @@ func (d *Database) DeleteClient(id string) error { // UpdateClient changes an oauth2 client by id using an updater function and saves it to the database. func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { - tx, err := d.client.Tx(context.TODO()) + tx, err := d.BeginTx(context.TODO()) if err != nil { return convertDBError("update client tx: %w", err) } diff --git a/storage/ent/client/connector.go b/storage/ent/client/connector.go index ebba3f58..bfec4418 100644 --- a/storage/ent/client/connector.go +++ b/storage/ent/client/connector.go @@ -55,7 +55,7 @@ func (d *Database) DeleteConnector(id string) error { // UpdateConnector changes a connector by id using an updater function and saves it to the database. func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error { - tx, err := d.client.Tx(context.TODO()) + tx, err := d.BeginTx(context.TODO()) if err != nil { return convertDBError("update connector tx: %w", err) } diff --git a/storage/ent/client/devicetoken.go b/storage/ent/client/devicetoken.go index 89de1cb3..d8870787 100644 --- a/storage/ent/client/devicetoken.go +++ b/storage/ent/client/devicetoken.go @@ -37,7 +37,7 @@ func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error // UpdateDeviceToken changes a token by device code using an updater function and saves it to the database. func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { - tx, err := d.client.Tx(context.TODO()) + tx, err := d.BeginTx(context.TODO()) if err != nil { return convertDBError("update device token tx: %w", err) } diff --git a/storage/ent/client/keys.go b/storage/ent/client/keys.go index d9f32048..f65d40fc 100644 --- a/storage/ent/client/keys.go +++ b/storage/ent/client/keys.go @@ -26,7 +26,7 @@ func (d *Database) GetKeys() (storage.Keys, error) { func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { firstUpdate := false - tx, err := d.client.Tx(context.TODO()) + tx, err := d.BeginTx(context.TODO()) if err != nil { return convertDBError("update keys tx: %w", err) } diff --git a/storage/ent/client/main.go b/storage/ent/client/main.go index 84dc7d97..bc4c1600 100644 --- a/storage/ent/client/main.go +++ b/storage/ent/client/main.go @@ -2,6 +2,7 @@ package client import ( "context" + "database/sql" "hash" "time" @@ -17,7 +18,9 @@ import ( var _ storage.Storage = (*Database)(nil) type Database struct { - client *db.Client + client *db.Client + txOptions *sql.TxOptions + hasher func() hash.Hash } @@ -44,6 +47,13 @@ func WithHasher(h func() hash.Hash) func(*Database) { } } +// WithTxIsolationLevel sets correct isolation level for database transactions. +func WithTxIsolationLevel(level sql.IsolationLevel) func(*Database) { + return func(s *Database) { + s.txOptions = &sql.TxOptions{Isolation: level} + } +} + // Schema exposes migration schema to perform migrations. func (d *Database) Schema() *migrate.Schema { return d.client.Schema @@ -54,6 +64,11 @@ func (d *Database) Close() error { return d.client.Close() } +// BeginTx is a wrapper to begin transaction with defined options. +func (d *Database) BeginTx(ctx context.Context) (*db.Tx, error) { + return d.client.BeginTx(ctx, d.txOptions) +} + // GarbageCollect removes expired entities from the database. func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) { result := storage.GCResult{} diff --git a/storage/ent/client/offlinesession.go b/storage/ent/client/offlinesession.go index cee415b6..6c698547 100644 --- a/storage/ent/client/offlinesession.go +++ b/storage/ent/client/offlinesession.go @@ -55,7 +55,7 @@ func (d *Database) DeleteOfflineSessions(userID, connID string) error { func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { id := offlineSessionID(userID, connID, d.hasher) - tx, err := d.client.Tx(context.TODO()) + tx, err := d.BeginTx(context.TODO()) if err != nil { return convertDBError("update offline session tx: %w", err) } diff --git a/storage/ent/client/password.go b/storage/ent/client/password.go index 003cbd1a..daaae30c 100644 --- a/storage/ent/client/password.go +++ b/storage/ent/client/password.go @@ -64,7 +64,7 @@ func (d *Database) DeletePassword(email string) error { func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error { email = strings.ToLower(email) - tx, err := d.client.Tx(context.TODO()) + tx, err := d.BeginTx(context.TODO()) if err != nil { return convertDBError("update connector tx: %w", err) } diff --git a/storage/ent/client/refreshtoken.go b/storage/ent/client/refreshtoken.go index 0b90233d..eca048f4 100644 --- a/storage/ent/client/refreshtoken.go +++ b/storage/ent/client/refreshtoken.go @@ -67,7 +67,7 @@ func (d *Database) DeleteRefresh(id string) error { // UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database. func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { - tx, err := d.client.Tx(context.TODO()) + tx, err := d.BeginTx(context.TODO()) if err != nil { return convertDBError("update refresh token tx: %w", err) } diff --git a/storage/ent/postgres.go b/storage/ent/postgres.go new file mode 100644 index 00000000..d2197893 --- /dev/null +++ b/storage/ent/postgres.go @@ -0,0 +1,155 @@ +package ent + +import ( + "context" + "crypto/sha256" + "database/sql" + "fmt" + "net" + "regexp" + "strconv" + "strings" + "time" + + // Register postgres driver. + _ "github.com/lib/pq" + + entSQL "entgo.io/ent/dialect/sql" + "github.com/dexidp/dex/pkg/log" + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/client" + "github.com/dexidp/dex/storage/ent/db" +) + +// nolint +const ( + // postgres SSL modes + pgSSLDisable = "disable" + pgSSLRequire = "require" + pgSSLVerifyCA = "verify-ca" + pgSSLVerifyFull = "verify-full" +) + +// Postgres options for creating an SQL db. +type Postgres struct { + NetworkDB + + SSL SSL `json:"ssl"` +} + +// Open always returns a new in sqlite3 storage. +func (p *Postgres) Open(logger log.Logger) (storage.Storage, error) { + logger.Debug("experimental ent-based storage driver is enabled") + drv, err := p.driver() + if err != nil { + return nil, err + } + + databaseClient := client.NewDatabase( + client.WithClient(db.NewClient(db.Driver(drv))), + client.WithHasher(sha256.New), + // The default behavior for Postgres transactions is consistent reads, not consistent writes. + // For each transaction opened, ensure it has the correct isolation level. + // + // See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html + client.WithTxIsolationLevel(sql.LevelSerializable), + ) + + if err := databaseClient.Schema().Create(context.TODO()); err != nil { + return nil, err + } + + return databaseClient, nil +} + +func (p *Postgres) driver() (*entSQL.Driver, error) { + drv, err := entSQL.Open("postgres", p.dsn()) + if err != nil { + return nil, err + } + + // set database/sql tunables if configured + if p.ConnMaxLifetime != 0 { + drv.DB().SetConnMaxLifetime(time.Duration(p.ConnMaxLifetime) * time.Second) + } + + if p.MaxIdleConns == 0 { + drv.DB().SetMaxIdleConns(5) + } else { + drv.DB().SetMaxIdleConns(p.MaxIdleConns) + } + + if p.MaxOpenConns == 0 { + drv.DB().SetMaxOpenConns(5) + } else { + drv.DB().SetMaxOpenConns(p.MaxOpenConns) + } + + return drv, nil +} + +func (p *Postgres) dsn() string { + // detect host:port for backwards-compatibility + host, port, err := net.SplitHostPort(p.Host) + if err != nil { + // not host:port, probably unix socket or bare address + host = p.Host + if p.Port != 0 { + port = strconv.Itoa(int(p.Port)) + } + } + + var parameters []string + addParam := func(key, val string) { + parameters = append(parameters, fmt.Sprintf("%s=%s", key, val)) + } + + addParam("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) + + if host != "" { + addParam("host", dataSourceStr(host)) + } + + if port != "" { + addParam("port", port) + } + + if p.User != "" { + addParam("user", dataSourceStr(p.User)) + } + + if p.Password != "" { + addParam("password", dataSourceStr(p.Password)) + } + + if p.Database != "" { + addParam("dbname", dataSourceStr(p.Database)) + } + + if p.SSL.Mode == "" { + // Assume the strictest mode if unspecified. + addParam("sslmode", dataSourceStr(pgSSLVerifyFull)) + } else { + addParam("sslmode", dataSourceStr(p.SSL.Mode)) + } + + if p.SSL.CAFile != "" { + addParam("sslrootcert", dataSourceStr(p.SSL.CAFile)) + } + + if p.SSL.CertFile != "" { + addParam("sslcert", dataSourceStr(p.SSL.CertFile)) + } + + if p.SSL.KeyFile != "" { + addParam("sslkey", dataSourceStr(p.SSL.KeyFile)) + } + + return strings.Join(parameters, " ") +} + +var strEsc = regexp.MustCompile(`([\\'])`) + +func dataSourceStr(str string) string { + return "'" + strEsc.ReplaceAllString(str, `\$1`) + "'" +} diff --git a/storage/ent/postgres_test.go b/storage/ent/postgres_test.go new file mode 100644 index 00000000..d9395880 --- /dev/null +++ b/storage/ent/postgres_test.go @@ -0,0 +1,183 @@ +package ent + +import ( + "os" + "strconv" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/conformance" +) + +func getenv(key, defaultVal string) string { + if val := os.Getenv(key); val != "" { + return val + } + return defaultVal +} + +func postgresTestConfig(host string, port uint64) *Postgres { + return &Postgres{ + NetworkDB: NetworkDB{ + Database: getenv("DEX_POSTGRES_DATABASE", "postgres"), + User: getenv("DEX_POSTGRES_USER", "postgres"), + Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"), + Host: host, + Port: uint16(port), + }, + SSL: SSL{ + Mode: pgSSLDisable, // Postgres container doesn't support SSL. + }, + } +} + +func newPostgresStorage(host string, port uint64) storage.Storage { + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + cfg := postgresTestConfig(host, port) + s, err := cfg.Open(logger) + if err != nil { + panic(err) + } + return s +} + +func TestPostgres(t *testing.T) { + host := os.Getenv("DEX_POSTGRES_HOST") + if host == "" { + t.Skipf("test environment variable DEX_POSTGRES_HOST not set, skipping") + } + + port := uint64(5432) + if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" { + var err error + + port, err = strconv.ParseUint(rawPort, 10, 32) + require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err) + } + + newStorage := func() storage.Storage { + return newPostgresStorage(host, port) + } + conformance.RunTests(t, newStorage) + conformance.RunTransactionTests(t, newStorage) +} + +func TestPostgresDSN(t *testing.T) { + tests := []struct { + name string + cfg *Postgres + desiredDSN string + }{ + { + name: "Host port", + cfg: &Postgres{ + NetworkDB: NetworkDB{ + Host: "localhost", + Port: uint16(5432), + }, + }, + desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'", + }, + { + name: "Host with port", + cfg: &Postgres{ + NetworkDB: NetworkDB{ + Host: "localhost:5432", + }, + }, + desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'", + }, + { + name: "Host ipv6 with port", + cfg: &Postgres{ + NetworkDB: NetworkDB{ + Host: "[a:b:c:d]:5432", + }, + }, + desiredDSN: "connect_timeout=0 host='a:b:c:d' port=5432 sslmode='verify-full'", + }, + { + name: "Credentials and timeout", + cfg: &Postgres{ + NetworkDB: NetworkDB{ + Database: "test", + User: "test", + Password: "test", + ConnectionTimeout: 5, + }, + }, + desiredDSN: "connect_timeout=5 user='test' password='test' dbname='test' sslmode='verify-full'", + }, + { + name: "SSL", + cfg: &Postgres{ + SSL: SSL{ + Mode: pgSSLRequire, + CAFile: "/ca.crt", + KeyFile: "/cert.crt", + CertFile: "/cert.key", + }, + }, + desiredDSN: "connect_timeout=0 sslmode='require' sslrootcert='/ca.crt' sslcert='/cert.key' sslkey='/cert.crt'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.desiredDSN, tt.cfg.dsn()) + }) + } +} + +func TestPostgresDriver(t *testing.T) { + host := os.Getenv("DEX_POSTGRES_HOST") + if host == "" { + t.Skipf("test environment variable DEX_POSTGRES_HOST not set, skipping") + } + + port := uint64(5432) + if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" { + var err error + + port, err = strconv.ParseUint(rawPort, 10, 32) + require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err) + } + + tests := []struct { + name string + cfg func() *Postgres + desiredConns int + }{ + { + name: "Defaults", + cfg: func() *Postgres { return postgresTestConfig(host, port) }, + desiredConns: 5, + }, + { + name: "Tune", + cfg: func() *Postgres { + cfg := postgresTestConfig(host, port) + cfg.MaxOpenConns = 101 + return cfg + }, + desiredConns: 101, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + drv, err := tt.cfg().driver() + require.NoError(t, err) + + require.Equal(t, tt.desiredConns, drv.DB().Stats().MaxOpenConnections) + }) + } +} diff --git a/storage/ent/sqlite_test.go b/storage/ent/sqlite_test.go index 053c827d..10047b7f 100644 --- a/storage/ent/sqlite_test.go +++ b/storage/ent/sqlite_test.go @@ -10,7 +10,7 @@ import ( "github.com/dexidp/dex/storage/conformance" ) -func newStorage() storage.Storage { +func newSQLiteStorage() storage.Storage { logger := &logrus.Logger{ Out: os.Stderr, Formatter: &logrus.TextFormatter{DisableColors: true}, @@ -26,6 +26,6 @@ func newStorage() storage.Storage { } func TestSQLite3(t *testing.T) { - conformance.RunTests(t, newStorage) - conformance.RunTransactionTests(t, newStorage) + conformance.RunTests(t, newSQLiteStorage) + conformance.RunTransactionTests(t, newSQLiteStorage) } diff --git a/storage/ent/types.go b/storage/ent/types.go new file mode 100644 index 00000000..062f8640 --- /dev/null +++ b/storage/ent/types.go @@ -0,0 +1,25 @@ +package ent + +// NetworkDB contains options common to SQL databases accessed over network. +type NetworkDB struct { + Database string + User string + Password string + Host string + Port uint16 + + ConnectionTimeout int // Seconds + + MaxOpenConns int // default: 5 + MaxIdleConns int // default: 5 + ConnMaxLifetime int // Seconds, default: not set +} + +// SSL represents SSL options for network databases. +type SSL struct { + Mode string + CAFile string + // Files for client auth. + KeyFile string + CertFile string +}