2021-09-13 15:55:17 +05:30
|
|
|
package ent
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"crypto/sha256"
|
|
|
|
"crypto/tls"
|
|
|
|
"crypto/x509"
|
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
2021-09-21 10:05:25 +05:30
|
|
|
"os"
|
2021-09-13 15:55:17 +05:30
|
|
|
"strconv"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
entSQL "entgo.io/ent/dialect/sql"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
|
|
|
|
|
|
// Register postgres driver.
|
|
|
|
_ "github.com/lib/pq"
|
|
|
|
|
|
|
|
"github.com/dexidp/dex/pkg/log"
|
|
|
|
"github.com/dexidp/dex/storage"
|
|
|
|
"github.com/dexidp/dex/storage/ent/client"
|
|
|
|
"github.com/dexidp/dex/storage/ent/db"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
// MySQL SSL modes
|
|
|
|
mysqlSSLTrue = "true"
|
|
|
|
mysqlSSLFalse = "false"
|
|
|
|
mysqlSSLSkipVerify = "skip-verify"
|
|
|
|
mysqlSSLCustom = "custom"
|
|
|
|
)
|
|
|
|
|
|
|
|
// MySQL options for creating an SQL db.
|
|
|
|
type MySQL struct {
|
|
|
|
NetworkDB
|
|
|
|
|
|
|
|
SSL SSL `json:"ssl"`
|
|
|
|
|
|
|
|
params map[string]string
|
|
|
|
}
|
|
|
|
|
|
|
|
// Open always returns a new in sqlite3 storage.
|
|
|
|
func (m *MySQL) Open(logger log.Logger) (storage.Storage, error) {
|
|
|
|
logger.Debug("experimental ent-based storage driver is enabled")
|
|
|
|
drv, err := m.driver()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
databaseClient := client.NewDatabase(
|
|
|
|
client.WithClient(db.NewClient(db.Driver(drv))),
|
|
|
|
client.WithHasher(sha256.New),
|
|
|
|
// Set tx isolation leve for each transaction as dex does for postgres
|
|
|
|
client.WithTxIsolationLevel(sql.LevelSerializable),
|
|
|
|
)
|
|
|
|
|
|
|
|
if err := databaseClient.Schema().Create(context.TODO()); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return databaseClient, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *MySQL) driver() (*entSQL.Driver, error) {
|
|
|
|
var tlsConfig string
|
|
|
|
|
|
|
|
switch {
|
|
|
|
case m.SSL.CAFile != "" || m.SSL.CertFile != "" || m.SSL.KeyFile != "":
|
|
|
|
if err := m.makeTLSConfig(); err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to make TLS config: %v", err)
|
|
|
|
}
|
|
|
|
tlsConfig = mysqlSSLCustom
|
|
|
|
case m.SSL.Mode == "":
|
|
|
|
tlsConfig = mysqlSSLTrue
|
|
|
|
default:
|
|
|
|
tlsConfig = m.SSL.Mode
|
|
|
|
}
|
|
|
|
|
|
|
|
drv, err := entSQL.Open("mysql", m.dsn(tlsConfig))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
if m.MaxIdleConns == 0 {
|
|
|
|
/* Override default behaviour to fix https://github.com/dexidp/dex/issues/1608 */
|
|
|
|
drv.DB().SetMaxIdleConns(0)
|
|
|
|
} else {
|
|
|
|
drv.DB().SetMaxIdleConns(m.MaxIdleConns)
|
|
|
|
}
|
|
|
|
|
|
|
|
return drv, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *MySQL) dsn(tlsConfig string) string {
|
|
|
|
cfg := mysql.Config{
|
|
|
|
User: m.User,
|
|
|
|
Passwd: m.Password,
|
|
|
|
DBName: m.Database,
|
|
|
|
AllowNativePasswords: true,
|
|
|
|
|
|
|
|
Timeout: time.Second * time.Duration(m.ConnectionTimeout),
|
|
|
|
|
|
|
|
TLSConfig: tlsConfig,
|
|
|
|
|
|
|
|
ParseTime: true,
|
2021-09-13 19:18:02 +05:30
|
|
|
Params: make(map[string]string),
|
2021-09-13 15:55:17 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
if m.Host != "" {
|
|
|
|
if m.Host[0] != '/' {
|
|
|
|
cfg.Net = "tcp"
|
|
|
|
cfg.Addr = m.Host
|
|
|
|
|
|
|
|
if m.Port != 0 {
|
|
|
|
cfg.Addr = net.JoinHostPort(m.Host, strconv.Itoa(int(m.Port)))
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
cfg.Net = "unix"
|
|
|
|
cfg.Addr = m.Host
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for k, v := range m.params {
|
|
|
|
cfg.Params[k] = v
|
|
|
|
}
|
|
|
|
|
|
|
|
return cfg.FormatDSN()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *MySQL) makeTLSConfig() error {
|
|
|
|
cfg := &tls.Config{}
|
|
|
|
|
|
|
|
if m.SSL.CAFile != "" {
|
|
|
|
rootCertPool := x509.NewCertPool()
|
|
|
|
|
2021-09-21 10:05:25 +05:30
|
|
|
pem, err := os.ReadFile(m.SSL.CAFile)
|
2021-09-13 15:55:17 +05:30
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
|
|
|
return fmt.Errorf("failed to append PEM")
|
|
|
|
}
|
|
|
|
cfg.RootCAs = rootCertPool
|
|
|
|
}
|
|
|
|
|
|
|
|
if m.SSL.CertFile != "" && m.SSL.KeyFile != "" {
|
|
|
|
clientCert := make([]tls.Certificate, 0, 1)
|
|
|
|
certs, err := tls.LoadX509KeyPair(m.SSL.CertFile, m.SSL.KeyFile)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
clientCert = append(clientCert, certs)
|
|
|
|
cfg.Certificates = clientCert
|
|
|
|
}
|
|
|
|
|
|
|
|
mysql.RegisterTLSConfig(mysqlSSLCustom, cfg)
|
|
|
|
return nil
|
|
|
|
}
|