dex/db/session.go

204 lines
4.3 KiB
Go
Raw Permalink Normal View History

2015-08-18 05:57:27 +05:30
package db
import (
"encoding/json"
"errors"
"fmt"
"net/url"
"reflect"
2015-08-29 04:33:51 +05:30
"strings"
2015-08-18 05:57:27 +05:30
"time"
"github.com/go-gorp/gorp"
2015-08-18 05:57:27 +05:30
"github.com/jonboulle/clockwork"
"github.com/lib/pq"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc"
)
const (
sessionTableName = "session"
)
func init() {
register(table{
name: sessionTableName,
model: sessionModel{},
autoinc: false,
pkey: []string{"id"},
})
}
type sessionModel struct {
ID string `db:"id"`
State string `db:"state"`
CreatedAt int64 `db:"created_at"`
ExpiresAt int64 `db:"expires_at"`
ClientID string `db:"client_id"`
ClientState string `db:"client_state"`
RedirectURL string `db:"redirect_url"`
Identity string `db:"identity"`
ConnectorID string `db:"connector_id"`
UserID string `db:"user_id"`
Register bool `db:"register"`
Nonce string `db:"nonce"`
2015-08-29 04:33:51 +05:30
Scope string `db:"scope"`
2015-08-18 05:57:27 +05:30
}
func (s *sessionModel) session() (*session.Session, error) {
ru, err := url.Parse(s.RedirectURL)
if err != nil {
return nil, err
}
var ident oidc.Identity
if err = json.Unmarshal([]byte(s.Identity), &ident); err != nil {
return nil, err
}
// If this is not here, then ExpiresAt is unmarshaled with a "loc" field,
// which breaks tests.
if ident.ExpiresAt.IsZero() {
ident.ExpiresAt = time.Time{}
}
ses := session.Session{
ID: s.ID,
State: session.SessionState(s.State),
ClientID: s.ClientID,
ClientState: s.ClientState,
RedirectURL: *ru,
Identity: ident,
ConnectorID: s.ConnectorID,
UserID: s.UserID,
Register: s.Register,
Nonce: s.Nonce,
2015-08-29 04:33:51 +05:30
Scope: strings.Fields(s.Scope),
2015-08-18 05:57:27 +05:30
}
if s.CreatedAt != 0 {
ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC()
}
if s.ExpiresAt != 0 {
ses.ExpiresAt = time.Unix(s.ExpiresAt, 0).UTC()
}
return &ses, nil
}
func newSessionModel(s *session.Session) (*sessionModel, error) {
b, err := json.Marshal(s.Identity)
if err != nil {
return nil, err
}
sm := sessionModel{
ID: s.ID,
State: string(s.State),
ClientID: s.ClientID,
ClientState: s.ClientState,
RedirectURL: s.RedirectURL.String(),
Identity: string(b),
ConnectorID: s.ConnectorID,
UserID: s.UserID,
Register: s.Register,
Nonce: s.Nonce,
2015-08-29 04:33:51 +05:30
Scope: strings.Join(s.Scope, " "),
2015-08-18 05:57:27 +05:30
}
if !s.CreatedAt.IsZero() {
sm.CreatedAt = s.CreatedAt.Unix()
}
if !s.ExpiresAt.IsZero() {
sm.ExpiresAt = s.ExpiresAt.Unix()
}
return &sm, nil
}
func NewSessionRepo(dbm *gorp.DbMap) *SessionRepo {
return NewSessionRepoWithClock(dbm, clockwork.NewRealClock())
}
func NewSessionRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionRepo {
return &SessionRepo{dbMap: dbm, clock: clock}
}
type SessionRepo struct {
dbMap *gorp.DbMap
clock clockwork.Clock
}
func (r *SessionRepo) Get(sessionID string) (*session.Session, error) {
m, err := r.dbMap.Get(sessionModel{}, sessionID)
if err != nil {
return nil, err
}
if m == nil {
return nil, errors.New("session does not exist")
}
2015-08-18 05:57:27 +05:30
sm, ok := m.(*sessionModel)
if !ok {
log.Errorf("expected sessionModel but found %v", reflect.TypeOf(m))
2015-08-18 05:57:27 +05:30
return nil, errors.New("unrecognized model")
}
ses, err := sm.session()
if err != nil {
return nil, err
}
if ses.ExpiresAt.Before(r.clock.Now()) {
return nil, errors.New("session does not exist")
}
return ses, nil
}
func (r *SessionRepo) Create(s session.Session) error {
sm, err := newSessionModel(&s)
if err != nil {
return err
}
return r.dbMap.Insert(sm)
}
func (r *SessionRepo) Update(s session.Session) error {
sm, err := newSessionModel(&s)
if err != nil {
return err
}
n, err := r.dbMap.Update(sm)
if err != nil {
return err
}
if n != 1 {
return errors.New("update affected unexpected number of rows")
}
return nil
}
func (r *SessionRepo) purge() error {
qt := pq.QuoteIdentifier(sessionTableName)
q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt)
res, err := r.dbMap.Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead))
if err != nil {
return err
}
d := "unknown # of"
if n, err := res.RowsAffected(); err == nil {
if n == 0 {
return nil
}
d = fmt.Sprintf("%d", n)
}
log.Infof("Deleted %s stale row(s) from %s table", d, sessionTableName)
return nil
}