forked from mystiq/dex
115 lines
2.5 KiB
Go
115 lines
2.5 KiB
Go
package db
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/go-gorp/gorp"
|
|
"github.com/jonboulle/clockwork"
|
|
"github.com/lib/pq"
|
|
|
|
"github.com/coreos/dex/pkg/log"
|
|
"github.com/coreos/dex/session"
|
|
)
|
|
|
|
const (
|
|
sessionKeyTableName = "session_key"
|
|
)
|
|
|
|
func init() {
|
|
register(table{
|
|
name: sessionKeyTableName,
|
|
model: sessionKeyModel{},
|
|
autoinc: false,
|
|
pkey: []string{"key"},
|
|
})
|
|
}
|
|
|
|
type sessionKeyModel struct {
|
|
Key string `db:"key"`
|
|
SessionID string `db:"session_id"`
|
|
ExpiresAt int64 `db:"expires_at"`
|
|
Stale bool `db:"stale"`
|
|
}
|
|
|
|
func NewSessionKeyRepo(dbm *gorp.DbMap) *SessionKeyRepo {
|
|
return NewSessionKeyRepoWithClock(dbm, clockwork.NewRealClock())
|
|
}
|
|
|
|
func NewSessionKeyRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionKeyRepo {
|
|
return &SessionKeyRepo{dbMap: dbm, clock: clock}
|
|
}
|
|
|
|
type SessionKeyRepo struct {
|
|
dbMap *gorp.DbMap
|
|
clock clockwork.Clock
|
|
}
|
|
|
|
func (r *SessionKeyRepo) Push(sk session.SessionKey, exp time.Duration) error {
|
|
skm := &sessionKeyModel{
|
|
Key: sk.Key,
|
|
SessionID: sk.SessionID,
|
|
ExpiresAt: r.clock.Now().Unix() + int64(exp.Seconds()),
|
|
Stale: false,
|
|
}
|
|
return r.dbMap.Insert(skm)
|
|
}
|
|
|
|
func (r *SessionKeyRepo) Pop(key string) (string, error) {
|
|
m, err := r.dbMap.Get(sessionKeyModel{}, key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if m == nil {
|
|
return "", errors.New("session key does not exist")
|
|
}
|
|
|
|
skm, ok := m.(*sessionKeyModel)
|
|
if !ok {
|
|
log.Errorf("expected sessionKeyModel but found %v", reflect.TypeOf(m))
|
|
return "", errors.New("unrecognized model")
|
|
}
|
|
|
|
if skm.Stale || skm.ExpiresAt < r.clock.Now().Unix() {
|
|
return "", errors.New("invalid session key")
|
|
}
|
|
|
|
qt := pq.QuoteIdentifier(sessionKeyTableName)
|
|
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
|
|
res, err := r.dbMap.Exec(q, true, key, false)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if n, err := res.RowsAffected(); n != 1 {
|
|
if err != nil {
|
|
log.Errorf("Failed determining rows affected by UPDATE session_key query: %v", err)
|
|
}
|
|
return "", fmt.Errorf("failed to pop entity")
|
|
}
|
|
|
|
return skm.SessionID, nil
|
|
}
|
|
|
|
func (r *SessionKeyRepo) purge() error {
|
|
qt := pq.QuoteIdentifier(sessionKeyTableName)
|
|
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
|
|
res, err := r.dbMap.Exec(q, true, r.clock.Now().Unix())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
d := "unknown # of"
|
|
if n, err := res.RowsAffected(); err == nil {
|
|
if n == 0 {
|
|
return nil
|
|
}
|
|
d = fmt.Sprintf("%d", n)
|
|
}
|
|
|
|
log.Infof("Deleted %s stale row(s) from %s table", d, sessionKeyTableName)
|
|
return nil
|
|
}
|