Merge pull request #1659 from dexidp/sql-specific-migrations
storage/sql: allow specifying sql flavor specific migrations
This commit is contained in:
commit
edd3a40141
4 changed files with 36 additions and 9 deletions
|
@ -66,7 +66,7 @@ func (s *SQLite3) open(logger log.Logger) (*conn, error) {
|
||||||
return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey
|
return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &conn{db, flavorSQLite3, logger, errCheck}
|
c := &conn{db, &flavorSQLite3, logger, errCheck}
|
||||||
if _, err := c.migrate(); err != nil {
|
if _, err := c.migrate(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -239,7 +239,7 @@ func (p *Postgres) open(logger log.Logger) (*conn, error) {
|
||||||
return sqlErr.Code == pgErrUniqueViolation
|
return sqlErr.Code == pgErrUniqueViolation
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &conn{db, flavorPostgres, logger, errCheck}
|
c := &conn{db, &flavorPostgres, logger, errCheck}
|
||||||
if _, err := c.migrate(); err != nil {
|
if _, err := c.migrate(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -344,7 +344,7 @@ func (s *MySQL) open(logger log.Logger) (*conn, error) {
|
||||||
sqlErr.Number == mysqlErrDupEntryWithKeyName
|
sqlErr.Number == mysqlErrDupEntryWithKeyName
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &conn{db, flavorMySQL, logger, errCheck}
|
c := &conn{db, &flavorMySQL, logger, errCheck}
|
||||||
if _, err := c.migrate(); err != nil {
|
if _, err := c.migrate(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,14 @@ func (c *conn) migrate() (int, error) {
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
done := false
|
done := false
|
||||||
|
|
||||||
|
var flavorMigrations []migration
|
||||||
|
for _, m := range migrations {
|
||||||
|
if m.flavor == nil || m.flavor == c.flavor {
|
||||||
|
flavorMigrations = append(flavorMigrations, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
err := c.ExecTx(func(tx *trans) error {
|
err := c.ExecTx(func(tx *trans) error {
|
||||||
// Within a transaction, perform a single migration.
|
// Within a transaction, perform a single migration.
|
||||||
|
@ -31,13 +39,13 @@ func (c *conn) migrate() (int, error) {
|
||||||
if num.Valid {
|
if num.Valid {
|
||||||
n = int(num.Int64)
|
n = int(num.Int64)
|
||||||
}
|
}
|
||||||
if n >= len(migrations) {
|
if n >= len(flavorMigrations) {
|
||||||
done = true
|
done = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
migrationNum := n + 1
|
migrationNum := n + 1
|
||||||
m := migrations[n]
|
m := flavorMigrations[n]
|
||||||
for i := range m.stmts {
|
for i := range m.stmts {
|
||||||
if _, err := tx.Exec(m.stmts[i]); err != nil {
|
if _, err := tx.Exec(m.stmts[i]); err != nil {
|
||||||
return fmt.Errorf("migration %d statement %d failed: %v", migrationNum, i+1, err)
|
return fmt.Errorf("migration %d statement %d failed: %v", migrationNum, i+1, err)
|
||||||
|
@ -64,7 +72,11 @@ func (c *conn) migrate() (int, error) {
|
||||||
|
|
||||||
type migration struct {
|
type migration struct {
|
||||||
stmts []string
|
stmts []string
|
||||||
// TODO(ericchiang): consider adding additional fields like "forDrivers"
|
|
||||||
|
// If flavor is nil the migration will take place for all database backend flavors.
|
||||||
|
// If specified, only for that corresponding flavor, in that case stmts can be written
|
||||||
|
// in the specific SQL dialect.
|
||||||
|
flavor *flavor
|
||||||
}
|
}
|
||||||
|
|
||||||
// All SQL flavors share migration strategies.
|
// All SQL flavors share migration strategies.
|
||||||
|
@ -209,4 +221,12 @@ var migrations = []migration{
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
stmts: []string{`
|
||||||
|
alter table auth_request
|
||||||
|
modify column state varchar(4096);
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
flavor: &flavorMySQL,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,8 +30,15 @@ func TestMigrate(t *testing.T) {
|
||||||
return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique
|
return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &conn{db, flavorSQLite3, logger, errCheck}
|
var sqliteMigrations []migration
|
||||||
for _, want := range []int{len(migrations), 0} {
|
for _, m := range migrations {
|
||||||
|
if m.flavor == nil || m.flavor == &flavorSQLite3 {
|
||||||
|
sqliteMigrations = append(sqliteMigrations, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &conn{db, &flavorSQLite3, logger, errCheck}
|
||||||
|
for _, want := range []int{len(sqliteMigrations), 0} {
|
||||||
got, err := c.migrate()
|
got, err := c.migrate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
|
@ -130,7 +130,7 @@ func (c *conn) translateArgs(args []interface{}) []interface{} {
|
||||||
// conn is the main database connection.
|
// conn is the main database connection.
|
||||||
type conn struct {
|
type conn struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
flavor flavor
|
flavor *flavor
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
alreadyExistsCheck func(err error) bool
|
alreadyExistsCheck func(err error) bool
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue