package ent import ( "os" "strconv" "testing" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/conformance" ) const ( PostgresEntHostEnv = "DEX_POSTGRES_ENT_HOST" PostgresEntPortEnv = "DEX_POSTGRES_ENT_PORT" PostgresEntDatabaseEnv = "DEX_POSTGRES_ENT_DATABASE" PostgresEntUserEnv = "DEX_POSTGRES_ENT_USER" PostgresEntPasswordEnv = "DEX_POSTGRES_ENT_PASSWORD" ) func getenv(key, defaultVal string) string { if val := os.Getenv(key); val != "" { return val } return defaultVal } func postgresTestConfig(host string, port uint64) *Postgres { return &Postgres{ NetworkDB: NetworkDB{ Database: getenv(PostgresEntDatabaseEnv, "postgres"), User: getenv(PostgresEntUserEnv, "postgres"), Password: getenv(PostgresEntPasswordEnv, "postgres"), Host: host, Port: uint16(port), }, SSL: SSL{ Mode: pgSSLDisable, // Postgres container doesn't support SSL. }, } } func newPostgresStorage(host string, port uint64) storage.Storage { logger := &logrus.Logger{ Out: os.Stderr, Formatter: &logrus.TextFormatter{DisableColors: true}, Level: logrus.DebugLevel, } cfg := postgresTestConfig(host, port) s, err := cfg.Open(logger) if err != nil { panic(err) } return s } func TestPostgres(t *testing.T) { host := os.Getenv(PostgresEntHostEnv) if host == "" { t.Skipf("test environment variable %s not set, skipping", PostgresEntHostEnv) } port := uint64(5432) if rawPort := os.Getenv(PostgresEntPortEnv); rawPort != "" { var err error port, err = strconv.ParseUint(rawPort, 10, 32) require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err) } newStorage := func() storage.Storage { return newPostgresStorage(host, port) } conformance.RunTests(t, newStorage) conformance.RunTransactionTests(t, newStorage) } func TestPostgresDSN(t *testing.T) { tests := []struct { name string cfg *Postgres desiredDSN string }{ { name: "Host port", cfg: &Postgres{ NetworkDB: NetworkDB{ Host: "localhost", Port: uint16(5432), }, }, desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'", }, { name: "Host with port", cfg: &Postgres{ NetworkDB: NetworkDB{ Host: "localhost:5432", }, }, desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'", }, { name: "Host ipv6 with port", cfg: &Postgres{ NetworkDB: NetworkDB{ Host: "[a:b:c:d]:5432", }, }, desiredDSN: "connect_timeout=0 host='a:b:c:d' port=5432 sslmode='verify-full'", }, { name: "Credentials and timeout", cfg: &Postgres{ NetworkDB: NetworkDB{ Database: "test", User: "test", Password: "test", ConnectionTimeout: 5, }, }, desiredDSN: "connect_timeout=5 user='test' password='test' dbname='test' sslmode='verify-full'", }, { name: "SSL", cfg: &Postgres{ SSL: SSL{ Mode: pgSSLRequire, CAFile: "/ca.crt", KeyFile: "/cert.crt", CertFile: "/cert.key", }, }, desiredDSN: "connect_timeout=0 sslmode='require' sslrootcert='/ca.crt' sslcert='/cert.key' sslkey='/cert.crt'", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require.Equal(t, tt.desiredDSN, tt.cfg.dsn()) }) } } func TestPostgresDriver(t *testing.T) { host := os.Getenv(PostgresEntHostEnv) if host == "" { t.Skipf("test environment variable %s not set, skipping", PostgresEntHostEnv) } port := uint64(5432) if rawPort := os.Getenv(PostgresEntPortEnv); rawPort != "" { var err error port, err = strconv.ParseUint(rawPort, 10, 32) require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err) } tests := []struct { name string cfg func() *Postgres desiredConns int }{ { name: "Defaults", cfg: func() *Postgres { return postgresTestConfig(host, port) }, desiredConns: 5, }, { name: "Tune", cfg: func() *Postgres { cfg := postgresTestConfig(host, port) cfg.MaxOpenConns = 101 return cfg }, desiredConns: 101, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { drv, err := tt.cfg().driver() require.NoError(t, err) require.Equal(t, tt.desiredConns, drv.DB().Stats().MaxOpenConnections) }) } }