dex/session/manager.go
2015-08-31 13:51:59 -07:00

161 lines
3.3 KiB
Go

package session
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/url"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/go-oidc/oidc"
)
type GenerateCodeFunc func() (string, error)
func DefaultGenerateCode() (string, error) {
b := make([]byte, 8)
n, err := rand.Read(b)
if err != nil {
return "", err
} else if n != 8 {
return "", errors.New("unable to read enough random bytes")
}
return base64.URLEncoding.EncodeToString(b), nil
}
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager {
return &SessionManager{
GenerateCode: DefaultGenerateCode,
Clock: clockwork.NewRealClock(),
ValidityWindow: DefaultSessionValidityWindow,
sessions: sRepo,
keys: skRepo,
}
}
type SessionManager struct {
GenerateCode GenerateCodeFunc
Clock clockwork.Clock
ValidityWindow time.Duration
sessions SessionRepo
keys SessionKeyRepo
}
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
sID, err := m.GenerateCode()
if err != nil {
return "", err
}
now := m.Clock.Now()
s := Session{
ConnectorID: connectorID,
ID: sID,
State: SessionStateNew,
CreatedAt: now,
ExpiresAt: now.Add(m.ValidityWindow),
ClientID: clientID,
ClientState: clientState,
RedirectURL: redirectURL,
Register: register,
Nonce: nonce,
Scope: scope,
}
err = m.sessions.Create(s)
if err != nil {
return "", err
}
return sID, nil
}
func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
key, err := m.GenerateCode()
if err != nil {
return "", err
}
k := SessionKey{
Key: key,
SessionID: sessionID,
}
err = m.keys.Push(k, sessionKeyValidityWindow)
if err != nil {
return "", err
}
return k.Key, nil
}
func (m *SessionManager) ExchangeKey(key string) (string, error) {
return m.keys.Pop(key)
}
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
if s.State != state {
return nil, fmt.Errorf("session state %s, expect %s", s.State, state)
}
return s, nil
}
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateNew)
if err != nil {
return nil, err
}
s.Identity = ident
s.State = SessionStateRemoteAttached
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached)
if err != nil {
return nil, err
}
s.UserID = userID
s.State = SessionStateIdentified
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Kill(sessionID string) (*Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.State = SessionStateDead
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Get(sessionID string) (*Session, error) {
return m.sessions.Get(sessionID)
}