diff --git a/db/client.go b/db/client.go index 4d4ea86d..22fcae85 100644 --- a/db/client.go +++ b/db/client.go @@ -11,6 +11,7 @@ import ( "github.com/coreos/go-oidc/oidc" "github.com/go-gorp/gorp" "github.com/lib/pq" + "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" "github.com/coreos/dex/client" @@ -89,23 +90,29 @@ func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo { } func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) { - repo := NewClientIdentityRepo(dbm).(*clientIdentityRepo) + tx, err := dbm.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() for _, c := range clients { dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret) if err != nil { return nil, err } - cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata) if err != nil { return nil, err } - err = repo.dbMap.Insert(cm) + err = tx.Insert(cm) if err != nil { return nil, err } } - return repo, nil + if err := tx.Commit(); err != nil { + return nil, err + } + return NewClientIdentityRepo(dbm), nil } type clientIdentityRepo struct { @@ -155,8 +162,9 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error { if err != nil { return err } + defer tx.Rollback() - m, err := r.dbMap.Get(clientIdentityModel{}, clientID) + m, err := tx.Get(clientIdentityModel{}, clientID) if m == nil || err != nil { rollback(tx) return err @@ -164,25 +172,17 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error { cim, ok := m.(*clientIdentityModel) if !ok { - rollback(tx) log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m)) return errors.New("unrecognized model") } cim.DexAdmin = isAdmin - _, err = r.dbMap.Update(cim) + _, err = tx.Update(cim) if err != nil { - rollback(tx) return err } - err = tx.Commit() - if err != nil { - rollback(tx) - return err - } - - return nil + return tx.Commit() } func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) { @@ -223,8 +223,15 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli } if err := r.dbMap.Insert(cim); err != nil { - if perr, ok := err.(*pq.Error); ok && perr.Code == pgErrorCodeUniqueViolation { - err = errors.New("client ID already exists") + switch sqlErr := err.(type) { + case *pq.Error: + if sqlErr.Code == pgErrorCodeUniqueViolation { + err = errors.New("client ID already exists") + } + case *sqlite3.Error: + if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique { + err = errors.New("client ID already exists") + } } return nil, err @@ -239,7 +246,7 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli } func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) { - qt := pq.QuoteIdentifier(clientIdentityTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", clientIdentityTableName) q := fmt.Sprintf("SELECT * FROM %s", qt) objs, err := r.dbMap.Select(&clientIdentityModel{}, q) if err != nil { diff --git a/db/conn.go b/db/conn.go index 6bff1349..8ff115f1 100644 --- a/db/conn.go +++ b/db/conn.go @@ -4,13 +4,16 @@ import ( "database/sql" "errors" "fmt" - "strings" + "net/url" "github.com/go-gorp/gorp" - _ "github.com/lib/pq" "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/repo" + + // Import database drivers + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" ) type table struct { @@ -43,22 +46,35 @@ type Config struct { } func NewConnection(cfg Config) (*gorp.DbMap, error) { - if !strings.HasPrefix(cfg.DSN, "postgres://") { + u, err := url.Parse(cfg.DSN) + if err != nil { + return nil, fmt.Errorf("parse DSN: %v", err) + } + var ( + db *sql.DB + dialect gorp.Dialect + ) + switch u.Scheme { + case "postgres": + db, err = sql.Open("postgres", cfg.DSN) + if err != nil { + return nil, err + } + db.SetMaxIdleConns(cfg.MaxIdleConnections) + db.SetMaxOpenConns(cfg.MaxOpenConnections) + dialect = gorp.PostgresDialect{} + case "sqlite3": + db, err = sql.Open("sqlite3", u.Host) + if err != nil { + return nil, err + } + // NOTE(ericchiang): sqlite does NOT work with SetMaxIdleConns. + dialect = gorp.SqliteDialect{} + default: return nil, errors.New("unrecognized database driver") } - db, err := sql.Open("postgres", cfg.DSN) - if err != nil { - return nil, err - } - - db.SetMaxIdleConns(cfg.MaxIdleConnections) - db.SetMaxOpenConns(cfg.MaxOpenConnections) - - dbm := gorp.DbMap{ - Db: db, - Dialect: gorp.PostgresDialect{}, - } + dbm := gorp.DbMap{Db: db, Dialect: dialect} for _, t := range tables { tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...) @@ -70,7 +86,6 @@ func NewConnection(cfg Config) (*gorp.DbMap, error) { cm.SetUnique(true) } } - return &dbm, nil } diff --git a/db/connector_config.go b/db/connector_config.go index 6be2832a..5f5a1a17 100644 --- a/db/connector_config.go +++ b/db/connector_config.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/go-gorp/gorp" - "github.com/lib/pq" "github.com/coreos/dex/connector" "github.com/coreos/dex/repo" @@ -69,7 +68,7 @@ type ConnectorConfigRepo struct { } func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) { - qt := pq.QuoteIdentifier(connectorConfigTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName) q := fmt.Sprintf("SELECT * FROM %s", qt) objs, err := r.dbMap.Select(&connectorConfigModel{}, q) if err != nil { @@ -94,10 +93,10 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) { } func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) { - qt := pq.QuoteIdentifier(connectorConfigTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName) q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt) var c connectorConfigModel - if err := r.executor(tx).SelectOne(&c, q, id); err != nil { + if err := executor(r.dbMap, tx).SelectOne(&c, q, id); err != nil { if err == sql.ErrNoRows { return nil, connector.ErrorNotFound } @@ -121,28 +120,17 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error { if err != nil { return err } + defer tx.Rollback() - qt := pq.QuoteIdentifier(connectorConfigTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName) q := fmt.Sprintf("DELETE FROM %s", qt) - if _, err = r.dbMap.Exec(q); err != nil { + if _, err = tx.Exec(q); err != nil { return err } - if err = r.dbMap.Insert(insert...); err != nil { + if err = tx.Insert(insert...); err != nil { return fmt.Errorf("DB insert failed %#v: %v", insert, err) } return tx.Commit() } - -func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor { - if tx == nil { - return r.dbMap - } - - gorpTx, ok := tx.(*gorp.Transaction) - if !ok { - panic("wrong kind of transaction passed to a DB repo") - } - return gorpTx -} diff --git a/db/key.go b/db/key.go index 18bdfdc4..8c4072c3 100644 --- a/db/key.go +++ b/db/key.go @@ -8,7 +8,6 @@ import ( "time" "github.com/go-gorp/gorp" - "github.com/lib/pq" pcrypto "github.com/coreos/dex/pkg/crypto" "github.com/coreos/go-oidc/key" @@ -114,7 +113,7 @@ type PrivateKeySetRepo struct { } func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { - qt := pq.QuoteIdentifier(keyTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName) _, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt)) if err != nil { return err @@ -152,7 +151,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { } func (r *PrivateKeySetRepo) Get() (key.KeySet, error) { - qt := pq.QuoteIdentifier(keyTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName) objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt)) if err != nil { return nil, err diff --git a/db/migrate.go b/db/migrate.go index 504732f6..d288b62f 100644 --- a/db/migrate.go +++ b/db/migrate.go @@ -1,19 +1,18 @@ package db import ( + "errors" "fmt" "github.com/go-gorp/gorp" - "github.com/lib/pq" - migrate "github.com/rubenv/sql-migrate" + "github.com/rubenv/sql-migrate" "github.com/coreos/dex/db/migrations" ) const ( - migrationDialect = "postgres" - migrationTable = "dex_migrations" - migrationDir = "db/migrations" + migrationTable = "dex_migrations" + migrationDir = "db/migrations" ) func init() { @@ -21,32 +20,57 @@ func init() { } func MigrateToLatest(dbMap *gorp.DbMap) (int, error) { - source := getSource() - - return migrate.Exec(dbMap.Db, migrationDialect, source, migrate.Up) + source, dialect, err := migrationSource(dbMap) + if err != nil { + return 0, err + } + return migrate.Exec(dbMap.Db, dialect, source, migrate.Up) } func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) { - source := getSource() - - return migrate.ExecMax(dbMap.Db, migrationDialect, source, migrate.Up, max) + source, dialect, err := migrationSource(dbMap) + if err != nil { + return 0, err + } + return migrate.ExecMax(dbMap.Db, dialect, source, migrate.Up, max) } func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) { - migrations, _, err := migrate.PlanMigration(dbMap.Db, migrationDialect, getSource(), migrate.Up, 0) + source, dialect, err := migrationSource(dbMap) + if err != nil { + return nil, err + } + migrations, _, err := migrate.PlanMigration(dbMap.Db, dialect, source, migrate.Up, 0) return migrations, err } func DropMigrationsTable(dbMap *gorp.DbMap) error { - qt := pq.QuoteIdentifier(migrationTable) - _, err := dbMap.Exec(fmt.Sprintf("drop table if exists %s ;", qt)) + qt := fmt.Sprintf("DROP TABLE IF EXISTS %s;", dbMap.Dialect.QuotedTableForQuery("", migrationTable)) + _, err := dbMap.Exec(qt) return err } -func getSource() migrate.MigrationSource { - return &migrate.AssetMigrationSource{ - Dir: migrationDir, - Asset: migrations.Asset, - AssetDir: migrations.AssetDir, +func migrationSource(dbMap *gorp.DbMap) (src migrate.MigrationSource, dialect string, err error) { + switch dbMap.Dialect.(type) { + case gorp.PostgresDialect: + src = &migrate.AssetMigrationSource{ + Dir: migrationDir, + Asset: migrations.Asset, + AssetDir: migrations.AssetDir, + } + return src, "postgres", nil + case gorp.SqliteDialect: + src = &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "dex.sql", + Up: []string{sqlite3Migration}, + }, + }, + } + + return src, "sqlite3", nil + default: + return nil, "", errors.New("unsupported migration driver") } } diff --git a/db/migrate_sqlite3.go b/db/migrate_sqlite3.go new file mode 100644 index 00000000..3cbfc7c3 --- /dev/null +++ b/db/migrate_sqlite3.go @@ -0,0 +1,73 @@ +package db + +// SQLite3 is a test only database. There is only one migration because we do not support migrations. +const sqlite3Migration = ` +CREATE TABLE authd_user ( + id text NOT NULL UNIQUE, + email text, + email_verified integer, + display_name text, + admin integer, + created_at bigint, + disabled integer +); + +CREATE TABLE client_identity ( + id text NOT NULL UNIQUE, + secret blob, + metadata text, + dex_admin integer +); + +CREATE TABLE connector_config ( + id text NOT NULL UNIQUE, + type text, + config text +); + +CREATE TABLE key ( + value blob +); + +CREATE TABLE password_info ( + user_id text NOT NULL UNIQUE, + password text, + password_expires bigint +); + +CREATE TABLE refresh_token ( + id integer PRIMARY KEY, + payload_hash blob, + user_id text, + client_id text +); + +CREATE TABLE remote_identity_mapping ( + connector_id text NOT NULL, + user_id text, + remote_id text NOT NULL +); + +CREATE TABLE session ( + id text NOT NULL UNIQUE, + state text, + created_at bigint, + expires_at bigint, + client_id text, + client_state text, + redirect_url text, + identity text, + connector_id text, + user_id text, + register integer, + nonce text, + scope text +); + +CREATE TABLE session_key ( + key text NOT NULL UNIQUE, + session_id text, + expires_at bigint, + stale integer +); +` diff --git a/db/password.go b/db/password.go index 4cc0b785..be278c2c 100644 --- a/db/password.go +++ b/db/password.go @@ -5,10 +5,11 @@ import ( "reflect" "time" + "github.com/go-gorp/gorp" + "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/repo" "github.com/coreos/dex/user" - "github.com/go-gorp/gorp" ) const ( @@ -89,20 +90,8 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err return nil } -func (r *passwordInfoRepo) executor(tx repo.Transaction) gorp.SqlExecutor { - if tx == nil { - return r.dbMap - } - - gorpTx, ok := tx.(*gorp.Transaction) - if !ok { - panic("wrong kind of transaction passed to a DB repo") - } - return gorpTx -} - func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) m, err := ex.Get(passwordInfoModel{}, id) if err != nil { @@ -123,7 +112,7 @@ func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInf } func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) error { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) pm, err := newPasswordInfoModel(&pw) if err != nil { return err @@ -132,7 +121,7 @@ func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) err } func (r *passwordInfoRepo) update(tx repo.Transaction, pw user.PasswordInfo) error { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) pm, err := newPasswordInfoModel(&pw) if err != nil { return err diff --git a/db/refresh.go b/db/refresh.go index 15195a46..66ad4ae2 100644 --- a/db/refresh.go +++ b/db/refresh.go @@ -8,10 +8,11 @@ import ( "strconv" "strings" - "github.com/coreos/dex/pkg/log" - "github.com/coreos/dex/refresh" "github.com/go-gorp/gorp" "golang.org/x/crypto/bcrypt" + + "github.com/coreos/dex/pkg/log" + "github.com/coreos/dex/refresh" ) const ( @@ -166,16 +167,8 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error { return nil } - -func (r *refreshTokenRepo) executor(tx *gorp.Transaction) gorp.SqlExecutor { - if tx == nil { - return r.dbMap - } - return tx -} - func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) result, err := ex.Get(refreshTokenModel{}, tokenID) if err != nil { return nil, err diff --git a/db/session.go b/db/session.go index dd46d18e..172985fc 100644 --- a/db/session.go +++ b/db/session.go @@ -11,7 +11,6 @@ import ( "github.com/go-gorp/gorp" "github.com/jonboulle/clockwork" - "github.com/lib/pq" "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/session" @@ -183,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error { } func (r *SessionRepo) purge() error { - qt := pq.QuoteIdentifier(sessionTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionTableName) q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt) - res, err := r.dbMap.Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead)) + res, err := executor(r.dbMap, nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead)) if err != nil { return err } diff --git a/db/session_key.go b/db/session_key.go index ee90a117..f58e7610 100644 --- a/db/session_key.go +++ b/db/session_key.go @@ -8,7 +8,6 @@ import ( "github.com/go-gorp/gorp" "github.com/jonboulle/clockwork" - "github.com/lib/pq" "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/session" @@ -77,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) { return "", errors.New("invalid session key") } - qt := pq.QuoteIdentifier(sessionKeyTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName) q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt) - res, err := r.dbMap.Exec(q, true, key, false) + res, err := executor(r.dbMap, nil).Exec(q, true, key, false) if err != nil { return "", err } @@ -95,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) { } func (r *SessionKeyRepo) purge() error { - qt := pq.QuoteIdentifier(sessionKeyTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName) q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt) - res, err := r.dbMap.Exec(q, true, r.clock.Now().Unix()) + res, err := executor(r.dbMap, nil).Exec(q, true, r.clock.Now().Unix()) if err != nil { return err } diff --git a/db/transaction.go b/db/transaction.go new file mode 100644 index 00000000..19da5a13 --- /dev/null +++ b/db/transaction.go @@ -0,0 +1,33 @@ +package db + +import ( + "github.com/go-gorp/gorp" + + "github.com/coreos/dex/db/translate" + "github.com/coreos/dex/repo" +) + +func executor(dbMap *gorp.DbMap, tx repo.Transaction) gorp.SqlExecutor { + var exec gorp.SqlExecutor + if tx == nil { + exec = dbMap + } else { + gorpTx, ok := tx.(*gorp.Transaction) + if !ok { + panic("wrong kind of transaction passed to a DB repo") + } + + // Check if the underlying value of the pointer is nil. + // This is not caught by the initial comparison (tx == nil). + if gorpTx == nil { + exec = dbMap + } else { + exec = gorpTx + } + } + + if _, ok := dbMap.Dialect.(gorp.SqliteDialect); ok { + exec = translate.NewExecutor(exec, translate.PostgresToSQLite) + } + return exec +} diff --git a/db/translate/translate.go b/db/translate/translate.go new file mode 100644 index 00000000..390b95ce --- /dev/null +++ b/db/translate/translate.go @@ -0,0 +1,68 @@ +/* +Package translate implements translation of driver specific SQL queries. +*/ +package translate + +import ( + "database/sql" + "regexp" + + "github.com/go-gorp/gorp" +) + +var ( + bindRegexp = regexp.MustCompile(`\$\d+`) + trueRegexp = regexp.MustCompile(`\btrue\b`) +) + +// PostgresToSQLite implements translation of the pq driver to sqlite3. +func PostgresToSQLite(query string) string { + query = bindRegexp.ReplaceAllString(query, "?") + query = trueRegexp.ReplaceAllString(query, "1") + return query +} + +func NewExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor { + return &executor{exec, translate} +} + +type executor struct { + gorp.SqlExecutor + Translate func(string) string +} + +func (e *executor) Exec(query string, args ...interface{}) (sql.Result, error) { + return e.SqlExecutor.Exec(e.Translate(query), args...) +} + +func (e *executor) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + return e.SqlExecutor.Select(i, e.Translate(query), args...) +} + +func (e *executor) SelectInt(query string, args ...interface{}) (int64, error) { + return e.SqlExecutor.SelectInt(e.Translate(query), args...) +} + +func (e *executor) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) { + return e.SqlExecutor.SelectNullInt(e.Translate(query), args...) +} + +func (e *executor) SelectFloat(query string, args ...interface{}) (float64, error) { + return e.SqlExecutor.SelectFloat(e.Translate(query), args...) +} + +func (e *executor) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) { + return e.SqlExecutor.SelectNullFloat(e.Translate(query), args...) +} + +func (e *executor) SelectStr(query string, args ...interface{}) (string, error) { + return e.SqlExecutor.SelectStr(e.Translate(query), args...) +} + +func (e *executor) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) { + return e.SqlExecutor.SelectNullStr(e.Translate(query), args...) +} + +func (e *executor) SelectOne(holder interface{}, query string, args ...interface{}) error { + return e.SqlExecutor.SelectOne(holder, e.Translate(query), args...) +} diff --git a/db/translate/translate_test.go b/db/translate/translate_test.go new file mode 100644 index 00000000..983c3a59 --- /dev/null +++ b/db/translate/translate_test.go @@ -0,0 +1,28 @@ +package translate + +import "testing" + +func TestPostgresToSQLite(t *testing.T) { + + tests := []struct { + query string + want string + }{ + {"SELECT * FROM foo", "SELECT * FROM foo"}, + {"SELECT * FROM %s", "SELECT * FROM %s"}, + {"SELECT * FROM foo WHERE is_admin=true", "SELECT * FROM foo WHERE is_admin=1"}, + {"SELECT * FROM foo WHERE is_admin=true;", "SELECT * FROM foo WHERE is_admin=1;"}, + {"SELECT * FROM foo WHERE is_admin=$10", "SELECT * FROM foo WHERE is_admin=?"}, + {"SELECT * FROM foo WHERE is_admin=$10;", "SELECT * FROM foo WHERE is_admin=?;"}, + {"SELECT * FROM foo WHERE name=$1 AND is_admin=$2;", "SELECT * FROM foo WHERE name=? AND is_admin=?;"}, + {"$1", "?"}, + {"$", "$"}, + } + + for _, tt := range tests { + got := PostgresToSQLite(tt.query) + if got != tt.want { + t.Errorf("PostgresToSQLite(%q): want=%q, got=%q", tt.query, tt.want, got) + } + } +} diff --git a/db/user.go b/db/user.go index ab0f8626..86f4638a 100644 --- a/db/user.go +++ b/db/user.go @@ -8,7 +8,6 @@ import ( "time" "github.com/go-gorp/gorp" - "github.com/lib/pq" "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/repo" @@ -107,9 +106,9 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err return user.ErrorInvalidID } - qt := pq.QuoteIdentifier(userTableName) - ex := r.executor(tx) - result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable) + qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName) + ex := executor(r.dbMap, tx) + result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID) if err != nil { return err } @@ -221,7 +220,7 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid return err } - ex := r.executor(tx) + ex := executor(r.dbMap, tx) deleted, err := ex.Delete(rim) if err != nil { @@ -236,14 +235,13 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid } func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) if userID == "" { return nil, user.ErrorInvalidID } - qt := pq.QuoteIdentifier(remoteIdentityMappingTableName) - rims, err := ex.Select(&remoteIdentityMappingModel{}, - fmt.Sprintf("select * from %s where user_id = $1", qt), userID) + qt := r.dbMap.Dialect.QuotedTableForQuery("", remoteIdentityMappingTableName) + rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID) if err != nil { if err != sql.ErrNoRows { @@ -273,9 +271,9 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us } func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) { - qt := pq.QuoteIdentifier(userTableName) - ex := r.executor(tx) - i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s where admin=true", qt)) + qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName) + ex := executor(r.dbMap, tx) + i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt)) return int(i), err } @@ -288,14 +286,13 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults if err != nil { return nil, "", err } - ex := r.executor(tx) + ex := executor(r.dbMap, tx) - qt := pq.QuoteIdentifier(userTableName) + qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName) // Ask for one more than needed so we know if there's more results, and // hence, whether a nextPageToken is necessary. - ums, err := ex.Select(&userModel{}, - fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2 ", qt), maxResults+1, offset) + ums, err := ex.Select(&userModel{}, fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2", qt), maxResults+1, offset) if err != nil { return nil, "", err } @@ -338,20 +335,8 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults } -func (r *userRepo) executor(tx repo.Transaction) gorp.SqlExecutor { - if tx == nil { - return r.dbMap - } - - gorpTx, ok := tx.(*gorp.Transaction) - if !ok { - panic("wrong kind of transaction passed to a DB repo") - } - return gorpTx -} - func (r *userRepo) insert(tx repo.Transaction, usr user.User) error { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) um, err := newUserModel(&usr) if err != nil { return err @@ -360,7 +345,7 @@ func (r *userRepo) insert(tx repo.Transaction, usr user.User) error { } func (r *userRepo) update(tx repo.Transaction, usr user.User) error { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) um, err := newUserModel(&usr) if err != nil { return err @@ -370,7 +355,7 @@ func (r *userRepo) update(tx repo.Transaction, usr user.User) error { } func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) m, err := ex.Get(userModel{}, userID) if err != nil { @@ -391,7 +376,7 @@ func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) { } func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID) if err != nil { @@ -412,8 +397,8 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot } func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) { - qt := pq.QuoteIdentifier(userTableName) - ex := r.executor(tx) + qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName) + ex := executor(r.dbMap, tx) var um userModel err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email) @@ -427,7 +412,7 @@ func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, err } func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error { - ex := r.executor(tx) + ex := executor(r.dbMap, tx) rim, err := newRemoteIdentityMappingModel(userID, ri) if err != nil {