forked from mystiq/dex
184 lines
4.1 KiB
Go
184 lines
4.1 KiB
Go
|
package ent
|
||
|
|
||
|
import (
|
||
|
"os"
|
||
|
"strconv"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/sirupsen/logrus"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
|
||
|
"github.com/dexidp/dex/storage"
|
||
|
"github.com/dexidp/dex/storage/conformance"
|
||
|
)
|
||
|
|
||
|
func getenv(key, defaultVal string) string {
|
||
|
if val := os.Getenv(key); val != "" {
|
||
|
return val
|
||
|
}
|
||
|
return defaultVal
|
||
|
}
|
||
|
|
||
|
func postgresTestConfig(host string, port uint64) *Postgres {
|
||
|
return &Postgres{
|
||
|
NetworkDB: NetworkDB{
|
||
|
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
|
||
|
User: getenv("DEX_POSTGRES_USER", "postgres"),
|
||
|
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
|
||
|
Host: host,
|
||
|
Port: uint16(port),
|
||
|
},
|
||
|
SSL: SSL{
|
||
|
Mode: pgSSLDisable, // Postgres container doesn't support SSL.
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func newPostgresStorage(host string, port uint64) storage.Storage {
|
||
|
logger := &logrus.Logger{
|
||
|
Out: os.Stderr,
|
||
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||
|
Level: logrus.DebugLevel,
|
||
|
}
|
||
|
|
||
|
cfg := postgresTestConfig(host, port)
|
||
|
s, err := cfg.Open(logger)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
func TestPostgres(t *testing.T) {
|
||
|
host := os.Getenv("DEX_POSTGRES_HOST")
|
||
|
if host == "" {
|
||
|
t.Skipf("test environment variable DEX_POSTGRES_HOST not set, skipping")
|
||
|
}
|
||
|
|
||
|
port := uint64(5432)
|
||
|
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
|
||
|
var err error
|
||
|
|
||
|
port, err = strconv.ParseUint(rawPort, 10, 32)
|
||
|
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
|
||
|
}
|
||
|
|
||
|
newStorage := func() storage.Storage {
|
||
|
return newPostgresStorage(host, port)
|
||
|
}
|
||
|
conformance.RunTests(t, newStorage)
|
||
|
conformance.RunTransactionTests(t, newStorage)
|
||
|
}
|
||
|
|
||
|
func TestPostgresDSN(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
cfg *Postgres
|
||
|
desiredDSN string
|
||
|
}{
|
||
|
{
|
||
|
name: "Host port",
|
||
|
cfg: &Postgres{
|
||
|
NetworkDB: NetworkDB{
|
||
|
Host: "localhost",
|
||
|
Port: uint16(5432),
|
||
|
},
|
||
|
},
|
||
|
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'",
|
||
|
},
|
||
|
{
|
||
|
name: "Host with port",
|
||
|
cfg: &Postgres{
|
||
|
NetworkDB: NetworkDB{
|
||
|
Host: "localhost:5432",
|
||
|
},
|
||
|
},
|
||
|
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'",
|
||
|
},
|
||
|
{
|
||
|
name: "Host ipv6 with port",
|
||
|
cfg: &Postgres{
|
||
|
NetworkDB: NetworkDB{
|
||
|
Host: "[a:b:c:d]:5432",
|
||
|
},
|
||
|
},
|
||
|
desiredDSN: "connect_timeout=0 host='a:b:c:d' port=5432 sslmode='verify-full'",
|
||
|
},
|
||
|
{
|
||
|
name: "Credentials and timeout",
|
||
|
cfg: &Postgres{
|
||
|
NetworkDB: NetworkDB{
|
||
|
Database: "test",
|
||
|
User: "test",
|
||
|
Password: "test",
|
||
|
ConnectionTimeout: 5,
|
||
|
},
|
||
|
},
|
||
|
desiredDSN: "connect_timeout=5 user='test' password='test' dbname='test' sslmode='verify-full'",
|
||
|
},
|
||
|
{
|
||
|
name: "SSL",
|
||
|
cfg: &Postgres{
|
||
|
SSL: SSL{
|
||
|
Mode: pgSSLRequire,
|
||
|
CAFile: "/ca.crt",
|
||
|
KeyFile: "/cert.crt",
|
||
|
CertFile: "/cert.key",
|
||
|
},
|
||
|
},
|
||
|
desiredDSN: "connect_timeout=0 sslmode='require' sslrootcert='/ca.crt' sslcert='/cert.key' sslkey='/cert.crt'",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
require.Equal(t, tt.desiredDSN, tt.cfg.dsn())
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestPostgresDriver(t *testing.T) {
|
||
|
host := os.Getenv("DEX_POSTGRES_HOST")
|
||
|
if host == "" {
|
||
|
t.Skipf("test environment variable DEX_POSTGRES_HOST not set, skipping")
|
||
|
}
|
||
|
|
||
|
port := uint64(5432)
|
||
|
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
|
||
|
var err error
|
||
|
|
||
|
port, err = strconv.ParseUint(rawPort, 10, 32)
|
||
|
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
|
||
|
}
|
||
|
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
cfg func() *Postgres
|
||
|
desiredConns int
|
||
|
}{
|
||
|
{
|
||
|
name: "Defaults",
|
||
|
cfg: func() *Postgres { return postgresTestConfig(host, port) },
|
||
|
desiredConns: 5,
|
||
|
},
|
||
|
{
|
||
|
name: "Tune",
|
||
|
cfg: func() *Postgres {
|
||
|
cfg := postgresTestConfig(host, port)
|
||
|
cfg.MaxOpenConns = 101
|
||
|
return cfg
|
||
|
},
|
||
|
desiredConns: 101,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
drv, err := tt.cfg().driver()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.Equal(t, tt.desiredConns, drv.DB().Stats().MaxOpenConnections)
|
||
|
})
|
||
|
}
|
||
|
}
|