db: add sqlite3 support
This commit is contained in:
parent
8f16279f49
commit
bfd63b7514
14 changed files with 345 additions and 145 deletions
43
db/client.go
43
db/client.go
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/mattn/go-sqlite3"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
@ -89,23 +90,29 @@ func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) {
|
func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) {
|
||||||
repo := NewClientIdentityRepo(dbm).(*clientIdentityRepo)
|
tx, err := dbm.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
for _, c := range clients {
|
for _, c := range clients {
|
||||||
dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret)
|
dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata)
|
cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = repo.dbMap.Insert(cm)
|
err = tx.Insert(cm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return repo, nil
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewClientIdentityRepo(dbm), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientIdentityRepo struct {
|
type clientIdentityRepo struct {
|
||||||
|
@ -155,8 +162,9 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
|
m, err := tx.Get(clientIdentityModel{}, clientID)
|
||||||
if m == nil || err != nil {
|
if m == nil || err != nil {
|
||||||
rollback(tx)
|
rollback(tx)
|
||||||
return err
|
return err
|
||||||
|
@ -164,25 +172,17 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
||||||
|
|
||||||
cim, ok := m.(*clientIdentityModel)
|
cim, ok := m.(*clientIdentityModel)
|
||||||
if !ok {
|
if !ok {
|
||||||
rollback(tx)
|
|
||||||
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
|
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
|
||||||
return errors.New("unrecognized model")
|
return errors.New("unrecognized model")
|
||||||
}
|
}
|
||||||
|
|
||||||
cim.DexAdmin = isAdmin
|
cim.DexAdmin = isAdmin
|
||||||
_, err = r.dbMap.Update(cim)
|
_, err = tx.Update(cim)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rollback(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit()
|
return tx.Commit()
|
||||||
if err != nil {
|
|
||||||
rollback(tx)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
||||||
|
@ -223,8 +223,15 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.dbMap.Insert(cim); err != nil {
|
if err := r.dbMap.Insert(cim); err != nil {
|
||||||
if perr, ok := err.(*pq.Error); ok && perr.Code == pgErrorCodeUniqueViolation {
|
switch sqlErr := err.(type) {
|
||||||
err = errors.New("client ID already exists")
|
case *pq.Error:
|
||||||
|
if sqlErr.Code == pgErrorCodeUniqueViolation {
|
||||||
|
err = errors.New("client ID already exists")
|
||||||
|
}
|
||||||
|
case *sqlite3.Error:
|
||||||
|
if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique {
|
||||||
|
err = errors.New("client ID already exists")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -239,7 +246,7 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
|
func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
|
||||||
qt := pq.QuoteIdentifier(clientIdentityTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", clientIdentityTableName)
|
||||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||||
objs, err := r.dbMap.Select(&clientIdentityModel{}, q)
|
objs, err := r.dbMap.Select(&clientIdentityModel{}, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
47
db/conn.go
47
db/conn.go
|
@ -4,13 +4,16 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"net/url"
|
||||||
|
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
_ "github.com/lib/pq"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
|
|
||||||
|
// Import database drivers
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type table struct {
|
type table struct {
|
||||||
|
@ -43,22 +46,35 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnection(cfg Config) (*gorp.DbMap, error) {
|
func NewConnection(cfg Config) (*gorp.DbMap, error) {
|
||||||
if !strings.HasPrefix(cfg.DSN, "postgres://") {
|
u, err := url.Parse(cfg.DSN)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse DSN: %v", err)
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
db *sql.DB
|
||||||
|
dialect gorp.Dialect
|
||||||
|
)
|
||||||
|
switch u.Scheme {
|
||||||
|
case "postgres":
|
||||||
|
db, err = sql.Open("postgres", cfg.DSN)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
db.SetMaxIdleConns(cfg.MaxIdleConnections)
|
||||||
|
db.SetMaxOpenConns(cfg.MaxOpenConnections)
|
||||||
|
dialect = gorp.PostgresDialect{}
|
||||||
|
case "sqlite3":
|
||||||
|
db, err = sql.Open("sqlite3", u.Host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// NOTE(ericchiang): sqlite does NOT work with SetMaxIdleConns.
|
||||||
|
dialect = gorp.SqliteDialect{}
|
||||||
|
default:
|
||||||
return nil, errors.New("unrecognized database driver")
|
return nil, errors.New("unrecognized database driver")
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sql.Open("postgres", cfg.DSN)
|
dbm := gorp.DbMap{Db: db, Dialect: dialect}
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetMaxIdleConns(cfg.MaxIdleConnections)
|
|
||||||
db.SetMaxOpenConns(cfg.MaxOpenConnections)
|
|
||||||
|
|
||||||
dbm := gorp.DbMap{
|
|
||||||
Db: db,
|
|
||||||
Dialect: gorp.PostgresDialect{},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...)
|
tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...)
|
||||||
|
@ -70,7 +86,6 @@ func NewConnection(cfg Config) (*gorp.DbMap, error) {
|
||||||
cm.SetUnique(true)
|
cm.SetUnique(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &dbm, nil
|
return &dbm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
"github.com/lib/pq"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
|
@ -69,7 +68,7 @@ type ConnectorConfigRepo struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
||||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
|
||||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||||
objs, err := r.dbMap.Select(&connectorConfigModel{}, q)
|
objs, err := r.dbMap.Select(&connectorConfigModel{}, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -94,10 +93,10 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
|
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
|
||||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
|
||||||
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
|
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
|
||||||
var c connectorConfigModel
|
var c connectorConfigModel
|
||||||
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
|
if err := executor(r.dbMap, tx).SelectOne(&c, q, id); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, connector.ErrorNotFound
|
return nil, connector.ErrorNotFound
|
||||||
}
|
}
|
||||||
|
@ -121,28 +120,17 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
|
||||||
q := fmt.Sprintf("DELETE FROM %s", qt)
|
q := fmt.Sprintf("DELETE FROM %s", qt)
|
||||||
if _, err = r.dbMap.Exec(q); err != nil {
|
if _, err = tx.Exec(q); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = r.dbMap.Insert(insert...); err != nil {
|
if err = tx.Insert(insert...); err != nil {
|
||||||
return fmt.Errorf("DB insert failed %#v: %v", insert, err)
|
return fmt.Errorf("DB insert failed %#v: %v", insert, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
|
||||||
if tx == nil {
|
|
||||||
return r.dbMap
|
|
||||||
}
|
|
||||||
|
|
||||||
gorpTx, ok := tx.(*gorp.Transaction)
|
|
||||||
if !ok {
|
|
||||||
panic("wrong kind of transaction passed to a DB repo")
|
|
||||||
}
|
|
||||||
return gorpTx
|
|
||||||
}
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
"github.com/lib/pq"
|
|
||||||
|
|
||||||
pcrypto "github.com/coreos/dex/pkg/crypto"
|
pcrypto "github.com/coreos/dex/pkg/crypto"
|
||||||
"github.com/coreos/go-oidc/key"
|
"github.com/coreos/go-oidc/key"
|
||||||
|
@ -114,7 +113,7 @@ type PrivateKeySetRepo struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||||||
qt := pq.QuoteIdentifier(keyTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
|
||||||
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt))
|
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -152,7 +151,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
|
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
|
||||||
qt := pq.QuoteIdentifier(keyTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
|
||||||
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
|
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -1,19 +1,18 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
"github.com/lib/pq"
|
"github.com/rubenv/sql-migrate"
|
||||||
migrate "github.com/rubenv/sql-migrate"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/db/migrations"
|
"github.com/coreos/dex/db/migrations"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
migrationDialect = "postgres"
|
migrationTable = "dex_migrations"
|
||||||
migrationTable = "dex_migrations"
|
migrationDir = "db/migrations"
|
||||||
migrationDir = "db/migrations"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -21,32 +20,57 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func MigrateToLatest(dbMap *gorp.DbMap) (int, error) {
|
func MigrateToLatest(dbMap *gorp.DbMap) (int, error) {
|
||||||
source := getSource()
|
source, dialect, err := migrationSource(dbMap)
|
||||||
|
if err != nil {
|
||||||
return migrate.Exec(dbMap.Db, migrationDialect, source, migrate.Up)
|
return 0, err
|
||||||
|
}
|
||||||
|
return migrate.Exec(dbMap.Db, dialect, source, migrate.Up)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) {
|
func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) {
|
||||||
source := getSource()
|
source, dialect, err := migrationSource(dbMap)
|
||||||
|
if err != nil {
|
||||||
return migrate.ExecMax(dbMap.Db, migrationDialect, source, migrate.Up, max)
|
return 0, err
|
||||||
|
}
|
||||||
|
return migrate.ExecMax(dbMap.Db, dialect, source, migrate.Up, max)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) {
|
func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) {
|
||||||
migrations, _, err := migrate.PlanMigration(dbMap.Db, migrationDialect, getSource(), migrate.Up, 0)
|
source, dialect, err := migrationSource(dbMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
migrations, _, err := migrate.PlanMigration(dbMap.Db, dialect, source, migrate.Up, 0)
|
||||||
return migrations, err
|
return migrations, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func DropMigrationsTable(dbMap *gorp.DbMap) error {
|
func DropMigrationsTable(dbMap *gorp.DbMap) error {
|
||||||
qt := pq.QuoteIdentifier(migrationTable)
|
qt := fmt.Sprintf("DROP TABLE IF EXISTS %s;", dbMap.Dialect.QuotedTableForQuery("", migrationTable))
|
||||||
_, err := dbMap.Exec(fmt.Sprintf("drop table if exists %s ;", qt))
|
_, err := dbMap.Exec(qt)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSource() migrate.MigrationSource {
|
func migrationSource(dbMap *gorp.DbMap) (src migrate.MigrationSource, dialect string, err error) {
|
||||||
return &migrate.AssetMigrationSource{
|
switch dbMap.Dialect.(type) {
|
||||||
Dir: migrationDir,
|
case gorp.PostgresDialect:
|
||||||
Asset: migrations.Asset,
|
src = &migrate.AssetMigrationSource{
|
||||||
AssetDir: migrations.AssetDir,
|
Dir: migrationDir,
|
||||||
|
Asset: migrations.Asset,
|
||||||
|
AssetDir: migrations.AssetDir,
|
||||||
|
}
|
||||||
|
return src, "postgres", nil
|
||||||
|
case gorp.SqliteDialect:
|
||||||
|
src = &migrate.MemoryMigrationSource{
|
||||||
|
Migrations: []*migrate.Migration{
|
||||||
|
{
|
||||||
|
Id: "dex.sql",
|
||||||
|
Up: []string{sqlite3Migration},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return src, "sqlite3", nil
|
||||||
|
default:
|
||||||
|
return nil, "", errors.New("unsupported migration driver")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
73
db/migrate_sqlite3.go
Normal file
73
db/migrate_sqlite3.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
// SQLite3 is a test only database. There is only one migration because we do not support migrations.
|
||||||
|
const sqlite3Migration = `
|
||||||
|
CREATE TABLE authd_user (
|
||||||
|
id text NOT NULL UNIQUE,
|
||||||
|
email text,
|
||||||
|
email_verified integer,
|
||||||
|
display_name text,
|
||||||
|
admin integer,
|
||||||
|
created_at bigint,
|
||||||
|
disabled integer
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE client_identity (
|
||||||
|
id text NOT NULL UNIQUE,
|
||||||
|
secret blob,
|
||||||
|
metadata text,
|
||||||
|
dex_admin integer
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE connector_config (
|
||||||
|
id text NOT NULL UNIQUE,
|
||||||
|
type text,
|
||||||
|
config text
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE key (
|
||||||
|
value blob
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE password_info (
|
||||||
|
user_id text NOT NULL UNIQUE,
|
||||||
|
password text,
|
||||||
|
password_expires bigint
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE refresh_token (
|
||||||
|
id integer PRIMARY KEY,
|
||||||
|
payload_hash blob,
|
||||||
|
user_id text,
|
||||||
|
client_id text
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE remote_identity_mapping (
|
||||||
|
connector_id text NOT NULL,
|
||||||
|
user_id text,
|
||||||
|
remote_id text NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE session (
|
||||||
|
id text NOT NULL UNIQUE,
|
||||||
|
state text,
|
||||||
|
created_at bigint,
|
||||||
|
expires_at bigint,
|
||||||
|
client_id text,
|
||||||
|
client_state text,
|
||||||
|
redirect_url text,
|
||||||
|
identity text,
|
||||||
|
connector_id text,
|
||||||
|
user_id text,
|
||||||
|
register integer,
|
||||||
|
nonce text,
|
||||||
|
scope text
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE session_key (
|
||||||
|
key text NOT NULL UNIQUE,
|
||||||
|
session_id text,
|
||||||
|
expires_at bigint,
|
||||||
|
stale integer
|
||||||
|
);
|
||||||
|
`
|
|
@ -5,10 +5,11 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gorp/gorp"
|
||||||
|
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/go-gorp/gorp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -89,20 +90,8 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *passwordInfoRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
|
||||||
if tx == nil {
|
|
||||||
return r.dbMap
|
|
||||||
}
|
|
||||||
|
|
||||||
gorpTx, ok := tx.(*gorp.Transaction)
|
|
||||||
if !ok {
|
|
||||||
panic("wrong kind of transaction passed to a DB repo")
|
|
||||||
}
|
|
||||||
return gorpTx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) {
|
func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
|
|
||||||
m, err := ex.Get(passwordInfoModel{}, id)
|
m, err := ex.Get(passwordInfoModel{}, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -123,7 +112,7 @@ func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) error {
|
func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) error {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
pm, err := newPasswordInfoModel(&pw)
|
pm, err := newPasswordInfoModel(&pw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -132,7 +121,7 @@ func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *passwordInfoRepo) update(tx repo.Transaction, pw user.PasswordInfo) error {
|
func (r *passwordInfoRepo) update(tx repo.Transaction, pw user.PasswordInfo) error {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
pm, err := newPasswordInfoModel(&pw)
|
pm, err := newPasswordInfoModel(&pw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -8,10 +8,11 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/coreos/dex/pkg/log"
|
|
||||||
"github.com/coreos/dex/refresh"
|
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/pkg/log"
|
||||||
|
"github.com/coreos/dex/refresh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -166,16 +167,8 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshTokenRepo) executor(tx *gorp.Transaction) gorp.SqlExecutor {
|
|
||||||
if tx == nil {
|
|
||||||
return r.dbMap
|
|
||||||
}
|
|
||||||
return tx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) {
|
func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
result, err := ex.Get(refreshTokenModel{}, tokenID)
|
result, err := ex.Get(refreshTokenModel{}, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
|
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
"github.com/lib/pq"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session"
|
||||||
|
@ -183,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SessionRepo) purge() error {
|
func (r *SessionRepo) purge() error {
|
||||||
qt := pq.QuoteIdentifier(sessionTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionTableName)
|
||||||
q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt)
|
q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt)
|
||||||
res, err := r.dbMap.Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
|
res, err := executor(r.dbMap, nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
|
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
"github.com/lib/pq"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session"
|
||||||
|
@ -77,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
||||||
return "", errors.New("invalid session key")
|
return "", errors.New("invalid session key")
|
||||||
}
|
}
|
||||||
|
|
||||||
qt := pq.QuoteIdentifier(sessionKeyTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
|
||||||
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
|
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
|
||||||
res, err := r.dbMap.Exec(q, true, key, false)
|
res, err := executor(r.dbMap, nil).Exec(q, true, key, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -95,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SessionKeyRepo) purge() error {
|
func (r *SessionKeyRepo) purge() error {
|
||||||
qt := pq.QuoteIdentifier(sessionKeyTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
|
||||||
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
|
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
|
||||||
res, err := r.dbMap.Exec(q, true, r.clock.Now().Unix())
|
res, err := executor(r.dbMap, nil).Exec(q, true, r.clock.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
33
db/transaction.go
Normal file
33
db/transaction.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-gorp/gorp"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/db/translate"
|
||||||
|
"github.com/coreos/dex/repo"
|
||||||
|
)
|
||||||
|
|
||||||
|
func executor(dbMap *gorp.DbMap, tx repo.Transaction) gorp.SqlExecutor {
|
||||||
|
var exec gorp.SqlExecutor
|
||||||
|
if tx == nil {
|
||||||
|
exec = dbMap
|
||||||
|
} else {
|
||||||
|
gorpTx, ok := tx.(*gorp.Transaction)
|
||||||
|
if !ok {
|
||||||
|
panic("wrong kind of transaction passed to a DB repo")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the underlying value of the pointer is nil.
|
||||||
|
// This is not caught by the initial comparison (tx == nil).
|
||||||
|
if gorpTx == nil {
|
||||||
|
exec = dbMap
|
||||||
|
} else {
|
||||||
|
exec = gorpTx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := dbMap.Dialect.(gorp.SqliteDialect); ok {
|
||||||
|
exec = translate.NewExecutor(exec, translate.PostgresToSQLite)
|
||||||
|
}
|
||||||
|
return exec
|
||||||
|
}
|
68
db/translate/translate.go
Normal file
68
db/translate/translate.go
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
/*
|
||||||
|
Package translate implements translation of driver specific SQL queries.
|
||||||
|
*/
|
||||||
|
package translate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/go-gorp/gorp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
bindRegexp = regexp.MustCompile(`\$\d+`)
|
||||||
|
trueRegexp = regexp.MustCompile(`\btrue\b`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgresToSQLite implements translation of the pq driver to sqlite3.
|
||||||
|
func PostgresToSQLite(query string) string {
|
||||||
|
query = bindRegexp.ReplaceAllString(query, "?")
|
||||||
|
query = trueRegexp.ReplaceAllString(query, "1")
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor {
|
||||||
|
return &executor{exec, translate}
|
||||||
|
}
|
||||||
|
|
||||||
|
type executor struct {
|
||||||
|
gorp.SqlExecutor
|
||||||
|
Translate func(string) string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
return e.SqlExecutor.Exec(e.Translate(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
||||||
|
return e.SqlExecutor.Select(i, e.Translate(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) SelectInt(query string, args ...interface{}) (int64, error) {
|
||||||
|
return e.SqlExecutor.SelectInt(e.Translate(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
|
||||||
|
return e.SqlExecutor.SelectNullInt(e.Translate(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) SelectFloat(query string, args ...interface{}) (float64, error) {
|
||||||
|
return e.SqlExecutor.SelectFloat(e.Translate(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
|
||||||
|
return e.SqlExecutor.SelectNullFloat(e.Translate(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) SelectStr(query string, args ...interface{}) (string, error) {
|
||||||
|
return e.SqlExecutor.SelectStr(e.Translate(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
|
||||||
|
return e.SqlExecutor.SelectNullStr(e.Translate(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) SelectOne(holder interface{}, query string, args ...interface{}) error {
|
||||||
|
return e.SqlExecutor.SelectOne(holder, e.Translate(query), args...)
|
||||||
|
}
|
28
db/translate/translate_test.go
Normal file
28
db/translate/translate_test.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package translate
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestPostgresToSQLite(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
query string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"SELECT * FROM foo", "SELECT * FROM foo"},
|
||||||
|
{"SELECT * FROM %s", "SELECT * FROM %s"},
|
||||||
|
{"SELECT * FROM foo WHERE is_admin=true", "SELECT * FROM foo WHERE is_admin=1"},
|
||||||
|
{"SELECT * FROM foo WHERE is_admin=true;", "SELECT * FROM foo WHERE is_admin=1;"},
|
||||||
|
{"SELECT * FROM foo WHERE is_admin=$10", "SELECT * FROM foo WHERE is_admin=?"},
|
||||||
|
{"SELECT * FROM foo WHERE is_admin=$10;", "SELECT * FROM foo WHERE is_admin=?;"},
|
||||||
|
{"SELECT * FROM foo WHERE name=$1 AND is_admin=$2;", "SELECT * FROM foo WHERE name=? AND is_admin=?;"},
|
||||||
|
{"$1", "?"},
|
||||||
|
{"$", "$"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := PostgresToSQLite(tt.query)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("PostgresToSQLite(%q): want=%q, got=%q", tt.query, tt.want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
55
db/user.go
55
db/user.go
|
@ -8,7 +8,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
"github.com/lib/pq"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
|
@ -107,9 +106,9 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err
|
||||||
return user.ErrorInvalidID
|
return user.ErrorInvalidID
|
||||||
}
|
}
|
||||||
|
|
||||||
qt := pq.QuoteIdentifier(userTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable)
|
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -221,7 +220,7 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
deleted, err := ex.Delete(rim)
|
deleted, err := ex.Delete(rim)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -236,14 +235,13 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) {
|
func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
return nil, user.ErrorInvalidID
|
return nil, user.ErrorInvalidID
|
||||||
}
|
}
|
||||||
|
|
||||||
qt := pq.QuoteIdentifier(remoteIdentityMappingTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", remoteIdentityMappingTableName)
|
||||||
rims, err := ex.Select(&remoteIdentityMappingModel{},
|
rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID)
|
||||||
fmt.Sprintf("select * from %s where user_id = $1", qt), userID)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
|
@ -273,9 +271,9 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) {
|
func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) {
|
||||||
qt := pq.QuoteIdentifier(userTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s where admin=true", qt))
|
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt))
|
||||||
return int(i), err
|
return int(i), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,14 +286,13 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
|
|
||||||
qt := pq.QuoteIdentifier(userTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||||
|
|
||||||
// Ask for one more than needed so we know if there's more results, and
|
// Ask for one more than needed so we know if there's more results, and
|
||||||
// hence, whether a nextPageToken is necessary.
|
// hence, whether a nextPageToken is necessary.
|
||||||
ums, err := ex.Select(&userModel{},
|
ums, err := ex.Select(&userModel{}, fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2", qt), maxResults+1, offset)
|
||||||
fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2 ", qt), maxResults+1, offset)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
@ -338,20 +335,8 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
|
||||||
if tx == nil {
|
|
||||||
return r.dbMap
|
|
||||||
}
|
|
||||||
|
|
||||||
gorpTx, ok := tx.(*gorp.Transaction)
|
|
||||||
if !ok {
|
|
||||||
panic("wrong kind of transaction passed to a DB repo")
|
|
||||||
}
|
|
||||||
return gorpTx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
|
func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
um, err := newUserModel(&usr)
|
um, err := newUserModel(&usr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -360,7 +345,7 @@ func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
|
func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
um, err := newUserModel(&usr)
|
um, err := newUserModel(&usr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -370,7 +355,7 @@ func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
|
func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
|
|
||||||
m, err := ex.Get(userModel{}, userID)
|
m, err := ex.Get(userModel{}, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -391,7 +376,7 @@ func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) {
|
func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
|
|
||||||
m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID)
|
m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -412,8 +397,8 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) {
|
func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) {
|
||||||
qt := pq.QuoteIdentifier(userTableName)
|
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
var um userModel
|
var um userModel
|
||||||
err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email)
|
err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email)
|
||||||
|
|
||||||
|
@ -427,7 +412,7 @@ func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error {
|
func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error {
|
||||||
ex := r.executor(tx)
|
ex := executor(r.dbMap, tx)
|
||||||
rim, err := newRemoteIdentityMappingModel(userID, ri)
|
rim, err := newRemoteIdentityMappingModel(userID, ri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
|
|
Reference in a new issue