db: add sqlite3 support
This commit is contained in:
parent
8f16279f49
commit
bfd63b7514
14 changed files with 345 additions and 145 deletions
41
db/client.go
41
db/client.go
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
|
@ -89,23 +90,29 @@ func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo {
|
|||
}
|
||||
|
||||
func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) {
|
||||
repo := NewClientIdentityRepo(dbm).(*clientIdentityRepo)
|
||||
tx, err := dbm.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, c := range clients {
|
||||
dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = repo.dbMap.Insert(cm)
|
||||
err = tx.Insert(cm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return repo, nil
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewClientIdentityRepo(dbm), nil
|
||||
}
|
||||
|
||||
type clientIdentityRepo struct {
|
||||
|
@ -155,8 +162,9 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
|
||||
m, err := tx.Get(clientIdentityModel{}, clientID)
|
||||
if m == nil || err != nil {
|
||||
rollback(tx)
|
||||
return err
|
||||
|
@ -164,25 +172,17 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
|||
|
||||
cim, ok := m.(*clientIdentityModel)
|
||||
if !ok {
|
||||
rollback(tx)
|
||||
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
|
||||
return errors.New("unrecognized model")
|
||||
}
|
||||
|
||||
cim.DexAdmin = isAdmin
|
||||
_, err = r.dbMap.Update(cim)
|
||||
_, err = tx.Update(cim)
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
||||
|
@ -223,9 +223,16 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
|
|||
}
|
||||
|
||||
if err := r.dbMap.Insert(cim); err != nil {
|
||||
if perr, ok := err.(*pq.Error); ok && perr.Code == pgErrorCodeUniqueViolation {
|
||||
switch sqlErr := err.(type) {
|
||||
case *pq.Error:
|
||||
if sqlErr.Code == pgErrorCodeUniqueViolation {
|
||||
err = errors.New("client ID already exists")
|
||||
}
|
||||
case *sqlite3.Error:
|
||||
if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique {
|
||||
err = errors.New("client ID already exists")
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
@ -239,7 +246,7 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
|
|||
}
|
||||
|
||||
func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
|
||||
qt := pq.QuoteIdentifier(clientIdentityTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", clientIdentityTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||
objs, err := r.dbMap.Select(&clientIdentityModel{}, q)
|
||||
if err != nil {
|
||||
|
|
39
db/conn.go
39
db/conn.go
|
@ -4,13 +4,16 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
|
||||
// Import database drivers
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type table struct {
|
||||
|
@ -43,22 +46,35 @@ type Config struct {
|
|||
}
|
||||
|
||||
func NewConnection(cfg Config) (*gorp.DbMap, error) {
|
||||
if !strings.HasPrefix(cfg.DSN, "postgres://") {
|
||||
return nil, errors.New("unrecognized database driver")
|
||||
u, err := url.Parse(cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse DSN: %v", err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", cfg.DSN)
|
||||
var (
|
||||
db *sql.DB
|
||||
dialect gorp.Dialect
|
||||
)
|
||||
switch u.Scheme {
|
||||
case "postgres":
|
||||
db, err = sql.Open("postgres", cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConnections)
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConnections)
|
||||
|
||||
dbm := gorp.DbMap{
|
||||
Db: db,
|
||||
Dialect: gorp.PostgresDialect{},
|
||||
dialect = gorp.PostgresDialect{}
|
||||
case "sqlite3":
|
||||
db, err = sql.Open("sqlite3", u.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// NOTE(ericchiang): sqlite does NOT work with SetMaxIdleConns.
|
||||
dialect = gorp.SqliteDialect{}
|
||||
default:
|
||||
return nil, errors.New("unrecognized database driver")
|
||||
}
|
||||
|
||||
dbm := gorp.DbMap{Db: db, Dialect: dialect}
|
||||
|
||||
for _, t := range tables {
|
||||
tm := dbm.AddTableWithName(t.model, t.name).SetKeys(t.autoinc, t.pkey...)
|
||||
|
@ -70,7 +86,6 @@ func NewConnection(cfg Config) (*gorp.DbMap, error) {
|
|||
cm.SetUnique(true)
|
||||
}
|
||||
}
|
||||
|
||||
return &dbm, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
|
@ -69,7 +68,7 @@ type ConnectorConfigRepo struct {
|
|||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||
objs, err := r.dbMap.Select(&connectorConfigModel{}, q)
|
||||
if err != nil {
|
||||
|
@ -94,10 +93,10 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
|||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
|
||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
|
||||
var c connectorConfigModel
|
||||
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
|
||||
if err := executor(r.dbMap, tx).SelectOne(&c, q, id); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, connector.ErrorNotFound
|
||||
}
|
||||
|
@ -121,28 +120,17 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
|
||||
q := fmt.Sprintf("DELETE FROM %s", qt)
|
||||
if _, err = r.dbMap.Exec(q); err != nil {
|
||||
if _, err = tx.Exec(q); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = r.dbMap.Insert(insert...); err != nil {
|
||||
if err = tx.Insert(insert...); err != nil {
|
||||
return fmt.Errorf("DB insert failed %#v: %v", insert, err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
return gorpTx
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
|
||||
pcrypto "github.com/coreos/dex/pkg/crypto"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
|
@ -114,7 +113,7 @@ type PrivateKeySetRepo struct {
|
|||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||||
qt := pq.QuoteIdentifier(keyTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
|
||||
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt))
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -152,7 +151,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
|||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
|
||||
qt := pq.QuoteIdentifier(keyTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
|
||||
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
migrate "github.com/rubenv/sql-migrate"
|
||||
"github.com/rubenv/sql-migrate"
|
||||
|
||||
"github.com/coreos/dex/db/migrations"
|
||||
)
|
||||
|
||||
const (
|
||||
migrationDialect = "postgres"
|
||||
migrationTable = "dex_migrations"
|
||||
migrationDir = "db/migrations"
|
||||
)
|
||||
|
@ -21,32 +20,57 @@ func init() {
|
|||
}
|
||||
|
||||
func MigrateToLatest(dbMap *gorp.DbMap) (int, error) {
|
||||
source := getSource()
|
||||
|
||||
return migrate.Exec(dbMap.Db, migrationDialect, source, migrate.Up)
|
||||
source, dialect, err := migrationSource(dbMap)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return migrate.Exec(dbMap.Db, dialect, source, migrate.Up)
|
||||
}
|
||||
|
||||
func MigrateMaxMigrations(dbMap *gorp.DbMap, max int) (int, error) {
|
||||
source := getSource()
|
||||
|
||||
return migrate.ExecMax(dbMap.Db, migrationDialect, source, migrate.Up, max)
|
||||
source, dialect, err := migrationSource(dbMap)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return migrate.ExecMax(dbMap.Db, dialect, source, migrate.Up, max)
|
||||
}
|
||||
|
||||
func GetPlannedMigrations(dbMap *gorp.DbMap) ([]*migrate.PlannedMigration, error) {
|
||||
migrations, _, err := migrate.PlanMigration(dbMap.Db, migrationDialect, getSource(), migrate.Up, 0)
|
||||
source, dialect, err := migrationSource(dbMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
migrations, _, err := migrate.PlanMigration(dbMap.Db, dialect, source, migrate.Up, 0)
|
||||
return migrations, err
|
||||
}
|
||||
|
||||
func DropMigrationsTable(dbMap *gorp.DbMap) error {
|
||||
qt := pq.QuoteIdentifier(migrationTable)
|
||||
_, err := dbMap.Exec(fmt.Sprintf("drop table if exists %s ;", qt))
|
||||
qt := fmt.Sprintf("DROP TABLE IF EXISTS %s;", dbMap.Dialect.QuotedTableForQuery("", migrationTable))
|
||||
_, err := dbMap.Exec(qt)
|
||||
return err
|
||||
}
|
||||
|
||||
func getSource() migrate.MigrationSource {
|
||||
return &migrate.AssetMigrationSource{
|
||||
func migrationSource(dbMap *gorp.DbMap) (src migrate.MigrationSource, dialect string, err error) {
|
||||
switch dbMap.Dialect.(type) {
|
||||
case gorp.PostgresDialect:
|
||||
src = &migrate.AssetMigrationSource{
|
||||
Dir: migrationDir,
|
||||
Asset: migrations.Asset,
|
||||
AssetDir: migrations.AssetDir,
|
||||
}
|
||||
return src, "postgres", nil
|
||||
case gorp.SqliteDialect:
|
||||
src = &migrate.MemoryMigrationSource{
|
||||
Migrations: []*migrate.Migration{
|
||||
{
|
||||
Id: "dex.sql",
|
||||
Up: []string{sqlite3Migration},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return src, "sqlite3", nil
|
||||
default:
|
||||
return nil, "", errors.New("unsupported migration driver")
|
||||
}
|
||||
}
|
||||
|
|
73
db/migrate_sqlite3.go
Normal file
73
db/migrate_sqlite3.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package db
|
||||
|
||||
// SQLite3 is a test only database. There is only one migration because we do not support migrations.
|
||||
const sqlite3Migration = `
|
||||
CREATE TABLE authd_user (
|
||||
id text NOT NULL UNIQUE,
|
||||
email text,
|
||||
email_verified integer,
|
||||
display_name text,
|
||||
admin integer,
|
||||
created_at bigint,
|
||||
disabled integer
|
||||
);
|
||||
|
||||
CREATE TABLE client_identity (
|
||||
id text NOT NULL UNIQUE,
|
||||
secret blob,
|
||||
metadata text,
|
||||
dex_admin integer
|
||||
);
|
||||
|
||||
CREATE TABLE connector_config (
|
||||
id text NOT NULL UNIQUE,
|
||||
type text,
|
||||
config text
|
||||
);
|
||||
|
||||
CREATE TABLE key (
|
||||
value blob
|
||||
);
|
||||
|
||||
CREATE TABLE password_info (
|
||||
user_id text NOT NULL UNIQUE,
|
||||
password text,
|
||||
password_expires bigint
|
||||
);
|
||||
|
||||
CREATE TABLE refresh_token (
|
||||
id integer PRIMARY KEY,
|
||||
payload_hash blob,
|
||||
user_id text,
|
||||
client_id text
|
||||
);
|
||||
|
||||
CREATE TABLE remote_identity_mapping (
|
||||
connector_id text NOT NULL,
|
||||
user_id text,
|
||||
remote_id text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE session (
|
||||
id text NOT NULL UNIQUE,
|
||||
state text,
|
||||
created_at bigint,
|
||||
expires_at bigint,
|
||||
client_id text,
|
||||
client_state text,
|
||||
redirect_url text,
|
||||
identity text,
|
||||
connector_id text,
|
||||
user_id text,
|
||||
register integer,
|
||||
nonce text,
|
||||
scope text
|
||||
);
|
||||
|
||||
CREATE TABLE session_key (
|
||||
key text NOT NULL UNIQUE,
|
||||
session_id text,
|
||||
expires_at bigint,
|
||||
stale integer
|
||||
);
|
||||
`
|
|
@ -5,10 +5,11 @@ import (
|
|||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/go-gorp/gorp"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -89,20 +90,8 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *passwordInfoRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
return gorpTx
|
||||
}
|
||||
|
||||
func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
|
||||
m, err := ex.Get(passwordInfoModel{}, id)
|
||||
if err != nil {
|
||||
|
@ -123,7 +112,7 @@ func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInf
|
|||
}
|
||||
|
||||
func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) error {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
pm, err := newPasswordInfoModel(&pw)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -132,7 +121,7 @@ func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) err
|
|||
}
|
||||
|
||||
func (r *passwordInfoRepo) update(tx repo.Transaction, pw user.PasswordInfo) error {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
pm, err := newPasswordInfoModel(&pw)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -8,10 +8,11 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/go-gorp/gorp"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -166,16 +167,8 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *refreshTokenRepo) executor(tx *gorp.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
result, err := ex.Get(refreshTokenModel{}, tokenID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/session"
|
||||
|
@ -183,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error {
|
|||
}
|
||||
|
||||
func (r *SessionRepo) purge() error {
|
||||
qt := pq.QuoteIdentifier(sessionTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionTableName)
|
||||
q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt)
|
||||
res, err := r.dbMap.Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
|
||||
res, err := executor(r.dbMap, nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/session"
|
||||
|
@ -77,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
|||
return "", errors.New("invalid session key")
|
||||
}
|
||||
|
||||
qt := pq.QuoteIdentifier(sessionKeyTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
|
||||
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
|
||||
res, err := r.dbMap.Exec(q, true, key, false)
|
||||
res, err := executor(r.dbMap, nil).Exec(q, true, key, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -95,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
|||
}
|
||||
|
||||
func (r *SessionKeyRepo) purge() error {
|
||||
qt := pq.QuoteIdentifier(sessionKeyTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
|
||||
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
|
||||
res, err := r.dbMap.Exec(q, true, r.clock.Now().Unix())
|
||||
res, err := executor(r.dbMap, nil).Exec(q, true, r.clock.Now().Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
33
db/transaction.go
Normal file
33
db/transaction.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/db/translate"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
func executor(dbMap *gorp.DbMap, tx repo.Transaction) gorp.SqlExecutor {
|
||||
var exec gorp.SqlExecutor
|
||||
if tx == nil {
|
||||
exec = dbMap
|
||||
} else {
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
|
||||
// Check if the underlying value of the pointer is nil.
|
||||
// This is not caught by the initial comparison (tx == nil).
|
||||
if gorpTx == nil {
|
||||
exec = dbMap
|
||||
} else {
|
||||
exec = gorpTx
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := dbMap.Dialect.(gorp.SqliteDialect); ok {
|
||||
exec = translate.NewExecutor(exec, translate.PostgresToSQLite)
|
||||
}
|
||||
return exec
|
||||
}
|
68
db/translate/translate.go
Normal file
68
db/translate/translate.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
Package translate implements translation of driver specific SQL queries.
|
||||
*/
|
||||
package translate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"regexp"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
)
|
||||
|
||||
var (
|
||||
bindRegexp = regexp.MustCompile(`\$\d+`)
|
||||
trueRegexp = regexp.MustCompile(`\btrue\b`)
|
||||
)
|
||||
|
||||
// PostgresToSQLite implements translation of the pq driver to sqlite3.
|
||||
func PostgresToSQLite(query string) string {
|
||||
query = bindRegexp.ReplaceAllString(query, "?")
|
||||
query = trueRegexp.ReplaceAllString(query, "1")
|
||||
return query
|
||||
}
|
||||
|
||||
func NewExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor {
|
||||
return &executor{exec, translate}
|
||||
}
|
||||
|
||||
type executor struct {
|
||||
gorp.SqlExecutor
|
||||
Translate func(string) string
|
||||
}
|
||||
|
||||
func (e *executor) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return e.SqlExecutor.Exec(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
||||
return e.SqlExecutor.Select(i, e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectInt(query string, args ...interface{}) (int64, error) {
|
||||
return e.SqlExecutor.SelectInt(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
|
||||
return e.SqlExecutor.SelectNullInt(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectFloat(query string, args ...interface{}) (float64, error) {
|
||||
return e.SqlExecutor.SelectFloat(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
|
||||
return e.SqlExecutor.SelectNullFloat(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectStr(query string, args ...interface{}) (string, error) {
|
||||
return e.SqlExecutor.SelectStr(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
|
||||
return e.SqlExecutor.SelectNullStr(e.Translate(query), args...)
|
||||
}
|
||||
|
||||
func (e *executor) SelectOne(holder interface{}, query string, args ...interface{}) error {
|
||||
return e.SqlExecutor.SelectOne(holder, e.Translate(query), args...)
|
||||
}
|
28
db/translate/translate_test.go
Normal file
28
db/translate/translate_test.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package translate
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPostgresToSQLite(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
query string
|
||||
want string
|
||||
}{
|
||||
{"SELECT * FROM foo", "SELECT * FROM foo"},
|
||||
{"SELECT * FROM %s", "SELECT * FROM %s"},
|
||||
{"SELECT * FROM foo WHERE is_admin=true", "SELECT * FROM foo WHERE is_admin=1"},
|
||||
{"SELECT * FROM foo WHERE is_admin=true;", "SELECT * FROM foo WHERE is_admin=1;"},
|
||||
{"SELECT * FROM foo WHERE is_admin=$10", "SELECT * FROM foo WHERE is_admin=?"},
|
||||
{"SELECT * FROM foo WHERE is_admin=$10;", "SELECT * FROM foo WHERE is_admin=?;"},
|
||||
{"SELECT * FROM foo WHERE name=$1 AND is_admin=$2;", "SELECT * FROM foo WHERE name=? AND is_admin=?;"},
|
||||
{"$1", "?"},
|
||||
{"$", "$"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := PostgresToSQLite(tt.query)
|
||||
if got != tt.want {
|
||||
t.Errorf("PostgresToSQLite(%q): want=%q, got=%q", tt.query, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
55
db/user.go
55
db/user.go
|
@ -8,7 +8,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
|
@ -107,9 +106,9 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err
|
|||
return user.ErrorInvalidID
|
||||
}
|
||||
|
||||
qt := pq.QuoteIdentifier(userTableName)
|
||||
ex := r.executor(tx)
|
||||
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||
ex := executor(r.dbMap, tx)
|
||||
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -221,7 +220,7 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
|
|||
return err
|
||||
}
|
||||
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
deleted, err := ex.Delete(rim)
|
||||
|
||||
if err != nil {
|
||||
|
@ -236,14 +235,13 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
|
|||
}
|
||||
|
||||
func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
if userID == "" {
|
||||
return nil, user.ErrorInvalidID
|
||||
}
|
||||
|
||||
qt := pq.QuoteIdentifier(remoteIdentityMappingTableName)
|
||||
rims, err := ex.Select(&remoteIdentityMappingModel{},
|
||||
fmt.Sprintf("select * from %s where user_id = $1", qt), userID)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", remoteIdentityMappingTableName)
|
||||
rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID)
|
||||
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
|
@ -273,9 +271,9 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
|
|||
}
|
||||
|
||||
func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) {
|
||||
qt := pq.QuoteIdentifier(userTableName)
|
||||
ex := r.executor(tx)
|
||||
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s where admin=true", qt))
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||
ex := executor(r.dbMap, tx)
|
||||
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt))
|
||||
return int(i), err
|
||||
}
|
||||
|
||||
|
@ -288,14 +286,13 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
|
|||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
|
||||
qt := pq.QuoteIdentifier(userTableName)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||
|
||||
// Ask for one more than needed so we know if there's more results, and
|
||||
// hence, whether a nextPageToken is necessary.
|
||||
ums, err := ex.Select(&userModel{},
|
||||
fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2 ", qt), maxResults+1, offset)
|
||||
ums, err := ex.Select(&userModel{}, fmt.Sprintf("SELECT * FROM %s ORDER BY email LIMIT $1 OFFSET $2", qt), maxResults+1, offset)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -338,20 +335,8 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
|
|||
|
||||
}
|
||||
|
||||
func (r *userRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
return gorpTx
|
||||
}
|
||||
|
||||
func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
um, err := newUserModel(&usr)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -360,7 +345,7 @@ func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
|
|||
}
|
||||
|
||||
func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
um, err := newUserModel(&usr)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -370,7 +355,7 @@ func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
|
|||
}
|
||||
|
||||
func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
|
||||
m, err := ex.Get(userModel{}, userID)
|
||||
if err != nil {
|
||||
|
@ -391,7 +376,7 @@ func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
|
|||
}
|
||||
|
||||
func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
|
||||
m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID)
|
||||
if err != nil {
|
||||
|
@ -412,8 +397,8 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot
|
|||
}
|
||||
|
||||
func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) {
|
||||
qt := pq.QuoteIdentifier(userTableName)
|
||||
ex := r.executor(tx)
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||
ex := executor(r.dbMap, tx)
|
||||
var um userModel
|
||||
err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email)
|
||||
|
||||
|
@ -427,7 +412,7 @@ func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, err
|
|||
}
|
||||
|
||||
func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error {
|
||||
ex := r.executor(tx)
|
||||
ex := executor(r.dbMap, tx)
|
||||
rim, err := newRemoteIdentityMappingModel(userID, ri)
|
||||
if err != nil {
|
||||
|
||||
|
|
Reference in a new issue