7bac93aa20
Move manager to it's own package so it can import db. Move all references to the in memory session repos to use sqlite3.
164 lines
3.5 KiB
Go
164 lines
3.5 KiB
Go
package manager
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/jonboulle/clockwork"
|
|
|
|
"github.com/coreos/dex/session"
|
|
"github.com/coreos/go-oidc/oidc"
|
|
)
|
|
|
|
type GenerateCodeFunc func() (string, error)
|
|
|
|
func DefaultGenerateCode() (string, error) {
|
|
b := make([]byte, 8)
|
|
n, err := rand.Read(b)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if n != 8 {
|
|
return "", errors.New("unable to read enough random bytes")
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
|
|
return &SessionManager{
|
|
GenerateCode: DefaultGenerateCode,
|
|
Clock: clockwork.NewRealClock(),
|
|
ValidityWindow: session.DefaultSessionValidityWindow,
|
|
sessions: sRepo,
|
|
keys: skRepo,
|
|
}
|
|
}
|
|
|
|
type SessionManager struct {
|
|
GenerateCode GenerateCodeFunc
|
|
Clock clockwork.Clock
|
|
ValidityWindow time.Duration
|
|
sessions session.SessionRepo
|
|
keys session.SessionKeyRepo
|
|
}
|
|
|
|
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
|
|
sID, err := m.GenerateCode()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
now := m.Clock.Now()
|
|
s := session.Session{
|
|
ConnectorID: connectorID,
|
|
ID: sID,
|
|
State: session.SessionStateNew,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(m.ValidityWindow),
|
|
ClientID: clientID,
|
|
ClientState: clientState,
|
|
RedirectURL: redirectURL,
|
|
Register: register,
|
|
Nonce: nonce,
|
|
Scope: scope,
|
|
}
|
|
|
|
err = m.sessions.Create(s)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return sID, nil
|
|
}
|
|
|
|
func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
|
|
key, err := m.GenerateCode()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
k := session.SessionKey{
|
|
Key: key,
|
|
SessionID: sessionID,
|
|
}
|
|
|
|
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
|
|
err = m.keys.Push(k, sessionKeyValidityWindow)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return k.Key, nil
|
|
}
|
|
|
|
func (m *SessionManager) ExchangeKey(key string) (string, error) {
|
|
return m.keys.Pop(key)
|
|
}
|
|
|
|
func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
|
|
s, err := m.sessions.Get(sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if s.State != state {
|
|
return nil, fmt.Errorf("session state %s, expect %s", s.State, state)
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
|
|
s, err := m.getSessionInState(sessionID, session.SessionStateNew)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.Identity = ident
|
|
s.State = session.SessionStateRemoteAttached
|
|
|
|
if err = m.sessions.Update(*s); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
|
|
s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.UserID = userID
|
|
s.State = session.SessionStateIdentified
|
|
|
|
if err = m.sessions.Update(*s); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
|
|
s, err := m.sessions.Get(sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.State = session.SessionStateDead
|
|
|
|
if err = m.sessions.Update(*s); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
|
|
return m.sessions.Get(sessionID)
|
|
}
|