dex/session/manager/manager.go
2016-07-19 11:23:04 -07:00

176 lines
3.8 KiB
Go

package manager
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/url"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc"
)
type GenerateCodeFunc func() (string, error)
func DefaultGenerateCode() (string, error) {
b := make([]byte, 8)
n, err := rand.Read(b)
if err != nil {
return "", err
}
if n != 8 {
return "", errors.New("unable to read enough random bytes")
}
return base64.URLEncoding.EncodeToString(b), nil
}
func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
return &SessionManager{
GenerateCode: DefaultGenerateCode,
Clock: clockwork.NewRealClock(),
ValidityWindow: session.DefaultSessionValidityWindow,
sessions: sRepo,
keys: skRepo,
}
}
type SessionManager struct {
GenerateCode GenerateCodeFunc
Clock clockwork.Clock
ValidityWindow time.Duration
sessions session.SessionRepo
keys session.SessionKeyRepo
}
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
sID, err := m.GenerateCode()
if err != nil {
return "", err
}
now := m.Clock.Now()
s := session.Session{
ConnectorID: connectorID,
ID: sID,
State: session.SessionStateNew,
CreatedAt: now,
ExpiresAt: now.Add(m.ValidityWindow),
ClientID: clientID,
ClientState: clientState,
RedirectURL: redirectURL,
Register: register,
Nonce: nonce,
Scope: scope,
}
err = m.sessions.Create(s)
if err != nil {
return "", err
}
return sID, nil
}
func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
key, err := m.GenerateCode()
if err != nil {
return "", err
}
k := session.SessionKey{
Key: key,
SessionID: sessionID,
}
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
err = m.keys.Push(k, sessionKeyValidityWindow)
if err != nil {
return "", err
}
return k.Key, nil
}
func (m *SessionManager) ExchangeKey(key string) (string, error) {
return m.keys.Pop(key)
}
func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
if s.State != state {
return nil, fmt.Errorf("session state %s, expect %s", s.State, state)
}
return s, nil
}
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, session.SessionStateNew)
if err != nil {
return nil, err
}
s.Identity = ident
s.State = session.SessionStateRemoteAttached
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
if err != nil {
return nil, err
}
s.UserID = userID
s.State = session.SessionStateIdentified
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) AttachGroups(sessionID string, groups []string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.Groups = groups
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.State = session.SessionStateDead
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
return m.sessions.Get(sessionID)
}