forked from mystiq/dex
db: clean up quote and executor function calls, improve translate docs
This commit is contained in:
parent
2cc0ae6fac
commit
ed5dee9960
12 changed files with 169 additions and 126 deletions
36
db/client.go
36
db/client.go
|
@ -86,15 +86,21 @@ func (m *clientIdentityModel) ClientIdentity() (*oidc.ClientIdentity, error) {
|
|||
}
|
||||
|
||||
func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo {
|
||||
return &clientIdentityRepo{dbMap: dbm}
|
||||
return newClientIdentityRepo(dbm)
|
||||
}
|
||||
|
||||
func newClientIdentityRepo(dbm *gorp.DbMap) *clientIdentityRepo {
|
||||
return &clientIdentityRepo{db: &db{dbm}}
|
||||
}
|
||||
|
||||
func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) {
|
||||
tx, err := dbm.Begin()
|
||||
repo := newClientIdentityRepo(dbm)
|
||||
tx, err := repo.begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := repo.executor(tx)
|
||||
for _, c := range clients {
|
||||
dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret)
|
||||
if err != nil {
|
||||
|
@ -104,7 +110,7 @@ func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIden
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = tx.Insert(cm)
|
||||
err = exec.Insert(cm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -112,15 +118,15 @@ func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIden
|
|||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewClientIdentityRepo(dbm), nil
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
type clientIdentityRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
}
|
||||
|
||||
func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
|
||||
m, err := r.executor(nil).Get(clientIdentityModel{}, clientID)
|
||||
if err == sql.ErrNoRows || m == nil {
|
||||
return nil, client.ErrorNotFound
|
||||
}
|
||||
|
@ -143,7 +149,7 @@ func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, er
|
|||
}
|
||||
|
||||
func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
|
||||
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
|
||||
m, err := r.executor(nil).Get(clientIdentityModel{}, clientID)
|
||||
if m == nil || err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -158,15 +164,15 @@ func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
|
|||
}
|
||||
|
||||
func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
||||
tx, err := r.dbMap.Begin()
|
||||
tx, err := r.begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := r.executor(tx)
|
||||
|
||||
m, err := tx.Get(clientIdentityModel{}, clientID)
|
||||
m, err := exec.Get(clientIdentityModel{}, clientID)
|
||||
if m == nil || err != nil {
|
||||
rollback(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -177,7 +183,7 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
|||
}
|
||||
|
||||
cim.DexAdmin = isAdmin
|
||||
_, err = tx.Update(cim)
|
||||
_, err = exec.Update(cim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -186,7 +192,7 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
|||
}
|
||||
|
||||
func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
||||
m, err := r.dbMap.Get(clientIdentityModel{}, creds.ID)
|
||||
m, err := r.executor(nil).Get(clientIdentityModel{}, creds.ID)
|
||||
if m == nil || err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -222,7 +228,7 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := r.dbMap.Insert(cim); err != nil {
|
||||
if err := r.executor(nil).Insert(cim); err != nil {
|
||||
switch sqlErr := err.(type) {
|
||||
case *pq.Error:
|
||||
if sqlErr.Code == pgErrorCodeUniqueViolation {
|
||||
|
@ -246,9 +252,9 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli
|
|||
}
|
||||
|
||||
func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", clientIdentityTableName)
|
||||
qt := r.quote(clientIdentityTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||
objs, err := r.dbMap.Select(&clientIdentityModel{}, q)
|
||||
objs, err := r.executor(nil).Select(&clientIdentityModel{}, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
|
||||
// Import database drivers
|
||||
|
@ -95,13 +94,6 @@ func TransactionFactory(conn *gorp.DbMap) repo.TransactionFactory {
|
|||
}
|
||||
}
|
||||
|
||||
func rollback(tx *gorp.Transaction) {
|
||||
err := tx.Rollback()
|
||||
if err != nil {
|
||||
log.Errorf("unable to rollback: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// NewMemDB creates a new in memory sqlite3 database.
|
||||
func NewMemDB() *gorp.DbMap {
|
||||
dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"})
|
||||
|
|
|
@ -60,17 +60,17 @@ func (m *connectorConfigModel) ConnectorConfig() (connector.ConnectorConfig, err
|
|||
}
|
||||
|
||||
func NewConnectorConfigRepo(dbm *gorp.DbMap) *ConnectorConfigRepo {
|
||||
return &ConnectorConfigRepo{dbMap: dbm}
|
||||
return &ConnectorConfigRepo{&db{dbm}}
|
||||
}
|
||||
|
||||
type ConnectorConfigRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
|
||||
qt := r.quote(connectorConfigTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||
objs, err := r.dbMap.Select(&connectorConfigModel{}, q)
|
||||
objs, err := r.executor(nil).Select(&connectorConfigModel{}, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -93,10 +93,10 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
|||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
|
||||
qt := r.quote(connectorConfigTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
|
||||
var c connectorConfigModel
|
||||
if err := executor(r.dbMap, tx).SelectOne(&c, q, id); err != nil {
|
||||
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, connector.ErrorNotFound
|
||||
}
|
||||
|
@ -116,19 +116,20 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
|||
insert[i] = m
|
||||
}
|
||||
|
||||
tx, err := r.dbMap.Begin()
|
||||
tx, err := r.begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := r.executor(tx)
|
||||
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName)
|
||||
qt := r.quote(connectorConfigTableName)
|
||||
q := fmt.Sprintf("DELETE FROM %s", qt)
|
||||
if _, err = tx.Exec(q); err != nil {
|
||||
if _, err = exec.Exec(q); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.Insert(insert...); err != nil {
|
||||
if err = exec.Insert(insert...); err != nil {
|
||||
return fmt.Errorf("DB insert failed %#v: %v", insert, err)
|
||||
}
|
||||
|
||||
|
|
60
db/db.go
Normal file
60
db/db.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
// Package db provides SQL implementations of dex's storage interfaces.
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/db/translate"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
// db is the connection type passed to repos.
|
||||
//
|
||||
// TODO(ericchiang): Eventually just return this instead of gorp.DbMap during Conn.
|
||||
// All actions should go through this type instead of dbMap.
|
||||
type db struct {
|
||||
dbMap *gorp.DbMap
|
||||
}
|
||||
|
||||
// executor returns a driver agnostic SQL executor.
|
||||
//
|
||||
// The expected flavor of all queries is the flavor used by github.com/lib/pq. All bind
|
||||
// parameters must be unique and in sequential order (e.g. $1, $2, ...).
|
||||
//
|
||||
// See github.com/coreos/dex/db/translate for details on the translation.
|
||||
//
|
||||
// If tx is nil, a non-transaction context is provided.
|
||||
func (conn *db) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
var exec gorp.SqlExecutor
|
||||
if tx == nil {
|
||||
exec = conn.dbMap
|
||||
} else {
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
|
||||
// Check if the underlying value of the pointer is nil.
|
||||
// This is not caught by the initial comparison (tx == nil).
|
||||
if gorpTx == nil {
|
||||
exec = conn.dbMap
|
||||
} else {
|
||||
exec = gorpTx
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := conn.dbMap.Dialect.(gorp.SqliteDialect); ok {
|
||||
exec = translate.NewTranslatingExecutor(exec, translate.PostgresToSQLite)
|
||||
}
|
||||
return exec
|
||||
}
|
||||
|
||||
// quote escapes a table name for a driver specific SQL query. quote uses the
|
||||
// gorp's package underlying quote logic and should NOT be used on untrusted input.
|
||||
func (conn *db) quote(tableName string) string {
|
||||
return conn.dbMap.Dialect.QuotedTableForQuery("", tableName)
|
||||
}
|
||||
|
||||
func (conn *db) begin() (repo.Transaction, error) {
|
||||
return conn.dbMap.Begin()
|
||||
}
|
22
db/key.go
22
db/key.go
|
@ -98,7 +98,7 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, useOldFormat bool, secrets ...[]byte)
|
|||
}
|
||||
|
||||
r := &PrivateKeySetRepo{
|
||||
dbMap: dbm,
|
||||
db: &db{dbm},
|
||||
useOldFormat: useOldFormat,
|
||||
secrets: secrets,
|
||||
}
|
||||
|
@ -107,17 +107,22 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, useOldFormat bool, secrets ...[]byte)
|
|||
}
|
||||
|
||||
type PrivateKeySetRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
useOldFormat bool
|
||||
secrets [][]byte
|
||||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
|
||||
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt))
|
||||
qt := r.quote(keyTableName)
|
||||
tx, err := r.begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := r.executor(tx)
|
||||
if _, err := exec.Exec(fmt.Sprintf("DELETE FROM %s", qt)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pks, ok := ks.(*key.PrivateKeySet)
|
||||
if !ok {
|
||||
|
@ -147,12 +152,15 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
|||
}
|
||||
|
||||
b := &privateKeySetBlob{Value: v}
|
||||
return r.dbMap.Insert(b)
|
||||
if err := exec.Insert(b); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName)
|
||||
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
|
||||
qt := r.quote(keyTableName)
|
||||
objs, err := r.executor(nil).Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ type passwordInfoModel struct {
|
|||
|
||||
func NewPasswordInfoRepo(dbm *gorp.DbMap) user.PasswordInfoRepo {
|
||||
return &passwordInfoRepo{
|
||||
dbMap: dbm,
|
||||
db: &db{dbm},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ func NewPasswordInfoRepoFromPasswordInfos(dbm *gorp.DbMap, infos []user.Password
|
|||
}
|
||||
|
||||
type passwordInfoRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
}
|
||||
|
||||
func (r *passwordInfoRepo) Get(tx repo.Transaction, userID string) (user.PasswordInfo, error) {
|
||||
|
@ -101,7 +101,7 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err
|
|||
}
|
||||
|
||||
func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) {
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
|
||||
m, err := ex.Get(passwordInfoModel{}, id)
|
||||
if err != nil {
|
||||
|
@ -122,7 +122,7 @@ func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInf
|
|||
}
|
||||
|
||||
func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) error {
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
pm, err := newPasswordInfoModel(&pw)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -131,7 +131,7 @@ func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) err
|
|||
}
|
||||
|
||||
func (r *passwordInfoRepo) update(tx repo.Transaction, pw user.PasswordInfo) error {
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
pm, err := newPasswordInfoModel(&pw)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -29,7 +30,7 @@ func init() {
|
|||
}
|
||||
|
||||
type refreshTokenRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
tokenGenerator refresh.RefreshTokenGenerator
|
||||
}
|
||||
|
||||
|
@ -77,15 +78,12 @@ func checkTokenPayload(payloadHash, payload []byte) error {
|
|||
}
|
||||
|
||||
func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo {
|
||||
return &refreshTokenRepo{
|
||||
dbMap: dbm,
|
||||
tokenGenerator: refresh.DefaultRefreshTokenGenerator,
|
||||
}
|
||||
return NewRefreshTokenRepoWithGenerator(dbm, refresh.DefaultRefreshTokenGenerator)
|
||||
}
|
||||
|
||||
func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo {
|
||||
return &refreshTokenRepo{
|
||||
dbMap: dbm,
|
||||
db: &db{dbm},
|
||||
tokenGenerator: gen,
|
||||
}
|
||||
}
|
||||
|
@ -115,7 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
|
|||
ClientID: clientID,
|
||||
}
|
||||
|
||||
if err := r.dbMap.Insert(record); err != nil {
|
||||
if err := r.executor(nil).Insert(record); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
@ -151,7 +149,13 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
record, err := r.get(nil, tokenID)
|
||||
tx, err := r.begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := r.executor(tx)
|
||||
record, err := r.get(tx, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -164,7 +168,7 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
deleted, err := r.dbMap.Delete(record)
|
||||
deleted, err := exec.Delete(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -172,10 +176,11 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
|
|||
return refresh.ErrorInvalidToken
|
||||
}
|
||||
|
||||
return nil
|
||||
return tx.Commit()
|
||||
}
|
||||
func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) {
|
||||
ex := executor(r.dbMap, tx)
|
||||
|
||||
func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) {
|
||||
ex := r.executor(tx)
|
||||
result, err := ex.Get(refreshTokenModel{}, tokenID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -123,16 +123,16 @@ func NewSessionRepo(dbm *gorp.DbMap) *SessionRepo {
|
|||
}
|
||||
|
||||
func NewSessionRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionRepo {
|
||||
return &SessionRepo{dbMap: dbm, clock: clock}
|
||||
return &SessionRepo{db: &db{dbm}, clock: clock}
|
||||
}
|
||||
|
||||
type SessionRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func (r *SessionRepo) Get(sessionID string) (*session.Session, error) {
|
||||
m, err := r.dbMap.Get(sessionModel{}, sessionID)
|
||||
m, err := r.executor(nil).Get(sessionModel{}, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ func (r *SessionRepo) Create(s session.Session) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.dbMap.Insert(sm)
|
||||
return r.executor(nil).Insert(sm)
|
||||
}
|
||||
|
||||
func (r *SessionRepo) Update(s session.Session) error {
|
||||
|
@ -171,7 +171,7 @@ func (r *SessionRepo) Update(s session.Session) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, err := r.dbMap.Update(sm)
|
||||
n, err := r.executor(nil).Update(sm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -182,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error {
|
|||
}
|
||||
|
||||
func (r *SessionRepo) purge() error {
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionTableName)
|
||||
qt := r.quote(sessionTableName)
|
||||
q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt)
|
||||
res, err := executor(r.dbMap, nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
|
||||
res, err := r.executor(nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -38,11 +38,11 @@ func NewSessionKeyRepo(dbm *gorp.DbMap) *SessionKeyRepo {
|
|||
}
|
||||
|
||||
func NewSessionKeyRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionKeyRepo {
|
||||
return &SessionKeyRepo{dbMap: dbm, clock: clock}
|
||||
return &SessionKeyRepo{db: &db{dbm}, clock: clock}
|
||||
}
|
||||
|
||||
type SessionKeyRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
|
@ -53,11 +53,11 @@ func (r *SessionKeyRepo) Push(sk session.SessionKey, exp time.Duration) error {
|
|||
ExpiresAt: r.clock.Now().Unix() + int64(exp.Seconds()),
|
||||
Stale: false,
|
||||
}
|
||||
return r.dbMap.Insert(skm)
|
||||
return r.executor(nil).Insert(skm)
|
||||
}
|
||||
|
||||
func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
||||
m, err := r.dbMap.Get(sessionKeyModel{}, key)
|
||||
m, err := r.executor(nil).Get(sessionKeyModel{}, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -76,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
|||
return "", errors.New("invalid session key")
|
||||
}
|
||||
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
|
||||
qt := r.quote(sessionKeyTableName)
|
||||
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
|
||||
res, err := executor(r.dbMap, nil).Exec(q, true, key, false)
|
||||
res, err := r.executor(nil).Exec(q, true, key, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -94,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
|||
}
|
||||
|
||||
func (r *SessionKeyRepo) purge() error {
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName)
|
||||
qt := r.quote(sessionKeyTableName)
|
||||
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
|
||||
res, err := executor(r.dbMap, nil).Exec(q, true, r.clock.Now().Unix())
|
||||
res, err := r.executor(nil).Exec(q, true, r.clock.Now().Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/db/translate"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
func executor(dbMap *gorp.DbMap, tx repo.Transaction) gorp.SqlExecutor {
|
||||
var exec gorp.SqlExecutor
|
||||
if tx == nil {
|
||||
exec = dbMap
|
||||
} else {
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
|
||||
// Check if the underlying value of the pointer is nil.
|
||||
// This is not caught by the initial comparison (tx == nil).
|
||||
if gorpTx == nil {
|
||||
exec = dbMap
|
||||
} else {
|
||||
exec = gorpTx
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := dbMap.Dialect.(gorp.SqliteDialect); ok {
|
||||
exec = translate.NewExecutor(exec, translate.PostgresToSQLite)
|
||||
}
|
||||
return exec
|
||||
}
|
|
@ -15,14 +15,18 @@ var (
|
|||
trueRegexp = regexp.MustCompile(`\btrue\b`)
|
||||
)
|
||||
|
||||
// PostgresToSQLite implements translation of the pq driver to sqlite3.
|
||||
// PostgresToSQLite translates github.com/lib/pq flavored SQL quries to github.com/mattn/go-sqlite3's flavor.
|
||||
//
|
||||
// It assumes that possitional bind arguements ($1, $2, etc.) are unqiue and in sequential order.
|
||||
func PostgresToSQLite(query string) string {
|
||||
query = bindRegexp.ReplaceAllString(query, "?")
|
||||
query = trueRegexp.ReplaceAllString(query, "1")
|
||||
return query
|
||||
}
|
||||
|
||||
func NewExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor {
|
||||
// NewTranslatingExecutor returns an executor wrapping the existing executor. All query strings passed to
|
||||
// the executor will be run through the translate function before begin passed to the driver.
|
||||
func NewTranslatingExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor {
|
||||
return &executor{exec, translate}
|
||||
}
|
||||
|
||||
|
|
38
db/user.go
38
db/user.go
|
@ -41,7 +41,7 @@ func init() {
|
|||
|
||||
func NewUserRepo(dbm *gorp.DbMap) user.UserRepo {
|
||||
return &userRepo{
|
||||
dbMap: dbm,
|
||||
db: &db{dbm},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ func NewUserRepoFromUsers(dbm *gorp.DbMap, us []user.UserWithRemoteIdentities) (
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = repo.dbMap.Insert(um)
|
||||
err = repo.executor(nil).Insert(um)
|
||||
for _, ri := range u.RemoteIdentities {
|
||||
err = repo.AddRemoteIdentity(nil, u.User.ID, ri)
|
||||
if err != nil {
|
||||
|
@ -64,7 +64,7 @@ func NewUserRepoFromUsers(dbm *gorp.DbMap, us []user.UserWithRemoteIdentities) (
|
|||
}
|
||||
|
||||
type userRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
*db
|
||||
}
|
||||
|
||||
func (r *userRepo) Get(tx repo.Transaction, userID string) (user.User, error) {
|
||||
|
@ -106,8 +106,8 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err
|
|||
return user.ErrorInvalidID
|
||||
}
|
||||
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||
ex := executor(r.dbMap, tx)
|
||||
qt := r.quote(userTableName)
|
||||
ex := r.executor(tx)
|
||||
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -220,7 +220,7 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
|
|||
return err
|
||||
}
|
||||
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
deleted, err := ex.Delete(rim)
|
||||
|
||||
if err != nil {
|
||||
|
@ -235,12 +235,12 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid
|
|||
}
|
||||
|
||||
func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) {
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
if userID == "" {
|
||||
return nil, user.ErrorInvalidID
|
||||
}
|
||||
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", remoteIdentityMappingTableName)
|
||||
qt := r.quote(remoteIdentityMappingTableName)
|
||||
rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID)
|
||||
|
||||
if err != nil {
|
||||
|
@ -271,8 +271,8 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
|
|||
}
|
||||
|
||||
func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) {
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||
ex := executor(r.dbMap, tx)
|
||||
qt := r.quote(userTableName)
|
||||
ex := r.executor(tx)
|
||||
i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt))
|
||||
return int(i), err
|
||||
}
|
||||
|
@ -286,9 +286,9 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
|
|||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||
qt := r.quote(userTableName)
|
||||
|
||||
// Ask for one more than needed so we know if there's more results, and
|
||||
// hence, whether a nextPageToken is necessary.
|
||||
|
@ -336,7 +336,7 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults
|
|||
}
|
||||
|
||||
func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
um, err := newUserModel(&usr)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -345,7 +345,7 @@ func (r *userRepo) insert(tx repo.Transaction, usr user.User) error {
|
|||
}
|
||||
|
||||
func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
um, err := newUserModel(&usr)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -355,7 +355,7 @@ func (r *userRepo) update(tx repo.Transaction, usr user.User) error {
|
|||
}
|
||||
|
||||
func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
|
||||
m, err := ex.Get(userModel{}, userID)
|
||||
if err != nil {
|
||||
|
@ -376,7 +376,7 @@ func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) {
|
|||
}
|
||||
|
||||
func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) {
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
|
||||
m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID)
|
||||
if err != nil {
|
||||
|
@ -397,8 +397,8 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot
|
|||
}
|
||||
|
||||
func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) {
|
||||
qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName)
|
||||
ex := executor(r.dbMap, tx)
|
||||
qt := r.quote(userTableName)
|
||||
ex := r.executor(tx)
|
||||
var um userModel
|
||||
err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email)
|
||||
|
||||
|
@ -412,7 +412,7 @@ func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, err
|
|||
}
|
||||
|
||||
func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error {
|
||||
ex := executor(r.dbMap, tx)
|
||||
ex := r.executor(tx)
|
||||
rim, err := newRemoteIdentityMappingModel(userID, ri)
|
||||
if err != nil {
|
||||
|
||||
|
|
Loading…
Reference in a new issue