package sql import ( "crypto/tls" "crypto/x509" "database/sql" "fmt" "io/ioutil" "net" "regexp" "strconv" "strings" "time" "github.com/go-sql-driver/mysql" "github.com/lib/pq" "github.com/dexidp/dex/pkg/log" "github.com/dexidp/dex/storage" ) const ( // postgres error codes pgErrUniqueViolation = "23505" // unique_violation ) const ( // MySQL error codes mysqlErrDupEntry = 1062 mysqlErrDupEntryWithKeyName = 1586 mysqlErrUnknownSysVar = 1193 ) // nolint const ( // postgres SSL modes pgSSLDisable = "disable" pgSSLRequire = "require" pgSSLVerifyCA = "verify-ca" pgSSLVerifyFull = "verify-full" ) // nolint const ( // MySQL SSL modes mysqlSSLTrue = "true" mysqlSSLFalse = "false" mysqlSSLSkipVerify = "skip-verify" mysqlSSLCustom = "custom" ) // NetworkDB contains options common to SQL databases accessed over network. type NetworkDB struct { Database string User string Password string Host string Port uint16 ConnectionTimeout int // Seconds // database/sql tunables, see // https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime and below // Note: defaults will be set if these are 0 MaxOpenConns int // default: 5 MaxIdleConns int // default: 5 ConnMaxLifetime int // Seconds, default: not set } // SSL represents SSL options for network databases. type SSL struct { Mode string CAFile string // Files for client auth. KeyFile string CertFile string } // Postgres options for creating an SQL db. type Postgres struct { NetworkDB SSL SSL `json:"ssl" yaml:"ssl"` } // Open creates a new storage implementation backed by Postgres. func (p *Postgres) Open(logger log.Logger) (storage.Storage, error) { conn, err := p.open(logger) if err != nil { return nil, err } return conn, nil } var strEsc = regexp.MustCompile(`([\\'])`) func dataSourceStr(str string) string { return "'" + strEsc.ReplaceAllString(str, `\$1`) + "'" } // createDataSourceName takes the configuration provided via the Postgres // struct to create a data-source name that Go's database/sql package can // make use of. func (p *Postgres) createDataSourceName() string { parameters := []string{} addParam := func(key, val string) { parameters = append(parameters, fmt.Sprintf("%s=%s", key, val)) } addParam("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) // detect host:port for backwards-compatibility host, port, err := net.SplitHostPort(p.Host) if err != nil { // not host:port, probably unix socket or bare address host = p.Host if p.Port != 0 { port = strconv.Itoa(int(p.Port)) } } if host != "" { addParam("host", dataSourceStr(host)) } if port != "" { addParam("port", port) } if p.User != "" { addParam("user", dataSourceStr(p.User)) } if p.Password != "" { addParam("password", dataSourceStr(p.Password)) } if p.Database != "" { addParam("dbname", dataSourceStr(p.Database)) } if p.SSL.Mode == "" { // Assume the strictest mode if unspecified. addParam("sslmode", dataSourceStr(pgSSLVerifyFull)) } else { addParam("sslmode", dataSourceStr(p.SSL.Mode)) } if p.SSL.CAFile != "" { addParam("sslrootcert", dataSourceStr(p.SSL.CAFile)) } if p.SSL.CertFile != "" { addParam("sslcert", dataSourceStr(p.SSL.CertFile)) } if p.SSL.KeyFile != "" { addParam("sslkey", dataSourceStr(p.SSL.KeyFile)) } return strings.Join(parameters, " ") } func (p *Postgres) open(logger log.Logger) (*conn, error) { dataSourceName := p.createDataSourceName() db, err := sql.Open("postgres", dataSourceName) if err != nil { return nil, err } // set database/sql tunables if configured if p.ConnMaxLifetime != 0 { db.SetConnMaxLifetime(time.Duration(p.ConnMaxLifetime) * time.Second) } if p.MaxIdleConns == 0 { db.SetMaxIdleConns(5) } else { db.SetMaxIdleConns(p.MaxIdleConns) } if p.MaxOpenConns == 0 { db.SetMaxOpenConns(5) } else { db.SetMaxOpenConns(p.MaxOpenConns) } errCheck := func(err error) bool { sqlErr, ok := err.(*pq.Error) if !ok { return false } return sqlErr.Code == pgErrUniqueViolation } c := &conn{db, &flavorPostgres, logger, errCheck} if _, err := c.migrate(); err != nil { return nil, fmt.Errorf("failed to perform migrations: %v", err) } return c, nil } // MySQL options for creating a MySQL db. type MySQL struct { NetworkDB SSL SSL `json:"ssl" yaml:"ssl"` // TODO(pborzenkov): used by tests to reduce lock wait timeout. Should // we make it exported and allow users to provide arbitrary params? params map[string]string } // Open creates a new storage implementation backed by MySQL. func (s *MySQL) Open(logger log.Logger) (storage.Storage, error) { conn, err := s.open(logger) if err != nil { return nil, err } return conn, nil } func (s *MySQL) open(logger log.Logger) (*conn, error) { cfg := mysql.Config{ User: s.User, Passwd: s.Password, DBName: s.Database, AllowNativePasswords: true, Timeout: time.Second * time.Duration(s.ConnectionTimeout), ParseTime: true, Params: map[string]string{ "transaction_isolation": "'SERIALIZABLE'", }, } if s.Host != "" { if s.Host[0] != '/' { cfg.Net = "tcp" cfg.Addr = s.Host } else { cfg.Net = "unix" cfg.Addr = s.Host } } switch { case s.SSL.CAFile != "" || s.SSL.CertFile != "" || s.SSL.KeyFile != "": if err := s.makeTLSConfig(); err != nil { return nil, fmt.Errorf("failed to make TLS config: %v", err) } cfg.TLSConfig = mysqlSSLCustom case s.SSL.Mode == "": cfg.TLSConfig = mysqlSSLTrue default: cfg.TLSConfig = s.SSL.Mode } for k, v := range s.params { cfg.Params[k] = v } db, err := sql.Open("mysql", cfg.FormatDSN()) if err != nil { return nil, err } if s.MaxIdleConns == 0 { /*Override default behaviour to fix https://github.com/dexidp/dex/issues/1608*/ db.SetMaxIdleConns(0) } else { db.SetMaxIdleConns(s.MaxIdleConns) } err = db.Ping() if err != nil { if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == mysqlErrUnknownSysVar { logger.Info("reconnecting with MySQL pre-5.7.20 compatibility mode") // MySQL 5.7.20 introduced transaction_isolation and deprecated tx_isolation. // MySQL 8.0 doesn't have tx_isolation at all. // https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation delete(cfg.Params, "transaction_isolation") cfg.Params["tx_isolation"] = "'SERIALIZABLE'" db, err = sql.Open("mysql", cfg.FormatDSN()) if err != nil { return nil, err } } else { return nil, err } } errCheck := func(err error) bool { sqlErr, ok := err.(*mysql.MySQLError) if !ok { return false } return sqlErr.Number == mysqlErrDupEntry || sqlErr.Number == mysqlErrDupEntryWithKeyName } c := &conn{db, &flavorMySQL, logger, errCheck} if _, err := c.migrate(); err != nil { return nil, fmt.Errorf("failed to perform migrations: %v", err) } return c, nil } func (s *MySQL) makeTLSConfig() error { cfg := &tls.Config{} if s.SSL.CAFile != "" { rootCertPool := x509.NewCertPool() pem, err := ioutil.ReadFile(s.SSL.CAFile) if err != nil { return err } if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { return fmt.Errorf("failed to append PEM") } cfg.RootCAs = rootCertPool } if s.SSL.CertFile != "" && s.SSL.KeyFile != "" { clientCert := make([]tls.Certificate, 0, 1) certs, err := tls.LoadX509KeyPair(s.SSL.CertFile, s.SSL.KeyFile) if err != nil { return err } clientCert = append(clientCert, certs) cfg.Certificates = clientCert } mysql.RegisterTLSConfig(mysqlSSLCustom, cfg) return nil }