From 7bac93aa2087abc110302123d5d9eb3182ac675f Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Tue, 9 Feb 2016 11:10:28 -0800 Subject: [PATCH] *: remove in memory session repos Move manager to it's own package so it can import db. Move all references to the in memory session repos to use sqlite3. --- db/conn.go | 12 ++++ db/migrate_sqlite3.go | 2 +- functional/repo/session_repo_test.go | 4 +- integration/oidc_test.go | 7 ++- server/config.go | 16 ++--- server/http_test.go | 7 ++- server/password.go | 10 +-- server/register.go | 7 ++- server/server.go | 7 ++- server/server_test.go | 19 +++--- server/testutil.go | 13 ++-- session/{ => manager}/manager.go | 38 +++++------ session/{ => manager}/manager_test.go | 25 +++++--- session/repo.go | 91 +-------------------------- test | 2 +- 15 files changed, 99 insertions(+), 161 deletions(-) rename session/{ => manager}/manager.go (72%) rename session/{ => manager}/manager_test.go (84%) diff --git a/db/conn.go b/db/conn.go index 8ff115f1..7d31bd52 100644 --- a/db/conn.go +++ b/db/conn.go @@ -101,3 +101,15 @@ func rollback(tx *gorp.Transaction) { log.Errorf("unable to rollback: %v", err) } } + +// NewMemDB creates a new in memory sqlite3 database. +func NewMemDB() *gorp.DbMap { + dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"}) + if err != nil { + panic("Failed to create in memory database: " + err.Error()) + } + if _, err := MigrateToLatest(dbMap); err != nil { + panic("In memory database migration failed: " + err.Error()) + } + return dbMap +} diff --git a/db/migrate_sqlite3.go b/db/migrate_sqlite3.go index 3cbfc7c3..af00fcf8 100644 --- a/db/migrate_sqlite3.go +++ b/db/migrate_sqlite3.go @@ -65,7 +65,7 @@ CREATE TABLE session ( ); CREATE TABLE session_key ( - key text NOT NULL UNIQUE, + key text NOT NULL, session_id text, expires_at bigint, stale integer diff --git a/functional/repo/session_repo_test.go b/functional/repo/session_repo_test.go index 2be8672b..4f939e52 100644 --- a/functional/repo/session_repo_test.go +++ b/functional/repo/session_repo_test.go @@ -15,7 +15,7 @@ import ( func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) { clock := clockwork.NewFakeClock() if os.Getenv("DEX_TEST_DSN") == "" { - return session.NewSessionRepoWithClock(clock), clock + return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock } dbMap := connect(t) return db.NewSessionRepoWithClock(dbMap, clock), clock @@ -24,7 +24,7 @@ func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) { func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) { clock := clockwork.NewFakeClock() if os.Getenv("DEX_TEST_DSN") == "" { - return session.NewSessionKeyRepoWithClock(clock), clock + return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock } dbMap := connect(t) return db.NewSessionKeyRepoWithClock(dbMap, clock), clock diff --git a/integration/oidc_test.go b/integration/oidc_test.go index 51bf0288..e4fe1802 100644 --- a/integration/oidc_test.go +++ b/integration/oidc_test.go @@ -10,10 +10,11 @@ import ( "github.com/coreos/dex/client" "github.com/coreos/dex/connector" + "github.com/coreos/dex/db" phttp "github.com/coreos/dex/pkg/http" "github.com/coreos/dex/refresh/refreshtest" "github.com/coreos/dex/server" - "github.com/coreos/dex/session" + "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/key" @@ -33,7 +34,7 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) { return nil, err } - sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) srv := &server.Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, @@ -120,7 +121,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} - sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) k, err := key.GeneratePrivateKey() if err != nil { diff --git a/server/config.go b/server/config.go index 51e467b2..278369a0 100644 --- a/server/config.go +++ b/server/config.go @@ -19,10 +19,10 @@ import ( "github.com/coreos/dex/email" "github.com/coreos/dex/refresh" "github.com/coreos/dex/repo" - "github.com/coreos/dex/session" + sessionmanager "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" useremail "github.com/coreos/dex/user/email" - "github.com/coreos/dex/user/manager" + usermanager "github.com/coreos/dex/user/manager" ) type ServerConfig struct { @@ -128,9 +128,9 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { } cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs) - sRepo := session.NewSessionRepo() - skRepo := session.NewSessionKeyRepo() - sm := session.NewSessionManager(sRepo, skRepo) + sRepo := db.NewSessionRepo(db.NewMemDB()) + skRepo := db.NewSessionKeyRepo(db.NewMemDB()) + sm := sessionmanager.NewSessionManager(sRepo, skRepo) userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile) if err != nil { @@ -142,7 +142,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { refTokRepo := refresh.NewRefreshTokenRepo() txnFactory := repo.InMemTransactionFactory - userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{}) + userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{}) srv.ClientIdentityRepo = ciRepo srv.KeySetRepo = kRepo srv.ConnectorConfigRepo = cfgRepo @@ -180,10 +180,10 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error { cfgRepo := db.NewConnectorConfigRepo(dbc) userRepo := db.NewUserRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc) - userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) + userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{}) refreshTokenRepo := db.NewRefreshTokenRepo(dbc) - sm := session.NewSessionManager(sRepo, skRepo) + sm := sessionmanager.NewSessionManager(sRepo, skRepo) srv.ClientIdentityRepo = ciRepo srv.KeySetRepo = kRepo diff --git a/server/http_test.go b/server/http_test.go index 7b4ec11a..0d2d5516 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -17,7 +17,8 @@ import ( "github.com/coreos/dex/client" "github.com/coreos/dex/connector" - "github.com/coreos/dex/session" + "github.com/coreos/dex/db" + "github.com/coreos/dex/session/manager" "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oidc" @@ -75,7 +76,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { } srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, - SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), + SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())), ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ @@ -198,7 +199,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { } srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, - SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), + SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())), ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ diff --git a/server/password.go b/server/password.go index 4d517f05..41b8780e 100644 --- a/server/password.go +++ b/server/password.go @@ -9,10 +9,10 @@ import ( "github.com/coreos/dex/client" "github.com/coreos/dex/pkg/log" - "github.com/coreos/dex/session" + sessionmanager "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" useremail "github.com/coreos/dex/user/email" - "github.com/coreos/dex/user/manager" + usermanager "github.com/coreos/dex/user/manager" ) type sendResetPasswordEmailData struct { @@ -28,7 +28,7 @@ type sendResetPasswordEmailData struct { type SendResetPasswordEmailHandler struct { tpl *template.Template emailer *useremail.UserEmailer - sm *session.SessionManager + sm *sessionmanager.SessionManager cr client.ClientIdentityRepo } @@ -182,7 +182,7 @@ type resetPasswordTemplateData struct { type ResetPasswordHandler struct { tpl *template.Template issuerURL url.URL - um *manager.UserManager + um *usermanager.UserManager keysFunc func() ([]key.PublicKey, error) } @@ -238,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() { cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext) if err != nil { switch err { - case manager.ErrorPasswordAlreadyChanged: + case usermanager.ErrorPasswordAlreadyChanged: r.data.Error = "Link Expired" r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email." r.data.DontShowForm = true diff --git a/server/register.go b/server/register.go index d4d4dc14..013c0b22 100644 --- a/server/register.go +++ b/server/register.go @@ -10,8 +10,9 @@ import ( "github.com/coreos/dex/connector" "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/session" + sessionmanager "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" - "github.com/coreos/dex/user/manager" + usermanager "github.com/coreos/dex/user/manager" "github.com/coreos/go-oidc/oidc" ) @@ -274,7 +275,7 @@ func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url. return &ru } -func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) { +func registerFromLocalConnector(userManager *usermanager.UserManager, sessionManager *sessionmanager.SessionManager, ses *session.Session, email, password string) (string, error) { userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID) if err != nil { return "", err @@ -289,7 +290,7 @@ func registerFromLocalConnector(userManager *manager.UserManager, sessionManager return userID, nil } -func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) { +func registerFromRemoteConnector(userManager *usermanager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) { if ses.Identity.ID == "" { return "", errors.New("No Identity found in session.") } diff --git a/server/server.go b/server/server.go index f9a81259..a8f61f42 100644 --- a/server/server.go +++ b/server/server.go @@ -22,10 +22,11 @@ import ( "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/refresh" "github.com/coreos/dex/session" + sessionmanager "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" usersapi "github.com/coreos/dex/user/api" useremail "github.com/coreos/dex/user/email" - "github.com/coreos/dex/user/manager" + usermanager "github.com/coreos/dex/user/manager" ) const ( @@ -57,7 +58,7 @@ type Server struct { IssuerURL url.URL KeyManager key.PrivateKeyManager KeySetRepo key.PrivateKeySetRepo - SessionManager *session.SessionManager + SessionManager *sessionmanager.SessionManager ClientIdentityRepo client.ClientIdentityRepo ConnectorConfigRepo connector.ConnectorConfigRepo Templates *template.Template @@ -69,7 +70,7 @@ type Server struct { HealthChecks []health.Checkable Connectors []connector.Connector UserRepo user.UserRepo - UserManager *manager.UserManager + UserManager *usermanager.UserManager PasswordInfoRepo user.PasswordInfoRepo RefreshTokenRepo refresh.RefreshTokenRepo UserEmailer *useremail.UserEmailer diff --git a/server/server_test.go b/server/server_test.go index 65e0162f..7bedc333 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -10,8 +10,9 @@ import ( "time" "github.com/coreos/dex/client" + "github.com/coreos/dex/db" "github.com/coreos/dex/refresh/refreshtest" - "github.com/coreos/dex/session" + "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/key" @@ -68,7 +69,7 @@ func (ss *StaticSigner) JWK() jose.JWK { return jose.JWK{} } -func staticGenerateCodeFunc(code string) session.GenerateCodeFunc { +func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc { return func() (string, error) { return code, nil } @@ -120,7 +121,7 @@ func TestServerProviderConfig(t *testing.T) { } func TestServerNewSession(t *testing.T) { - sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) srv := &Server{ SessionManager: sm, } @@ -197,7 +198,7 @@ func TestServerLogin(t *testing.T) { signer: &StaticSigner{sig: []byte("beer"), err: nil}, } - sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sm.GenerateCode = staticGenerateCodeFunc("fakecode") sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"}) if err != nil { @@ -245,7 +246,7 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) { km := &StaticKeyManager{ signer: &StaticSigner{sig: nil, err: errors.New("fail")}, } - sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, @@ -286,7 +287,7 @@ func TestServerLoginDisabledUser(t *testing.T) { signer: &StaticSigner{sig: []byte("beer"), err: nil}, } - sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sm.GenerateCode = staticGenerateCodeFunc("fakecode") sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"}) if err != nil { @@ -343,7 +344,7 @@ func TestServerCodeToken(t *testing.T) { km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, } - sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) userRepo, err := makeNewUserRepo() if err != nil { @@ -424,7 +425,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) { km := &StaticKeyManager{ signer: &StaticSigner{sig: []byte("beer"), err: nil}, } - sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, @@ -518,7 +519,7 @@ func TestServerTokenFail(t *testing.T) { } for i, tt := range tests { - sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sm.GenerateCode = func() (string, error) { return keyFixture, nil } sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope) diff --git a/server/testutil.go b/server/testutil.go index b3770121..c61bbfea 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -10,12 +10,13 @@ import ( "github.com/coreos/dex/client" "github.com/coreos/dex/connector" + "github.com/coreos/dex/db" "github.com/coreos/dex/email" "github.com/coreos/dex/repo" - "github.com/coreos/dex/session" + sessionmanager "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" useremail "github.com/coreos/dex/user/email" - "github.com/coreos/dex/user/manager" + usermanager "github.com/coreos/dex/user/manager" ) const ( @@ -75,13 +76,13 @@ var ( type testFixtures struct { srv *Server userRepo user.UserRepo - sessionManager *session.SessionManager + sessionManager *sessionmanager.SessionManager emailer *email.TemplatizedEmailer redirectURL url.URL clientIdentityRepo client.ClientIdentityRepo } -func sequentialGenerateCodeFunc() session.GenerateCodeFunc { +func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc { x := 0 return func() (string, error) { x += 1 @@ -113,9 +114,9 @@ func makeTestFixtures() (*testFixtures, error) { } connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs) - manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{}) + manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{}) - sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) + sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sessionManager.GenerateCode = sequentialGenerateCodeFunc() emailer, err := email.NewTemplatizedEmailerFromGlobs( diff --git a/session/manager.go b/session/manager/manager.go similarity index 72% rename from session/manager.go rename to session/manager/manager.go index 27935ea8..c0ac6d4a 100644 --- a/session/manager.go +++ b/session/manager/manager.go @@ -1,4 +1,4 @@ -package session +package manager import ( "crypto/rand" @@ -10,6 +10,7 @@ import ( "github.com/jonboulle/clockwork" + "github.com/coreos/dex/session" "github.com/coreos/go-oidc/oidc" ) @@ -27,11 +28,11 @@ func DefaultGenerateCode() (string, error) { return base64.URLEncoding.EncodeToString(b), nil } -func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager { +func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager { return &SessionManager{ GenerateCode: DefaultGenerateCode, Clock: clockwork.NewRealClock(), - ValidityWindow: DefaultSessionValidityWindow, + ValidityWindow: session.DefaultSessionValidityWindow, sessions: sRepo, keys: skRepo, } @@ -41,8 +42,8 @@ type SessionManager struct { GenerateCode GenerateCodeFunc Clock clockwork.Clock ValidityWindow time.Duration - sessions SessionRepo - keys SessionKeyRepo + sessions session.SessionRepo + keys session.SessionKeyRepo } func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) { @@ -52,10 +53,10 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r } now := m.Clock.Now() - s := Session{ + s := session.Session{ ConnectorID: connectorID, ID: sID, - State: SessionStateNew, + State: session.SessionStateNew, CreatedAt: now, ExpiresAt: now.Add(m.ValidityWindow), ClientID: clientID, @@ -80,11 +81,12 @@ func (m *SessionManager) NewSessionKey(sessionID string) (string, error) { return "", err } - k := SessionKey{ + k := session.SessionKey{ Key: key, SessionID: sessionID, } + sessionKeyValidityWindow := 10 * time.Minute //RFC6749 err = m.keys.Push(k, sessionKeyValidityWindow) if err != nil { return "", err @@ -97,7 +99,7 @@ func (m *SessionManager) ExchangeKey(key string) (string, error) { return m.keys.Pop(key) } -func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) { +func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) { s, err := m.sessions.Get(sessionID) if err != nil { return nil, err @@ -110,14 +112,14 @@ func (m *SessionManager) getSessionInState(sessionID string, state SessionState) return s, nil } -func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) { - s, err := m.getSessionInState(sessionID, SessionStateNew) +func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) { + s, err := m.getSessionInState(sessionID, session.SessionStateNew) if err != nil { return nil, err } s.Identity = ident - s.State = SessionStateRemoteAttached + s.State = session.SessionStateRemoteAttached if err = m.sessions.Update(*s); err != nil { return nil, err @@ -126,14 +128,14 @@ func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Ident return s, nil } -func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) { - s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached) +func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) { + s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached) if err != nil { return nil, err } s.UserID = userID - s.State = SessionStateIdentified + s.State = session.SessionStateIdentified if err = m.sessions.Update(*s); err != nil { return nil, err @@ -142,13 +144,13 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, return s, nil } -func (m *SessionManager) Kill(sessionID string) (*Session, error) { +func (m *SessionManager) Kill(sessionID string) (*session.Session, error) { s, err := m.sessions.Get(sessionID) if err != nil { return nil, err } - s.State = SessionStateDead + s.State = session.SessionStateDead if err = m.sessions.Update(*s); err != nil { return nil, err @@ -157,6 +159,6 @@ func (m *SessionManager) Kill(sessionID string) (*Session, error) { return s, nil } -func (m *SessionManager) Get(sessionID string) (*Session, error) { +func (m *SessionManager) Get(sessionID string) (*session.Session, error) { return m.sessions.Get(sessionID) } diff --git a/session/manager_test.go b/session/manager/manager_test.go similarity index 84% rename from session/manager_test.go rename to session/manager/manager_test.go index 4e925ec1..7e55486e 100644 --- a/session/manager_test.go +++ b/session/manager/manager_test.go @@ -1,9 +1,11 @@ -package session +package manager import ( "net/url" "testing" + "github.com/coreos/dex/db" + "github.com/coreos/dex/session" "github.com/coreos/go-oidc/oidc" ) @@ -13,8 +15,13 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc { } } +func newManager(t *testing.T) *SessionManager { + dbMap := db.NewMemDB() + return NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap)) +} + func TestSessionManagerNewSession(t *testing.T) { - sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) + sm := newManager(t) sm.GenerateCode = staticGenerateCodeFunc("boo") got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { @@ -26,7 +33,7 @@ func TestSessionManagerNewSession(t *testing.T) { } func TestSessionAttachRemoteIdentityTwice(t *testing.T) { - sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) + sm := newManager(t) sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -43,7 +50,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) { } func TestSessionManagerExchangeKey(t *testing.T) { - sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) + sm := newManager(t) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -68,8 +75,8 @@ func TestSessionManagerExchangeKey(t *testing.T) { } func TestSessionManagerGetSessionInStateNoExist(t *testing.T) { - sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) - ses, err := sm.getSessionInState("123", SessionStateNew) + sm := newManager(t) + ses, err := sm.getSessionInState("123", session.SessionStateNew) if err == nil { t.Errorf("Expected non-nil error") } @@ -79,12 +86,12 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) { } func TestSessionManagerGetSessionInStateWrongState(t *testing.T) { - sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) + sm := newManager(t) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } - ses, err := sm.getSessionInState(sessionID, SessionStateDead) + ses, err := sm.getSessionInState(sessionID, session.SessionStateDead) if err == nil { t.Errorf("Expected non-nil error") } @@ -94,7 +101,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) { } func TestSessionManagerKill(t *testing.T) { - sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) + sm := newManager(t) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) diff --git a/session/repo.go b/session/repo.go index 88ec9b23..9ae86276 100644 --- a/session/repo.go +++ b/session/repo.go @@ -1,11 +1,6 @@ package session -import ( - "errors" - "time" - - "github.com/jonboulle/clockwork" -) +import "time" type SessionRepo interface { Get(string) (*Session, error) @@ -17,87 +12,3 @@ type SessionKeyRepo interface { Push(SessionKey, time.Duration) error Pop(string) (string, error) } - -func NewSessionRepo() SessionRepo { - return NewSessionRepoWithClock(clockwork.NewRealClock()) -} - -func NewSessionRepoWithClock(clock clockwork.Clock) SessionRepo { - return &memSessionRepo{ - store: make(map[string]Session), - clock: clock, - } -} - -type memSessionRepo struct { - store map[string]Session - clock clockwork.Clock -} - -func (m *memSessionRepo) Get(sessionID string) (*Session, error) { - s, ok := m.store[sessionID] - if !ok || s.ExpiresAt.Before(m.clock.Now()) { - return nil, errors.New("unrecognized ID") - } - return &s, nil -} - -func (m *memSessionRepo) Create(s Session) error { - if _, ok := m.store[s.ID]; ok { - return errors.New("ID exists") - } - - m.store[s.ID] = s - return nil -} - -func (m *memSessionRepo) Update(s Session) error { - if _, ok := m.store[s.ID]; !ok { - return errors.New("unrecognized ID") - } - m.store[s.ID] = s - return nil -} - -type expiringSessionKey struct { - SessionKey - expiresAt time.Time -} - -func NewSessionKeyRepo() SessionKeyRepo { - return NewSessionKeyRepoWithClock(clockwork.NewRealClock()) -} - -func NewSessionKeyRepoWithClock(clock clockwork.Clock) SessionKeyRepo { - return &memSessionKeyRepo{ - store: make(map[string]expiringSessionKey), - clock: clock, - } -} - -type memSessionKeyRepo struct { - store map[string]expiringSessionKey - clock clockwork.Clock -} - -func (m *memSessionKeyRepo) Pop(key string) (string, error) { - esk, ok := m.store[key] - if !ok { - return "", errors.New("unrecognized key") - } - defer delete(m.store, key) - - if esk.expiresAt.Before(m.clock.Now()) { - return "", errors.New("expired key") - } - - return esk.SessionKey.SessionID, nil -} - -func (m *memSessionKeyRepo) Push(sk SessionKey, ttl time.Duration) error { - m.store[sk.Key] = expiringSessionKey{ - SessionKey: sk, - expiresAt: m.clock.Now().Add(ttl), - } - return nil -} diff --git a/test b/test index 8b01cd42..0ea61d78 100755 --- a/test +++ b/test @@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"} source ./build -TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email admin" +TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session session/manager user user/api user/manager email admin" FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log" # user has not provided PKG override