*: remove in memory session repos
Move manager to it's own package so it can import db. Move all references to the in memory session repos to use sqlite3.
This commit is contained in:
parent
5052d8007f
commit
7bac93aa20
15 changed files with 99 additions and 161 deletions
12
db/conn.go
12
db/conn.go
|
@ -101,3 +101,15 @@ func rollback(tx *gorp.Transaction) {
|
||||||
log.Errorf("unable to rollback: %v", err)
|
log.Errorf("unable to rollback: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewMemDB creates a new in memory sqlite3 database.
|
||||||
|
func NewMemDB() *gorp.DbMap {
|
||||||
|
dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"})
|
||||||
|
if err != nil {
|
||||||
|
panic("Failed to create in memory database: " + err.Error())
|
||||||
|
}
|
||||||
|
if _, err := MigrateToLatest(dbMap); err != nil {
|
||||||
|
panic("In memory database migration failed: " + err.Error())
|
||||||
|
}
|
||||||
|
return dbMap
|
||||||
|
}
|
||||||
|
|
|
@ -65,7 +65,7 @@ CREATE TABLE session (
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE session_key (
|
CREATE TABLE session_key (
|
||||||
key text NOT NULL UNIQUE,
|
key text NOT NULL,
|
||||||
session_id text,
|
session_id text,
|
||||||
expires_at bigint,
|
expires_at bigint,
|
||||||
stale integer
|
stale integer
|
||||||
|
|
|
@ -15,7 +15,7 @@ import (
|
||||||
func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
|
func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
|
||||||
clock := clockwork.NewFakeClock()
|
clock := clockwork.NewFakeClock()
|
||||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||||
return session.NewSessionRepoWithClock(clock), clock
|
return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock
|
||||||
}
|
}
|
||||||
dbMap := connect(t)
|
dbMap := connect(t)
|
||||||
return db.NewSessionRepoWithClock(dbMap, clock), clock
|
return db.NewSessionRepoWithClock(dbMap, clock), clock
|
||||||
|
@ -24,7 +24,7 @@ func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
|
||||||
func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
|
func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
|
||||||
clock := clockwork.NewFakeClock()
|
clock := clockwork.NewFakeClock()
|
||||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||||
return session.NewSessionKeyRepoWithClock(clock), clock
|
return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock
|
||||||
}
|
}
|
||||||
dbMap := connect(t)
|
dbMap := connect(t)
|
||||||
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
|
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
|
||||||
|
|
|
@ -10,10 +10,11 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
|
"github.com/coreos/dex/db"
|
||||||
phttp "github.com/coreos/dex/pkg/http"
|
phttp "github.com/coreos/dex/pkg/http"
|
||||||
"github.com/coreos/dex/refresh/refreshtest"
|
"github.com/coreos/dex/refresh/refreshtest"
|
||||||
"github.com/coreos/dex/server"
|
"github.com/coreos/dex/server"
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/go-oidc/jose"
|
"github.com/coreos/go-oidc/jose"
|
||||||
"github.com/coreos/go-oidc/key"
|
"github.com/coreos/go-oidc/key"
|
||||||
|
@ -33,7 +34,7 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
srv := &server.Server{
|
srv := &server.Server{
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
|
@ -120,7 +121,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
||||||
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
|
||||||
|
|
||||||
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
||||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
|
|
||||||
k, err := key.GeneratePrivateKey()
|
k, err := key.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -19,10 +19,10 @@ import (
|
||||||
"github.com/coreos/dex/email"
|
"github.com/coreos/dex/email"
|
||||||
"github.com/coreos/dex/refresh"
|
"github.com/coreos/dex/refresh"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
"github.com/coreos/dex/session"
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
useremail "github.com/coreos/dex/user/email"
|
useremail "github.com/coreos/dex/user/email"
|
||||||
"github.com/coreos/dex/user/manager"
|
usermanager "github.com/coreos/dex/user/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
|
@ -128,9 +128,9 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
||||||
}
|
}
|
||||||
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
|
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
|
||||||
|
|
||||||
sRepo := session.NewSessionRepo()
|
sRepo := db.NewSessionRepo(db.NewMemDB())
|
||||||
skRepo := session.NewSessionKeyRepo()
|
skRepo := db.NewSessionKeyRepo(db.NewMemDB())
|
||||||
sm := session.NewSessionManager(sRepo, skRepo)
|
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||||
|
|
||||||
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
|
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -142,7 +142,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
||||||
refTokRepo := refresh.NewRefreshTokenRepo()
|
refTokRepo := refresh.NewRefreshTokenRepo()
|
||||||
|
|
||||||
txnFactory := repo.InMemTransactionFactory
|
txnFactory := repo.InMemTransactionFactory
|
||||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
|
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
||||||
srv.ClientIdentityRepo = ciRepo
|
srv.ClientIdentityRepo = ciRepo
|
||||||
srv.KeySetRepo = kRepo
|
srv.KeySetRepo = kRepo
|
||||||
srv.ConnectorConfigRepo = cfgRepo
|
srv.ConnectorConfigRepo = cfgRepo
|
||||||
|
@ -180,10 +180,10 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
||||||
cfgRepo := db.NewConnectorConfigRepo(dbc)
|
cfgRepo := db.NewConnectorConfigRepo(dbc)
|
||||||
userRepo := db.NewUserRepo(dbc)
|
userRepo := db.NewUserRepo(dbc)
|
||||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
|
||||||
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
||||||
|
|
||||||
sm := session.NewSessionManager(sRepo, skRepo)
|
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||||
|
|
||||||
srv.ClientIdentityRepo = ciRepo
|
srv.ClientIdentityRepo = ciRepo
|
||||||
srv.KeySetRepo = kRepo
|
srv.KeySetRepo = kRepo
|
||||||
|
|
|
@ -17,7 +17,8 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/db"
|
||||||
|
"github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/go-oidc/jose"
|
"github.com/coreos/go-oidc/jose"
|
||||||
"github.com/coreos/go-oidc/oauth2"
|
"github.com/coreos/go-oidc/oauth2"
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
|
@ -75,7 +76,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
}
|
}
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
|
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||||
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||||
oidc.ClientIdentity{
|
oidc.ClientIdentity{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
|
@ -198,7 +199,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
||||||
}
|
}
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
|
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||||
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
|
||||||
oidc.ClientIdentity{
|
oidc.ClientIdentity{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
|
|
|
@ -9,10 +9,10 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/session"
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
useremail "github.com/coreos/dex/user/email"
|
useremail "github.com/coreos/dex/user/email"
|
||||||
"github.com/coreos/dex/user/manager"
|
usermanager "github.com/coreos/dex/user/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sendResetPasswordEmailData struct {
|
type sendResetPasswordEmailData struct {
|
||||||
|
@ -28,7 +28,7 @@ type sendResetPasswordEmailData struct {
|
||||||
type SendResetPasswordEmailHandler struct {
|
type SendResetPasswordEmailHandler struct {
|
||||||
tpl *template.Template
|
tpl *template.Template
|
||||||
emailer *useremail.UserEmailer
|
emailer *useremail.UserEmailer
|
||||||
sm *session.SessionManager
|
sm *sessionmanager.SessionManager
|
||||||
cr client.ClientIdentityRepo
|
cr client.ClientIdentityRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ type resetPasswordTemplateData struct {
|
||||||
type ResetPasswordHandler struct {
|
type ResetPasswordHandler struct {
|
||||||
tpl *template.Template
|
tpl *template.Template
|
||||||
issuerURL url.URL
|
issuerURL url.URL
|
||||||
um *manager.UserManager
|
um *usermanager.UserManager
|
||||||
keysFunc func() ([]key.PublicKey, error)
|
keysFunc func() ([]key.PublicKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -238,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
|
||||||
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
|
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err {
|
switch err {
|
||||||
case manager.ErrorPasswordAlreadyChanged:
|
case usermanager.ErrorPasswordAlreadyChanged:
|
||||||
r.data.Error = "Link Expired"
|
r.data.Error = "Link Expired"
|
||||||
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
|
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
|
||||||
r.data.DontShowForm = true
|
r.data.DontShowForm = true
|
||||||
|
|
|
@ -10,8 +10,9 @@ import (
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session"
|
||||||
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/dex/user/manager"
|
usermanager "github.com/coreos/dex/user/manager"
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -274,7 +275,7 @@ func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url.
|
||||||
return &ru
|
return &ru
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
|
func registerFromLocalConnector(userManager *usermanager.UserManager, sessionManager *sessionmanager.SessionManager, ses *session.Session, email, password string) (string, error) {
|
||||||
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
|
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -289,7 +290,7 @@ func registerFromLocalConnector(userManager *manager.UserManager, sessionManager
|
||||||
return userID, nil
|
return userID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
|
func registerFromRemoteConnector(userManager *usermanager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
|
||||||
if ses.Identity.ID == "" {
|
if ses.Identity.ID == "" {
|
||||||
return "", errors.New("No Identity found in session.")
|
return "", errors.New("No Identity found in session.")
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,10 +22,11 @@ import (
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/refresh"
|
"github.com/coreos/dex/refresh"
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session"
|
||||||
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
usersapi "github.com/coreos/dex/user/api"
|
usersapi "github.com/coreos/dex/user/api"
|
||||||
useremail "github.com/coreos/dex/user/email"
|
useremail "github.com/coreos/dex/user/email"
|
||||||
"github.com/coreos/dex/user/manager"
|
usermanager "github.com/coreos/dex/user/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -57,7 +58,7 @@ type Server struct {
|
||||||
IssuerURL url.URL
|
IssuerURL url.URL
|
||||||
KeyManager key.PrivateKeyManager
|
KeyManager key.PrivateKeyManager
|
||||||
KeySetRepo key.PrivateKeySetRepo
|
KeySetRepo key.PrivateKeySetRepo
|
||||||
SessionManager *session.SessionManager
|
SessionManager *sessionmanager.SessionManager
|
||||||
ClientIdentityRepo client.ClientIdentityRepo
|
ClientIdentityRepo client.ClientIdentityRepo
|
||||||
ConnectorConfigRepo connector.ConnectorConfigRepo
|
ConnectorConfigRepo connector.ConnectorConfigRepo
|
||||||
Templates *template.Template
|
Templates *template.Template
|
||||||
|
@ -69,7 +70,7 @@ type Server struct {
|
||||||
HealthChecks []health.Checkable
|
HealthChecks []health.Checkable
|
||||||
Connectors []connector.Connector
|
Connectors []connector.Connector
|
||||||
UserRepo user.UserRepo
|
UserRepo user.UserRepo
|
||||||
UserManager *manager.UserManager
|
UserManager *usermanager.UserManager
|
||||||
PasswordInfoRepo user.PasswordInfoRepo
|
PasswordInfoRepo user.PasswordInfoRepo
|
||||||
RefreshTokenRepo refresh.RefreshTokenRepo
|
RefreshTokenRepo refresh.RefreshTokenRepo
|
||||||
UserEmailer *useremail.UserEmailer
|
UserEmailer *useremail.UserEmailer
|
||||||
|
|
|
@ -10,8 +10,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/refresh/refreshtest"
|
"github.com/coreos/dex/refresh/refreshtest"
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/go-oidc/jose"
|
"github.com/coreos/go-oidc/jose"
|
||||||
"github.com/coreos/go-oidc/key"
|
"github.com/coreos/go-oidc/key"
|
||||||
|
@ -68,7 +69,7 @@ func (ss *StaticSigner) JWK() jose.JWK {
|
||||||
return jose.JWK{}
|
return jose.JWK{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func staticGenerateCodeFunc(code string) session.GenerateCodeFunc {
|
func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
|
||||||
return func() (string, error) {
|
return func() (string, error) {
|
||||||
return code, nil
|
return code, nil
|
||||||
}
|
}
|
||||||
|
@ -120,7 +121,7 @@ func TestServerProviderConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerNewSession(t *testing.T) {
|
func TestServerNewSession(t *testing.T) {
|
||||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
}
|
}
|
||||||
|
@ -197,7 +198,7 @@ func TestServerLogin(t *testing.T) {
|
||||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
|
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
|
||||||
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
|
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -245,7 +246,7 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
||||||
}
|
}
|
||||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
|
@ -286,7 +287,7 @@ func TestServerLoginDisabledUser(t *testing.T) {
|
||||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
|
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
|
||||||
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
|
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -343,7 +344,7 @@ func TestServerCodeToken(t *testing.T) {
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||||
}
|
}
|
||||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
|
|
||||||
userRepo, err := makeNewUserRepo()
|
userRepo, err := makeNewUserRepo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -424,7 +425,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||||
}
|
}
|
||||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
|
@ -518,7 +519,7 @@ func TestServerTokenFail(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
sm.GenerateCode = func() (string, error) { return keyFixture, nil }
|
sm.GenerateCode = func() (string, error) { return keyFixture, nil }
|
||||||
|
|
||||||
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
|
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
|
||||||
|
|
|
@ -10,12 +10,13 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/email"
|
"github.com/coreos/dex/email"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
"github.com/coreos/dex/session"
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
useremail "github.com/coreos/dex/user/email"
|
useremail "github.com/coreos/dex/user/email"
|
||||||
"github.com/coreos/dex/user/manager"
|
usermanager "github.com/coreos/dex/user/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -75,13 +76,13 @@ var (
|
||||||
type testFixtures struct {
|
type testFixtures struct {
|
||||||
srv *Server
|
srv *Server
|
||||||
userRepo user.UserRepo
|
userRepo user.UserRepo
|
||||||
sessionManager *session.SessionManager
|
sessionManager *sessionmanager.SessionManager
|
||||||
emailer *email.TemplatizedEmailer
|
emailer *email.TemplatizedEmailer
|
||||||
redirectURL url.URL
|
redirectURL url.URL
|
||||||
clientIdentityRepo client.ClientIdentityRepo
|
clientIdentityRepo client.ClientIdentityRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
|
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
|
||||||
x := 0
|
x := 0
|
||||||
return func() (string, error) {
|
return func() (string, error) {
|
||||||
x += 1
|
x += 1
|
||||||
|
@ -113,9 +114,9 @@ func makeTestFixtures() (*testFixtures, error) {
|
||||||
}
|
}
|
||||||
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
||||||
|
|
||||||
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{})
|
||||||
|
|
||||||
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||||
|
|
||||||
emailer, err := email.NewTemplatizedEmailerFromGlobs(
|
emailer, err := email.NewTemplatizedEmailerFromGlobs(
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package session
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/session"
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,11 +28,11 @@ func DefaultGenerateCode() (string, error) {
|
||||||
return base64.URLEncoding.EncodeToString(b), nil
|
return base64.URLEncoding.EncodeToString(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager {
|
func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
|
||||||
return &SessionManager{
|
return &SessionManager{
|
||||||
GenerateCode: DefaultGenerateCode,
|
GenerateCode: DefaultGenerateCode,
|
||||||
Clock: clockwork.NewRealClock(),
|
Clock: clockwork.NewRealClock(),
|
||||||
ValidityWindow: DefaultSessionValidityWindow,
|
ValidityWindow: session.DefaultSessionValidityWindow,
|
||||||
sessions: sRepo,
|
sessions: sRepo,
|
||||||
keys: skRepo,
|
keys: skRepo,
|
||||||
}
|
}
|
||||||
|
@ -41,8 +42,8 @@ type SessionManager struct {
|
||||||
GenerateCode GenerateCodeFunc
|
GenerateCode GenerateCodeFunc
|
||||||
Clock clockwork.Clock
|
Clock clockwork.Clock
|
||||||
ValidityWindow time.Duration
|
ValidityWindow time.Duration
|
||||||
sessions SessionRepo
|
sessions session.SessionRepo
|
||||||
keys SessionKeyRepo
|
keys session.SessionKeyRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
|
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
|
||||||
|
@ -52,10 +53,10 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r
|
||||||
}
|
}
|
||||||
|
|
||||||
now := m.Clock.Now()
|
now := m.Clock.Now()
|
||||||
s := Session{
|
s := session.Session{
|
||||||
ConnectorID: connectorID,
|
ConnectorID: connectorID,
|
||||||
ID: sID,
|
ID: sID,
|
||||||
State: SessionStateNew,
|
State: session.SessionStateNew,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
ExpiresAt: now.Add(m.ValidityWindow),
|
ExpiresAt: now.Add(m.ValidityWindow),
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
|
@ -80,11 +81,12 @@ func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
k := SessionKey{
|
k := session.SessionKey{
|
||||||
Key: key,
|
Key: key,
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
|
||||||
err = m.keys.Push(k, sessionKeyValidityWindow)
|
err = m.keys.Push(k, sessionKeyValidityWindow)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -97,7 +99,7 @@ func (m *SessionManager) ExchangeKey(key string) (string, error) {
|
||||||
return m.keys.Pop(key)
|
return m.keys.Pop(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) {
|
func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
|
||||||
s, err := m.sessions.Get(sessionID)
|
s, err := m.sessions.Get(sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -110,14 +112,14 @@ func (m *SessionManager) getSessionInState(sessionID string, state SessionState)
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) {
|
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
|
||||||
s, err := m.getSessionInState(sessionID, SessionStateNew)
|
s, err := m.getSessionInState(sessionID, session.SessionStateNew)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Identity = ident
|
s.Identity = ident
|
||||||
s.State = SessionStateRemoteAttached
|
s.State = session.SessionStateRemoteAttached
|
||||||
|
|
||||||
if err = m.sessions.Update(*s); err != nil {
|
if err = m.sessions.Update(*s); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -126,14 +128,14 @@ func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Ident
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) {
|
func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
|
||||||
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached)
|
s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.UserID = userID
|
s.UserID = userID
|
||||||
s.State = SessionStateIdentified
|
s.State = session.SessionStateIdentified
|
||||||
|
|
||||||
if err = m.sessions.Update(*s); err != nil {
|
if err = m.sessions.Update(*s); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -142,13 +144,13 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session,
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) Kill(sessionID string) (*Session, error) {
|
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
|
||||||
s, err := m.sessions.Get(sessionID)
|
s, err := m.sessions.Get(sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.State = SessionStateDead
|
s.State = session.SessionStateDead
|
||||||
|
|
||||||
if err = m.sessions.Update(*s); err != nil {
|
if err = m.sessions.Update(*s); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -157,6 +159,6 @@ func (m *SessionManager) Kill(sessionID string) (*Session, error) {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) Get(sessionID string) (*Session, error) {
|
func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
|
||||||
return m.sessions.Get(sessionID)
|
return m.sessions.Get(sessionID)
|
||||||
}
|
}
|
|
@ -1,9 +1,11 @@
|
||||||
package session
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/db"
|
||||||
|
"github.com/coreos/dex/session"
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,8 +15,13 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newManager(t *testing.T) *SessionManager {
|
||||||
|
dbMap := db.NewMemDB()
|
||||||
|
return NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
|
||||||
|
}
|
||||||
|
|
||||||
func TestSessionManagerNewSession(t *testing.T) {
|
func TestSessionManagerNewSession(t *testing.T) {
|
||||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
sm := newManager(t)
|
||||||
sm.GenerateCode = staticGenerateCodeFunc("boo")
|
sm.GenerateCode = staticGenerateCodeFunc("boo")
|
||||||
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -26,7 +33,7 @@ func TestSessionManagerNewSession(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
|
func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
|
||||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
sm := newManager(t)
|
||||||
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
@ -43,7 +50,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionManagerExchangeKey(t *testing.T) {
|
func TestSessionManagerExchangeKey(t *testing.T) {
|
||||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
sm := newManager(t)
|
||||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
@ -68,8 +75,8 @@ func TestSessionManagerExchangeKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
|
func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
|
||||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
sm := newManager(t)
|
||||||
ses, err := sm.getSessionInState("123", SessionStateNew)
|
ses, err := sm.getSessionInState("123", session.SessionStateNew)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected non-nil error")
|
t.Errorf("Expected non-nil error")
|
||||||
}
|
}
|
||||||
|
@ -79,12 +86,12 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
|
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
|
||||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
sm := newManager(t)
|
||||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
ses, err := sm.getSessionInState(sessionID, SessionStateDead)
|
ses, err := sm.getSessionInState(sessionID, session.SessionStateDead)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected non-nil error")
|
t.Errorf("Expected non-nil error")
|
||||||
}
|
}
|
||||||
|
@ -94,7 +101,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionManagerKill(t *testing.T) {
|
func TestSessionManagerKill(t *testing.T) {
|
||||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
sm := newManager(t)
|
||||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
|
@ -1,11 +1,6 @@
|
||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import "time"
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jonboulle/clockwork"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SessionRepo interface {
|
type SessionRepo interface {
|
||||||
Get(string) (*Session, error)
|
Get(string) (*Session, error)
|
||||||
|
@ -17,87 +12,3 @@ type SessionKeyRepo interface {
|
||||||
Push(SessionKey, time.Duration) error
|
Push(SessionKey, time.Duration) error
|
||||||
Pop(string) (string, error)
|
Pop(string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSessionRepo() SessionRepo {
|
|
||||||
return NewSessionRepoWithClock(clockwork.NewRealClock())
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSessionRepoWithClock(clock clockwork.Clock) SessionRepo {
|
|
||||||
return &memSessionRepo{
|
|
||||||
store: make(map[string]Session),
|
|
||||||
clock: clock,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type memSessionRepo struct {
|
|
||||||
store map[string]Session
|
|
||||||
clock clockwork.Clock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memSessionRepo) Get(sessionID string) (*Session, error) {
|
|
||||||
s, ok := m.store[sessionID]
|
|
||||||
if !ok || s.ExpiresAt.Before(m.clock.Now()) {
|
|
||||||
return nil, errors.New("unrecognized ID")
|
|
||||||
}
|
|
||||||
return &s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memSessionRepo) Create(s Session) error {
|
|
||||||
if _, ok := m.store[s.ID]; ok {
|
|
||||||
return errors.New("ID exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
m.store[s.ID] = s
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memSessionRepo) Update(s Session) error {
|
|
||||||
if _, ok := m.store[s.ID]; !ok {
|
|
||||||
return errors.New("unrecognized ID")
|
|
||||||
}
|
|
||||||
m.store[s.ID] = s
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type expiringSessionKey struct {
|
|
||||||
SessionKey
|
|
||||||
expiresAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSessionKeyRepo() SessionKeyRepo {
|
|
||||||
return NewSessionKeyRepoWithClock(clockwork.NewRealClock())
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSessionKeyRepoWithClock(clock clockwork.Clock) SessionKeyRepo {
|
|
||||||
return &memSessionKeyRepo{
|
|
||||||
store: make(map[string]expiringSessionKey),
|
|
||||||
clock: clock,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type memSessionKeyRepo struct {
|
|
||||||
store map[string]expiringSessionKey
|
|
||||||
clock clockwork.Clock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memSessionKeyRepo) Pop(key string) (string, error) {
|
|
||||||
esk, ok := m.store[key]
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("unrecognized key")
|
|
||||||
}
|
|
||||||
defer delete(m.store, key)
|
|
||||||
|
|
||||||
if esk.expiresAt.Before(m.clock.Now()) {
|
|
||||||
return "", errors.New("expired key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return esk.SessionKey.SessionID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memSessionKeyRepo) Push(sk SessionKey, ttl time.Duration) error {
|
|
||||||
m.store[sk.Key] = expiringSessionKey{
|
|
||||||
SessionKey: sk,
|
|
||||||
expiresAt: m.clock.Now().Add(ttl),
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
2
test
2
test
|
@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
|
||||||
|
|
||||||
source ./build
|
source ./build
|
||||||
|
|
||||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email admin"
|
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session session/manager user user/api user/manager email admin"
|
||||||
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
||||||
|
|
||||||
# user has not provided PKG override
|
# user has not provided PKG override
|
||||||
|
|
Reference in a new issue