*: remove in memory session repos
Move manager to it's own package so it can import db. Move all references to the in memory session repos to use sqlite3.
This commit is contained in:
parent
5052d8007f
commit
7bac93aa20
15 changed files with 99 additions and 161 deletions
12
db/conn.go
12
db/conn.go
|
@ -101,3 +101,15 @@ func rollback(tx *gorp.Transaction) {
|
|||
log.Errorf("unable to rollback: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// NewMemDB creates a new in memory sqlite3 database.
|
||||
func NewMemDB() *gorp.DbMap {
|
||||
dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"})
|
||||
if err != nil {
|
||||
panic("Failed to create in memory database: " + err.Error())
|
||||
}
|
||||
if _, err := MigrateToLatest(dbMap); err != nil {
|
||||
panic("In memory database migration failed: " + err.Error())
|
||||
}
|
||||
return dbMap
|
||||
}
|
||||
|
|
|
@ -65,7 +65,7 @@ CREATE TABLE session (
|
|||
);
|
||||
|
||||
CREATE TABLE session_key (
|
||||
key text NOT NULL UNIQUE,
|
||||
key text NOT NULL,
|
||||
session_id text,
|
||||
expires_at bigint,
|
||||
stale integer
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
|
||||
clock := clockwork.NewFakeClock()
|
||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||
return session.NewSessionRepoWithClock(clock), clock
|
||||
return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock
|
||||
}
|
||||
dbMap := connect(t)
|
||||
return db.NewSessionRepoWithClock(dbMap, clock), clock
|
||||
|
@ -24,7 +24,7 @@ func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
|
|||
func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
|
||||
clock := clockwork.NewFakeClock()
|
||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||
return session.NewSessionKeyRepoWithClock(clock), clock
|
||||
return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock
|
||||
}
|
||||
dbMap := connect(t)
|
||||
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
|
||||
|
|
|
@ -10,10 +10,11 @@ import (
|
|||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
phttp "github.com/coreos/dex/pkg/http"
|
||||
"github.com/coreos/dex/refresh/refreshtest"
|
||||
"github.com/coreos/dex/server"
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
|
@ -33,7 +34,7 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
srv := &server.Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
|
@ -120,7 +121,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
||||
|
||||
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
|
||||
k, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
|
|
@ -19,10 +19,10 @@ import (
|
|||
"github.com/coreos/dex/email"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
type ServerConfig struct {
|
||||
|
@ -128,9 +128,9 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
}
|
||||
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
|
||||
|
||||
sRepo := session.NewSessionRepo()
|
||||
skRepo := session.NewSessionKeyRepo()
|
||||
sm := session.NewSessionManager(sRepo, skRepo)
|
||||
sRepo := db.NewSessionRepo(db.NewMemDB())
|
||||
skRepo := db.NewSessionKeyRepo(db.NewMemDB())
|
||||
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||
|
||||
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
|
||||
if err != nil {
|
||||
|
@ -142,7 +142,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
refTokRepo := refresh.NewRefreshTokenRepo()
|
||||
|
||||
txnFactory := repo.InMemTransactionFactory
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
|
||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
||||
srv.ClientIdentityRepo = ciRepo
|
||||
srv.KeySetRepo = kRepo
|
||||
srv.ConnectorConfigRepo = cfgRepo
|
||||
|
@ -180,10 +180,10 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
|||
cfgRepo := db.NewConnectorConfigRepo(dbc)
|
||||
userRepo := db.NewUserRepo(dbc)
|
||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
|
||||
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
||||
|
||||
sm := session.NewSessionManager(sRepo, skRepo)
|
||||
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||
|
||||
srv.ClientIdentityRepo = ciRepo
|
||||
srv.KeySetRepo = kRepo
|
||||
|
|
|
@ -17,7 +17,8 @@ import (
|
|||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
|
@ -75,7 +76,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
}
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
|
||||
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
|
@ -198,7 +199,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
}
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
|
||||
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||
oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
|
|
|
@ -9,10 +9,10 @@ import (
|
|||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
type sendResetPasswordEmailData struct {
|
||||
|
@ -28,7 +28,7 @@ type sendResetPasswordEmailData struct {
|
|||
type SendResetPasswordEmailHandler struct {
|
||||
tpl *template.Template
|
||||
emailer *useremail.UserEmailer
|
||||
sm *session.SessionManager
|
||||
sm *sessionmanager.SessionManager
|
||||
cr client.ClientIdentityRepo
|
||||
}
|
||||
|
||||
|
@ -182,7 +182,7 @@ type resetPasswordTemplateData struct {
|
|||
type ResetPasswordHandler struct {
|
||||
tpl *template.Template
|
||||
issuerURL url.URL
|
||||
um *manager.UserManager
|
||||
um *usermanager.UserManager
|
||||
keysFunc func() ([]key.PublicKey, error)
|
||||
}
|
||||
|
||||
|
@ -238,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
|
|||
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case manager.ErrorPasswordAlreadyChanged:
|
||||
case usermanager.ErrorPasswordAlreadyChanged:
|
||||
r.data.Error = "Link Expired"
|
||||
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
|
||||
r.data.DontShowForm = true
|
||||
|
|
|
@ -10,8 +10,9 @@ import (
|
|||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -274,7 +275,7 @@ func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url.
|
|||
return &ru
|
||||
}
|
||||
|
||||
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
|
||||
func registerFromLocalConnector(userManager *usermanager.UserManager, sessionManager *sessionmanager.SessionManager, ses *session.Session, email, password string) (string, error) {
|
||||
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -289,7 +290,7 @@ func registerFromLocalConnector(userManager *manager.UserManager, sessionManager
|
|||
return userID, nil
|
||||
}
|
||||
|
||||
func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
|
||||
func registerFromRemoteConnector(userManager *usermanager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
|
||||
if ses.Identity.ID == "" {
|
||||
return "", errors.New("No Identity found in session.")
|
||||
}
|
||||
|
|
|
@ -22,10 +22,11 @@ import (
|
|||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
usersapi "github.com/coreos/dex/user/api"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -57,7 +58,7 @@ type Server struct {
|
|||
IssuerURL url.URL
|
||||
KeyManager key.PrivateKeyManager
|
||||
KeySetRepo key.PrivateKeySetRepo
|
||||
SessionManager *session.SessionManager
|
||||
SessionManager *sessionmanager.SessionManager
|
||||
ClientIdentityRepo client.ClientIdentityRepo
|
||||
ConnectorConfigRepo connector.ConnectorConfigRepo
|
||||
Templates *template.Template
|
||||
|
@ -69,7 +70,7 @@ type Server struct {
|
|||
HealthChecks []health.Checkable
|
||||
Connectors []connector.Connector
|
||||
UserRepo user.UserRepo
|
||||
UserManager *manager.UserManager
|
||||
UserManager *usermanager.UserManager
|
||||
PasswordInfoRepo user.PasswordInfoRepo
|
||||
RefreshTokenRepo refresh.RefreshTokenRepo
|
||||
UserEmailer *useremail.UserEmailer
|
||||
|
|
|
@ -10,8 +10,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/refresh/refreshtest"
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
|
@ -68,7 +69,7 @@ func (ss *StaticSigner) JWK() jose.JWK {
|
|||
return jose.JWK{}
|
||||
}
|
||||
|
||||
func staticGenerateCodeFunc(code string) session.GenerateCodeFunc {
|
||||
func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
|
||||
return func() (string, error) {
|
||||
return code, nil
|
||||
}
|
||||
|
@ -120,7 +121,7 @@ func TestServerProviderConfig(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServerNewSession(t *testing.T) {
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
srv := &Server{
|
||||
SessionManager: sm,
|
||||
}
|
||||
|
@ -197,7 +198,7 @@ func TestServerLogin(t *testing.T) {
|
|||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
|
||||
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
|
@ -245,7 +246,7 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
|||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
||||
}
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
|
@ -286,7 +287,7 @@ func TestServerLoginDisabledUser(t *testing.T) {
|
|||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
|
||||
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
|
@ -343,7 +344,7 @@ func TestServerCodeToken(t *testing.T) {
|
|||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
|
||||
userRepo, err := makeNewUserRepo()
|
||||
if err != nil {
|
||||
|
@ -424,7 +425,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
|
|||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
|
@ -518,7 +519,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sm.GenerateCode = func() (string, error) { return keyFixture, nil }
|
||||
|
||||
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
|
||||
|
|
|
@ -10,12 +10,13 @@ import (
|
|||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/session"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -75,13 +76,13 @@ var (
|
|||
type testFixtures struct {
|
||||
srv *Server
|
||||
userRepo user.UserRepo
|
||||
sessionManager *session.SessionManager
|
||||
sessionManager *sessionmanager.SessionManager
|
||||
emailer *email.TemplatizedEmailer
|
||||
redirectURL url.URL
|
||||
clientIdentityRepo client.ClientIdentityRepo
|
||||
}
|
||||
|
||||
func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
|
||||
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
|
||||
x := 0
|
||||
return func() (string, error) {
|
||||
x += 1
|
||||
|
@ -113,9 +114,9 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
}
|
||||
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
||||
|
||||
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{})
|
||||
|
||||
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||
|
||||
emailer, err := email.NewTemplatizedEmailerFromGlobs(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package session
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -27,11 +28,11 @@ func DefaultGenerateCode() (string, error) {
|
|||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager {
|
||||
func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
|
||||
return &SessionManager{
|
||||
GenerateCode: DefaultGenerateCode,
|
||||
Clock: clockwork.NewRealClock(),
|
||||
ValidityWindow: DefaultSessionValidityWindow,
|
||||
ValidityWindow: session.DefaultSessionValidityWindow,
|
||||
sessions: sRepo,
|
||||
keys: skRepo,
|
||||
}
|
||||
|
@ -41,8 +42,8 @@ type SessionManager struct {
|
|||
GenerateCode GenerateCodeFunc
|
||||
Clock clockwork.Clock
|
||||
ValidityWindow time.Duration
|
||||
sessions SessionRepo
|
||||
keys SessionKeyRepo
|
||||
sessions session.SessionRepo
|
||||
keys session.SessionKeyRepo
|
||||
}
|
||||
|
||||
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
|
||||
|
@ -52,10 +53,10 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r
|
|||
}
|
||||
|
||||
now := m.Clock.Now()
|
||||
s := Session{
|
||||
s := session.Session{
|
||||
ConnectorID: connectorID,
|
||||
ID: sID,
|
||||
State: SessionStateNew,
|
||||
State: session.SessionStateNew,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(m.ValidityWindow),
|
||||
ClientID: clientID,
|
||||
|
@ -80,11 +81,12 @@ func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
k := SessionKey{
|
||||
k := session.SessionKey{
|
||||
Key: key,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
|
||||
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
|
||||
err = m.keys.Push(k, sessionKeyValidityWindow)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -97,7 +99,7 @@ func (m *SessionManager) ExchangeKey(key string) (string, error) {
|
|||
return m.keys.Pop(key)
|
||||
}
|
||||
|
||||
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) {
|
||||
func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
|
||||
s, err := m.sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -110,14 +112,14 @@ func (m *SessionManager) getSessionInState(sessionID string, state SessionState)
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) {
|
||||
s, err := m.getSessionInState(sessionID, SessionStateNew)
|
||||
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
|
||||
s, err := m.getSessionInState(sessionID, session.SessionStateNew)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.Identity = ident
|
||||
s.State = SessionStateRemoteAttached
|
||||
s.State = session.SessionStateRemoteAttached
|
||||
|
||||
if err = m.sessions.Update(*s); err != nil {
|
||||
return nil, err
|
||||
|
@ -126,14 +128,14 @@ func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Ident
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) {
|
||||
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached)
|
||||
func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
|
||||
s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.UserID = userID
|
||||
s.State = SessionStateIdentified
|
||||
s.State = session.SessionStateIdentified
|
||||
|
||||
if err = m.sessions.Update(*s); err != nil {
|
||||
return nil, err
|
||||
|
@ -142,13 +144,13 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session,
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) Kill(sessionID string) (*Session, error) {
|
||||
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
|
||||
s, err := m.sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.State = SessionStateDead
|
||||
s.State = session.SessionStateDead
|
||||
|
||||
if err = m.sessions.Update(*s); err != nil {
|
||||
return nil, err
|
||||
|
@ -157,6 +159,6 @@ func (m *SessionManager) Kill(sessionID string) (*Session, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) Get(sessionID string) (*Session, error) {
|
||||
func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
|
||||
return m.sessions.Get(sessionID)
|
||||
}
|
|
@ -1,9 +1,11 @@
|
|||
package session
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -13,8 +15,13 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func newManager(t *testing.T) *SessionManager {
|
||||
dbMap := db.NewMemDB()
|
||||
return NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
|
||||
}
|
||||
|
||||
func TestSessionManagerNewSession(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager(t)
|
||||
sm.GenerateCode = staticGenerateCodeFunc("boo")
|
||||
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
|
@ -26,7 +33,7 @@ func TestSessionManagerNewSession(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager(t)
|
||||
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
|
@ -43,7 +50,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionManagerExchangeKey(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager(t)
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
|
@ -68,8 +75,8 @@ func TestSessionManagerExchangeKey(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
ses, err := sm.getSessionInState("123", SessionStateNew)
|
||||
sm := newManager(t)
|
||||
ses, err := sm.getSessionInState("123", session.SessionStateNew)
|
||||
if err == nil {
|
||||
t.Errorf("Expected non-nil error")
|
||||
}
|
||||
|
@ -79,12 +86,12 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager(t)
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
ses, err := sm.getSessionInState(sessionID, SessionStateDead)
|
||||
ses, err := sm.getSessionInState(sessionID, session.SessionStateDead)
|
||||
if err == nil {
|
||||
t.Errorf("Expected non-nil error")
|
||||
}
|
||||
|
@ -94,7 +101,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionManagerKill(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm := newManager(t)
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
|
@ -1,11 +1,6 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
)
|
||||
import "time"
|
||||
|
||||
type SessionRepo interface {
|
||||
Get(string) (*Session, error)
|
||||
|
@ -17,87 +12,3 @@ type SessionKeyRepo interface {
|
|||
Push(SessionKey, time.Duration) error
|
||||
Pop(string) (string, error)
|
||||
}
|
||||
|
||||
func NewSessionRepo() SessionRepo {
|
||||
return NewSessionRepoWithClock(clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func NewSessionRepoWithClock(clock clockwork.Clock) SessionRepo {
|
||||
return &memSessionRepo{
|
||||
store: make(map[string]Session),
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
type memSessionRepo struct {
|
||||
store map[string]Session
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func (m *memSessionRepo) Get(sessionID string) (*Session, error) {
|
||||
s, ok := m.store[sessionID]
|
||||
if !ok || s.ExpiresAt.Before(m.clock.Now()) {
|
||||
return nil, errors.New("unrecognized ID")
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func (m *memSessionRepo) Create(s Session) error {
|
||||
if _, ok := m.store[s.ID]; ok {
|
||||
return errors.New("ID exists")
|
||||
}
|
||||
|
||||
m.store[s.ID] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memSessionRepo) Update(s Session) error {
|
||||
if _, ok := m.store[s.ID]; !ok {
|
||||
return errors.New("unrecognized ID")
|
||||
}
|
||||
m.store[s.ID] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
type expiringSessionKey struct {
|
||||
SessionKey
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func NewSessionKeyRepo() SessionKeyRepo {
|
||||
return NewSessionKeyRepoWithClock(clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func NewSessionKeyRepoWithClock(clock clockwork.Clock) SessionKeyRepo {
|
||||
return &memSessionKeyRepo{
|
||||
store: make(map[string]expiringSessionKey),
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
type memSessionKeyRepo struct {
|
||||
store map[string]expiringSessionKey
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func (m *memSessionKeyRepo) Pop(key string) (string, error) {
|
||||
esk, ok := m.store[key]
|
||||
if !ok {
|
||||
return "", errors.New("unrecognized key")
|
||||
}
|
||||
defer delete(m.store, key)
|
||||
|
||||
if esk.expiresAt.Before(m.clock.Now()) {
|
||||
return "", errors.New("expired key")
|
||||
}
|
||||
|
||||
return esk.SessionKey.SessionID, nil
|
||||
}
|
||||
|
||||
func (m *memSessionKeyRepo) Push(sk SessionKey, ttl time.Duration) error {
|
||||
m.store[sk.Key] = expiringSessionKey{
|
||||
SessionKey: sk,
|
||||
expiresAt: m.clock.Now().Add(ttl),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
2
test
2
test
|
@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
|
|||
|
||||
source ./build
|
||||
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email admin"
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session session/manager user user/api user/manager email admin"
|
||||
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
||||
|
||||
# user has not provided PKG override
|
||||
|
|
Reference in a new issue