package ent import ( "context" "crypto/sha256" "crypto/tls" "crypto/x509" "database/sql" "fmt" "net" "os" "strconv" "time" entSQL "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/schema" "github.com/go-sql-driver/mysql" // Register postgres driver. _ "github.com/lib/pq" "github.com/dexidp/dex/pkg/log" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/ent/client" "github.com/dexidp/dex/storage/ent/db" ) const ( // MySQL SSL modes mysqlSSLTrue = "true" mysqlSSLFalse = "false" mysqlSSLSkipVerify = "skip-verify" mysqlSSLCustom = "custom" ) // MySQL options for creating an SQL db. type MySQL struct { NetworkDB SSL SSL `json:"ssl"` params map[string]string } // Open always returns a new in sqlite3 storage. func (m *MySQL) Open(logger log.Logger) (storage.Storage, error) { logger.Debug("experimental ent-based storage driver is enabled") drv, err := m.driver() if err != nil { return nil, err } databaseClient := client.NewDatabase( client.WithClient(db.NewClient(db.Driver(drv))), client.WithHasher(sha256.New), // Set tx isolation leve for each transaction as dex does for postgres client.WithTxIsolationLevel(sql.LevelSerializable), ) if err := databaseClient.Schema().Create(context.TODO(), schema.WithAtlas(false)); err != nil { return nil, err } return databaseClient, nil } func (m *MySQL) driver() (*entSQL.Driver, error) { var tlsConfig string switch { case m.SSL.CAFile != "" || m.SSL.CertFile != "" || m.SSL.KeyFile != "": if err := m.makeTLSConfig(); err != nil { return nil, fmt.Errorf("failed to make TLS config: %v", err) } tlsConfig = mysqlSSLCustom case m.SSL.Mode == "": tlsConfig = mysqlSSLTrue default: tlsConfig = m.SSL.Mode } drv, err := entSQL.Open("mysql", m.dsn(tlsConfig)) if err != nil { return nil, err } if m.MaxIdleConns == 0 { /* Override default behaviour to fix https://github.com/dexidp/dex/issues/1608 */ drv.DB().SetMaxIdleConns(0) } else { drv.DB().SetMaxIdleConns(m.MaxIdleConns) } return drv, nil } func (m *MySQL) dsn(tlsConfig string) string { cfg := mysql.Config{ User: m.User, Passwd: m.Password, DBName: m.Database, AllowNativePasswords: true, Timeout: time.Second * time.Duration(m.ConnectionTimeout), TLSConfig: tlsConfig, ParseTime: true, Params: make(map[string]string), } if m.Host != "" { if m.Host[0] != '/' { cfg.Net = "tcp" cfg.Addr = m.Host if m.Port != 0 { cfg.Addr = net.JoinHostPort(m.Host, strconv.Itoa(int(m.Port))) } } else { cfg.Net = "unix" cfg.Addr = m.Host } } for k, v := range m.params { cfg.Params[k] = v } return cfg.FormatDSN() } func (m *MySQL) makeTLSConfig() error { cfg := &tls.Config{} if m.SSL.CAFile != "" { rootCertPool := x509.NewCertPool() pem, err := os.ReadFile(m.SSL.CAFile) if err != nil { return err } if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { return fmt.Errorf("failed to append PEM") } cfg.RootCAs = rootCertPool } if m.SSL.CertFile != "" && m.SSL.KeyFile != "" { clientCert := make([]tls.Certificate, 0, 1) certs, err := tls.LoadX509KeyPair(m.SSL.CertFile, m.SSL.KeyFile) if err != nil { return err } clientCert = append(clientCert, certs) cfg.Certificates = clientCert } mysql.RegisterTLSConfig(mysqlSSLCustom, cfg) return nil }