dex/storage/ent/postgres_test.go

184 lines
4.1 KiB
Go
Raw Normal View History

package ent
import (
"os"
"strconv"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/conformance"
)
func getenv(key, defaultVal string) string {
if val := os.Getenv(key); val != "" {
return val
}
return defaultVal
}
func postgresTestConfig(host string, port uint64) *Postgres {
return &Postgres{
NetworkDB: NetworkDB{
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
User: getenv("DEX_POSTGRES_USER", "postgres"),
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
Host: host,
Port: uint16(port),
},
SSL: SSL{
Mode: pgSSLDisable, // Postgres container doesn't support SSL.
},
}
}
func newPostgresStorage(host string, port uint64) storage.Storage {
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
cfg := postgresTestConfig(host, port)
s, err := cfg.Open(logger)
if err != nil {
panic(err)
}
return s
}
func TestPostgres(t *testing.T) {
host := os.Getenv("DEX_POSTGRES_HOST")
if host == "" {
t.Skipf("test environment variable DEX_POSTGRES_HOST not set, skipping")
}
port := uint64(5432)
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
var err error
port, err = strconv.ParseUint(rawPort, 10, 32)
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
}
newStorage := func() storage.Storage {
return newPostgresStorage(host, port)
}
conformance.RunTests(t, newStorage)
conformance.RunTransactionTests(t, newStorage)
}
func TestPostgresDSN(t *testing.T) {
tests := []struct {
name string
cfg *Postgres
desiredDSN string
}{
{
name: "Host port",
cfg: &Postgres{
NetworkDB: NetworkDB{
Host: "localhost",
Port: uint16(5432),
},
},
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'",
},
{
name: "Host with port",
cfg: &Postgres{
NetworkDB: NetworkDB{
Host: "localhost:5432",
},
},
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'",
},
{
name: "Host ipv6 with port",
cfg: &Postgres{
NetworkDB: NetworkDB{
Host: "[a:b:c:d]:5432",
},
},
desiredDSN: "connect_timeout=0 host='a:b:c:d' port=5432 sslmode='verify-full'",
},
{
name: "Credentials and timeout",
cfg: &Postgres{
NetworkDB: NetworkDB{
Database: "test",
User: "test",
Password: "test",
ConnectionTimeout: 5,
},
},
desiredDSN: "connect_timeout=5 user='test' password='test' dbname='test' sslmode='verify-full'",
},
{
name: "SSL",
cfg: &Postgres{
SSL: SSL{
Mode: pgSSLRequire,
CAFile: "/ca.crt",
KeyFile: "/cert.crt",
CertFile: "/cert.key",
},
},
desiredDSN: "connect_timeout=0 sslmode='require' sslrootcert='/ca.crt' sslcert='/cert.key' sslkey='/cert.crt'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.desiredDSN, tt.cfg.dsn())
})
}
}
func TestPostgresDriver(t *testing.T) {
host := os.Getenv("DEX_POSTGRES_HOST")
if host == "" {
t.Skipf("test environment variable DEX_POSTGRES_HOST not set, skipping")
}
port := uint64(5432)
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
var err error
port, err = strconv.ParseUint(rawPort, 10, 32)
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
}
tests := []struct {
name string
cfg func() *Postgres
desiredConns int
}{
{
name: "Defaults",
cfg: func() *Postgres { return postgresTestConfig(host, port) },
desiredConns: 5,
},
{
name: "Tune",
cfg: func() *Postgres {
cfg := postgresTestConfig(host, port)
cfg.MaxOpenConns = 101
return cfg
},
desiredConns: 101,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
drv, err := tt.cfg().driver()
require.NoError(t, err)
require.Equal(t, tt.desiredConns, drv.DB().Stats().MaxOpenConnections)
})
}
}