diff --git a/storage/sql/config.go b/storage/sql/config.go index ec001b2c..3be63b30 100644 --- a/storage/sql/config.go +++ b/storage/sql/config.go @@ -3,8 +3,9 @@ package sql import ( "database/sql" "fmt" - "net/url" + "regexp" "strconv" + "strings" "github.com/lib/pq" sqlite3 "github.com/mattn/go-sqlite3" @@ -81,6 +82,7 @@ type Postgres struct { User string Password string Host string + Port uint16 SSL PostgresSSL `json:"ssl" yaml:"ssl"` @@ -89,45 +91,75 @@ type Postgres struct { // Open creates a new storage implementation backed by Postgres. func (p *Postgres) Open(logger logrus.FieldLogger) (storage.Storage, error) { - conn, err := p.open(logger) + conn, err := p.open(logger, p.createDataSourceName()) if err != nil { return nil, err } return conn, nil } -func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) { - v := url.Values{} - set := func(key, val string) { - if val != "" { - v.Set(key, val) - } - } - set("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) - set("sslkey", p.SSL.KeyFile) - set("sslcert", p.SSL.CertFile) - set("sslrootcert", p.SSL.CAFile) - if p.SSL.Mode == "" { - // Assume the strictest mode if unspecified. - p.SSL.Mode = sslVerifyFull - } - set("sslmode", p.SSL.Mode) +var strEsc = regexp.MustCompile(`([\\'])`) - u := url.URL{ - Scheme: "postgres", - Host: p.Host, - Path: "/" + p.Database, - RawQuery: v.Encode(), +func dataSourceStr(str string) string { + return "'" + strEsc.ReplaceAllString(str, `\$1`) + "'" +} + +// createDataSourceName takes the configuration provided via the Postgres +// struct to create a data-source name that Go's database/sql package can +// make use of. +func (p *Postgres) createDataSourceName() string { + parameters := []string{} + + addParam := func(key, val string) { + parameters = append(parameters, fmt.Sprintf("%s=%s", key, val)) + } + + addParam("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) + + if p.Host != "" { + addParam("host", dataSourceStr(p.Host)) + } + + if p.Port != 0 { + addParam("port", strconv.Itoa(int(p.Port))) } if p.User != "" { - if p.Password != "" { - u.User = url.UserPassword(p.User, p.Password) - } else { - u.User = url.User(p.User) - } + addParam("user", dataSourceStr(p.User)) } - db, err := sql.Open("postgres", u.String()) + + if p.Password != "" { + addParam("password", dataSourceStr(p.Password)) + } + + if p.Database != "" { + addParam("dbname", dataSourceStr(p.Database)) + } + + if p.SSL.Mode == "" { + // Assume the strictest mode if unspecified. + addParam("sslmode", dataSourceStr(sslVerifyFull)) + } else { + addParam("sslmode", dataSourceStr(p.SSL.Mode)) + } + + if p.SSL.CAFile != "" { + addParam("sslrootcert", dataSourceStr(p.SSL.CAFile)) + } + + if p.SSL.CertFile != "" { + addParam("sslcert", dataSourceStr(p.SSL.CertFile)) + } + + if p.SSL.KeyFile != "" { + addParam("sslkey", dataSourceStr(p.SSL.KeyFile)) + } + + return strings.Join(parameters, " ") +} + +func (p *Postgres) open(logger logrus.FieldLogger, dataSourceName string) (*conn, error) { + db, err := sql.Open("postgres", dataSourceName) if err != nil { return nil, err } diff --git a/storage/sql/config_test.go b/storage/sql/config_test.go index 75b81b67..972fe973 100644 --- a/storage/sql/config_test.go +++ b/storage/sql/config_test.go @@ -77,6 +77,103 @@ func getenv(key, defaultVal string) string { const testPostgresEnv = "DEX_POSTGRES_HOST" +func TestCreateDataSourceName(t *testing.T) { + var testCases = []struct { + description string + input *Postgres + expected string + }{ + { + description: "with no configuration", + input: &Postgres{}, + expected: "connect_timeout=0 sslmode='verify-full'", + }, + { + description: "with typical configuration", + input: &Postgres{ + Host: "1.2.3.4", + Port: 6543, + User: "some-user", + Password: "some-password", + Database: "some-db", + }, + expected: "connect_timeout=0 host='1.2.3.4' port=6543 user='some-user' password='some-password' dbname='some-db' sslmode='verify-full'", + }, + { + description: "with unix socket host", + input: &Postgres{ + Host: "/var/run/postgres", + SSL: PostgresSSL{ + Mode: "disable", + }, + }, + expected: "connect_timeout=0 host='/var/run/postgres' sslmode='disable'", + }, + { + description: "with tcp host", + input: &Postgres{ + Host: "coreos.com", + SSL: PostgresSSL{ + Mode: "disable", + }, + }, + expected: "connect_timeout=0 host='coreos.com' sslmode='disable'", + }, + { + description: "with tcp host and port", + input: &Postgres{ + Host: "coreos.com", + Port: 6543, + }, + expected: "connect_timeout=0 host='coreos.com' port=6543 sslmode='verify-full'", + }, + { + description: "with ssl ca cert", + input: &Postgres{ + Host: "coreos.com", + SSL: PostgresSSL{ + Mode: "verify-ca", + CAFile: "/some/file/path", + }, + }, + expected: "connect_timeout=0 host='coreos.com' sslmode='verify-ca' sslrootcert='/some/file/path'", + }, + { + description: "with ssl client cert", + input: &Postgres{ + Host: "coreos.com", + SSL: PostgresSSL{ + Mode: "verify-ca", + CAFile: "/some/ca/path", + CertFile: "/some/cert/path", + KeyFile: "/some/key/path", + }, + }, + expected: "connect_timeout=0 host='coreos.com' sslmode='verify-ca' sslrootcert='/some/ca/path' sslcert='/some/cert/path' sslkey='/some/key/path'", + }, + { + description: "with funny characters in credentials", + input: &Postgres{ + Host: "coreos.com", + User: `some'user\slashed`, + Password: "some'password!", + }, + expected: `connect_timeout=0 host='coreos.com' user='some\'user\\slashed' password='some\'password!' sslmode='verify-full'`, + }, + } + + var actual string + for _, testCase := range testCases { + t.Run(testCase.description, func(t *testing.T) { + actual = testCase.input.createDataSourceName() + + if actual != testCase.expected { + t.Fatalf("%s != %s", actual, testCase.expected) + } + }) + } +} + func TestPostgres(t *testing.T) { host := os.Getenv(testPostgresEnv) if host == "" { @@ -100,7 +197,7 @@ func TestPostgres(t *testing.T) { } newStorage := func() storage.Storage { - conn, err := p.open(logger) + conn, err := p.open(logger, p.createDataSourceName()) if err != nil { fatal(err) }