forked from mystiq/dex
160 lines
3.2 KiB
Go
160 lines
3.2 KiB
Go
package session
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/jonboulle/clockwork"
|
|
|
|
"github.com/coreos/go-oidc/oidc"
|
|
)
|
|
|
|
type GenerateCodeFunc func() (string, error)
|
|
|
|
func DefaultGenerateCode() (string, error) {
|
|
b := make([]byte, 8)
|
|
n, err := rand.Read(b)
|
|
if err != nil {
|
|
return "", err
|
|
} else if n != 8 {
|
|
return "", errors.New("unable to read enough random bytes")
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager {
|
|
return &SessionManager{
|
|
GenerateCode: DefaultGenerateCode,
|
|
Clock: clockwork.NewRealClock(),
|
|
ValidityWindow: DefaultSessionValidityWindow,
|
|
sessions: sRepo,
|
|
keys: skRepo,
|
|
}
|
|
}
|
|
|
|
type SessionManager struct {
|
|
GenerateCode GenerateCodeFunc
|
|
Clock clockwork.Clock
|
|
ValidityWindow time.Duration
|
|
sessions SessionRepo
|
|
keys SessionKeyRepo
|
|
}
|
|
|
|
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool) (string, error) {
|
|
sID, err := m.GenerateCode()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
now := m.Clock.Now()
|
|
s := Session{
|
|
ConnectorID: connectorID,
|
|
ID: sID,
|
|
State: SessionStateNew,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(m.ValidityWindow),
|
|
ClientID: clientID,
|
|
ClientState: clientState,
|
|
RedirectURL: redirectURL,
|
|
Register: register,
|
|
Nonce: nonce,
|
|
}
|
|
|
|
err = m.sessions.Create(s)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return sID, nil
|
|
}
|
|
|
|
func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
|
|
key, err := m.GenerateCode()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
k := SessionKey{
|
|
Key: key,
|
|
SessionID: sessionID,
|
|
}
|
|
|
|
err = m.keys.Push(k, sessionKeyValidityWindow)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return k.Key, nil
|
|
}
|
|
|
|
func (m *SessionManager) ExchangeKey(key string) (string, error) {
|
|
return m.keys.Pop(key)
|
|
}
|
|
|
|
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) {
|
|
s, err := m.sessions.Get(sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if s.State != state {
|
|
return nil, fmt.Errorf("session state %s, expect %s", s.State, state)
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) {
|
|
s, err := m.getSessionInState(sessionID, SessionStateNew)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.Identity = ident
|
|
s.State = SessionStateRemoteAttached
|
|
|
|
if err = m.sessions.Update(*s); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) {
|
|
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.UserID = userID
|
|
s.State = SessionStateIdentified
|
|
|
|
if err = m.sessions.Update(*s); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *SessionManager) Kill(sessionID string) (*Session, error) {
|
|
s, err := m.sessions.Get(sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.State = SessionStateDead
|
|
|
|
if err = m.sessions.Update(*s); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *SessionManager) Get(sessionID string) (*Session, error) {
|
|
return m.sessions.Get(sessionID)
|
|
}
|