This repository has been archived on 2022-08-17. You can view files and clone it, but cannot push or open issues or pull requests.
dex/session/manager/manager.go

165 lines
3.5 KiB
Go
Raw Normal View History

package manager
2015-08-18 05:57:27 +05:30
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/url"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/session"
2015-08-18 05:57:27 +05:30
"github.com/coreos/go-oidc/oidc"
)
type GenerateCodeFunc func() (string, error)
func DefaultGenerateCode() (string, error) {
b := make([]byte, 8)
n, err := rand.Read(b)
if err != nil {
return "", err
}
if n != 8 {
2015-08-18 05:57:27 +05:30
return "", errors.New("unable to read enough random bytes")
}
return base64.URLEncoding.EncodeToString(b), nil
}
func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
2015-08-18 05:57:27 +05:30
return &SessionManager{
GenerateCode: DefaultGenerateCode,
Clock: clockwork.NewRealClock(),
ValidityWindow: session.DefaultSessionValidityWindow,
2015-08-18 05:57:27 +05:30
sessions: sRepo,
keys: skRepo,
}
}
type SessionManager struct {
GenerateCode GenerateCodeFunc
Clock clockwork.Clock
ValidityWindow time.Duration
sessions session.SessionRepo
keys session.SessionKeyRepo
2015-08-18 05:57:27 +05:30
}
2015-08-29 04:33:51 +05:30
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
2015-08-18 05:57:27 +05:30
sID, err := m.GenerateCode()
if err != nil {
return "", err
}
now := m.Clock.Now()
s := session.Session{
2015-08-18 05:57:27 +05:30
ConnectorID: connectorID,
ID: sID,
State: session.SessionStateNew,
2015-08-18 05:57:27 +05:30
CreatedAt: now,
ExpiresAt: now.Add(m.ValidityWindow),
ClientID: clientID,
ClientState: clientState,
RedirectURL: redirectURL,
Register: register,
Nonce: nonce,
2015-08-29 04:33:51 +05:30
Scope: scope,
2015-08-18 05:57:27 +05:30
}
err = m.sessions.Create(s)
if err != nil {
return "", err
}
return sID, nil
}
func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
key, err := m.GenerateCode()
if err != nil {
return "", err
}
k := session.SessionKey{
2015-08-18 05:57:27 +05:30
Key: key,
SessionID: sessionID,
}
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
2015-08-18 05:57:27 +05:30
err = m.keys.Push(k, sessionKeyValidityWindow)
if err != nil {
return "", err
}
return k.Key, nil
}
func (m *SessionManager) ExchangeKey(key string) (string, error) {
return m.keys.Pop(key)
}
func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
2015-08-18 05:57:27 +05:30
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
if s.State != state {
return nil, fmt.Errorf("session state %s, expect %s", s.State, state)
}
return s, nil
}
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, session.SessionStateNew)
2015-08-18 05:57:27 +05:30
if err != nil {
return nil, err
}
s.Identity = ident
s.State = session.SessionStateRemoteAttached
2015-08-18 05:57:27 +05:30
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
2015-08-18 05:57:27 +05:30
if err != nil {
return nil, err
}
s.UserID = userID
s.State = session.SessionStateIdentified
2015-08-18 05:57:27 +05:30
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
2015-08-18 05:57:27 +05:30
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.State = session.SessionStateDead
2015-08-18 05:57:27 +05:30
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
2015-08-18 05:57:27 +05:30
return m.sessions.Get(sessionID)
}