dex/db/session_key.go

115 lines
2.5 KiB
Go
Raw Normal View History

2015-08-18 05:57:27 +05:30
package db
import (
"errors"
"fmt"
"reflect"
2015-08-18 05:57:27 +05:30
"time"
"github.com/go-gorp/gorp"
2015-08-18 05:57:27 +05:30
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
)
const (
sessionKeyTableName = "session_key"
)
func init() {
register(table{
name: sessionKeyTableName,
model: sessionKeyModel{},
autoinc: false,
pkey: []string{"key"},
})
}
type sessionKeyModel struct {
Key string `db:"key"`
SessionID string `db:"session_id"`
ExpiresAt int64 `db:"expires_at"`
Stale bool `db:"stale"`
}
func NewSessionKeyRepo(dbm *gorp.DbMap) *SessionKeyRepo {
return NewSessionKeyRepoWithClock(dbm, clockwork.NewRealClock())
}
func NewSessionKeyRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionKeyRepo {
return &SessionKeyRepo{db: &db{dbm}, clock: clock}
2015-08-18 05:57:27 +05:30
}
type SessionKeyRepo struct {
*db
2015-08-18 05:57:27 +05:30
clock clockwork.Clock
}
func (r *SessionKeyRepo) Push(sk session.SessionKey, exp time.Duration) error {
skm := &sessionKeyModel{
Key: sk.Key,
SessionID: sk.SessionID,
ExpiresAt: r.clock.Now().Unix() + int64(exp.Seconds()),
Stale: false,
}
return r.executor(nil).Insert(skm)
2015-08-18 05:57:27 +05:30
}
func (r *SessionKeyRepo) Pop(key string) (string, error) {
m, err := r.executor(nil).Get(sessionKeyModel{}, key)
2015-08-18 05:57:27 +05:30
if err != nil {
return "", err
}
if m == nil {
return "", errors.New("session key does not exist")
}
2015-08-18 05:57:27 +05:30
skm, ok := m.(*sessionKeyModel)
if !ok {
log.Errorf("expected sessionKeyModel but found %v", reflect.TypeOf(m))
2015-08-18 05:57:27 +05:30
return "", errors.New("unrecognized model")
}
if skm.Stale || skm.ExpiresAt < r.clock.Now().Unix() {
return "", errors.New("invalid session key")
}
qt := r.quote(sessionKeyTableName)
2015-08-18 05:57:27 +05:30
q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt)
res, err := r.executor(nil).Exec(q, true, key, false)
2015-08-18 05:57:27 +05:30
if err != nil {
return "", err
}
if n, err := res.RowsAffected(); n != 1 {
if err != nil {
log.Errorf("Failed determining rows affected by UPDATE session_key query: %v", err)
}
return "", fmt.Errorf("failed to pop entity")
}
return skm.SessionID, nil
}
func (r *SessionKeyRepo) purge() error {
qt := r.quote(sessionKeyTableName)
2015-08-18 05:57:27 +05:30
q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt)
res, err := r.executor(nil).Exec(q, true, r.clock.Now().Unix())
2015-08-18 05:57:27 +05:30
if err != nil {
return err
}
d := "unknown # of"
if n, err := res.RowsAffected(); err == nil {
if n == 0 {
return nil
}
d = fmt.Sprintf("%d", n)
}
log.Infof("Deleted %s stale row(s) from %s table", d, sessionKeyTableName)
return nil
}