forked from mystiq/dex
user/manager: connector must exists when creating remote identity
Add ConnectorConfigRepo to UserManager. When trying to create a RemoteIdentity, validate that the connector ID exists. Fixes #198
This commit is contained in:
parent
d518447282
commit
f43655a8c3
12 changed files with 183 additions and 14 deletions
|
@ -99,8 +99,9 @@ func main() {
|
||||||
|
|
||||||
userRepo := db.NewUserRepo(dbc)
|
userRepo := db.NewUserRepo(dbc)
|
||||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||||
|
connCfgRepo := db.NewConnectorConfigRepo(dbc)
|
||||||
userManager := manager.NewUserManager(userRepo,
|
userManager := manager.NewUserManager(userRepo,
|
||||||
pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||||
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
|
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
|
||||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
|
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/repo"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) {
|
func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) {
|
||||||
|
@ -41,6 +43,19 @@ type memConnectorConfigRepo struct {
|
||||||
configs []ConnectorConfig
|
configs []ConnectorConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo {
|
||||||
|
return &memConnectorConfigRepo{configs: cfgs}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
|
func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
|
||||||
return r.configs, nil
|
return r.configs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) {
|
||||||
|
for _, cfg := range r.configs {
|
||||||
|
if cfg.ConnectorID() == id {
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrorNotFound
|
||||||
|
}
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
package connector
|
package connector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/repo"
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
"github.com/coreos/pkg/health"
|
"github.com/coreos/pkg/health"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrorNotFound = errors.New("connector not found in repository")
|
||||||
|
|
||||||
type Connector interface {
|
type Connector interface {
|
||||||
ID() string
|
ID() string
|
||||||
LoginURL(sessionKey, prompt string) (string, error)
|
LoginURL(sessionKey, prompt string) (string, error)
|
||||||
|
@ -34,4 +38,5 @@ type ConnectorConfig interface {
|
||||||
|
|
||||||
type ConnectorConfigRepo interface {
|
type ConnectorConfigRepo interface {
|
||||||
All() ([]ConnectorConfig, error)
|
All() ([]ConnectorConfig, error)
|
||||||
|
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -9,6 +10,7 @@ import (
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
|
"github.com/coreos/dex/repo"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -91,6 +93,18 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
||||||
return cfgs, nil
|
return cfgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
|
||||||
|
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
||||||
|
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
|
||||||
|
var c connectorConfigModel
|
||||||
|
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, connector.ErrorNotFound
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.ConnectorConfig()
|
||||||
|
}
|
||||||
|
|
||||||
func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
||||||
insert := make([]interface{}, len(cfgs))
|
insert := make([]interface{}, len(cfgs))
|
||||||
for i, cfg := range cfgs {
|
for i, cfg := range cfgs {
|
||||||
|
@ -119,3 +133,15 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
||||||
|
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||||
|
if tx == nil {
|
||||||
|
return r.dbMap
|
||||||
|
}
|
||||||
|
|
||||||
|
gorpTx, ok := tx.(*gorp.Transaction)
|
||||||
|
if !ok {
|
||||||
|
panic("wrong kind of transaction passed to a DB repo")
|
||||||
|
}
|
||||||
|
return gorpTx
|
||||||
|
}
|
||||||
|
|
71
functional/repo/connector_repo_test.go
Normal file
71
functional/repo/connector_repo_test.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package repo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/connector"
|
||||||
|
"github.com/coreos/dex/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo
|
||||||
|
|
||||||
|
var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
|
||||||
|
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
|
||||||
|
} else {
|
||||||
|
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory {
|
||||||
|
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
|
||||||
|
dbMap := initDB(dsn)
|
||||||
|
|
||||||
|
repo := db.NewConnectorConfigRepo(dbMap)
|
||||||
|
if err := repo.Set(cfgs); err != nil {
|
||||||
|
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
|
||||||
|
}
|
||||||
|
return repo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectorConfigRepoGetByID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
cfgs []connector.ConnectorConfig
|
||||||
|
id string
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
cfgs: []connector.ConnectorConfig{
|
||||||
|
&connector.LocalConnectorConfig{ID: "local"},
|
||||||
|
},
|
||||||
|
id: "local",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
cfgs: []connector.ConnectorConfig{
|
||||||
|
&connector.LocalConnectorConfig{ID: "local1"},
|
||||||
|
&connector.LocalConnectorConfig{ID: "local2"},
|
||||||
|
},
|
||||||
|
id: "local2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
cfgs: []connector.ConnectorConfig{
|
||||||
|
&connector.LocalConnectorConfig{ID: "local1"},
|
||||||
|
&connector.LocalConnectorConfig{ID: "local2"},
|
||||||
|
},
|
||||||
|
id: "foo",
|
||||||
|
err: connector.ErrorNotFound,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs)
|
||||||
|
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
|
||||||
|
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/coreos/go-oidc/key"
|
"github.com/coreos/go-oidc/key"
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/dex/user/manager"
|
"github.com/coreos/dex/user/manager"
|
||||||
|
@ -47,7 +48,10 @@ func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.Pas
|
||||||
ur := user.NewUserRepoFromUsers(users)
|
ur := user.NewUserRepoFromUsers(users)
|
||||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
|
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
|
||||||
|
|
||||||
um := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||||
|
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
|
||||||
|
)
|
||||||
|
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||||
um.Clock = clock
|
um.Clock = clock
|
||||||
return ur, pwr, um
|
return ur, pwr, um
|
||||||
}
|
}
|
||||||
|
|
|
@ -134,7 +134,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
||||||
refTokRepo := refresh.NewRefreshTokenRepo()
|
refTokRepo := refresh.NewRefreshTokenRepo()
|
||||||
|
|
||||||
txnFactory := repo.InMemTransactionFactory
|
txnFactory := repo.InMemTransactionFactory
|
||||||
userManager := manager.NewUserManager(userRepo, pwiRepo, txnFactory, manager.ManagerOptions{})
|
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
|
||||||
srv.ClientIdentityRepo = ciRepo
|
srv.ClientIdentityRepo = ciRepo
|
||||||
srv.KeySetRepo = kRepo
|
srv.KeySetRepo = kRepo
|
||||||
srv.ConnectorConfigRepo = cfgRepo
|
srv.ConnectorConfigRepo = cfgRepo
|
||||||
|
@ -172,7 +172,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
||||||
cfgRepo := db.NewConnectorConfigRepo(dbc)
|
cfgRepo := db.NewConnectorConfigRepo(dbc)
|
||||||
userRepo := db.NewUserRepo(dbc)
|
userRepo := db.NewUserRepo(dbc)
|
||||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||||
userManager := manager.NewUserManager(userRepo, pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||||
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
||||||
|
|
||||||
sm := session.NewSessionManager(sRepo, skRepo)
|
sm := session.NewSessionManager(sRepo, skRepo)
|
||||||
|
|
|
@ -92,7 +92,6 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
|
||||||
func makeTestFixtures() (*testFixtures, error) {
|
func makeTestFixtures() (*testFixtures, error) {
|
||||||
userRepo := user.NewUserRepoFromUsers(testUsers)
|
userRepo := user.NewUserRepoFromUsers(testUsers)
|
||||||
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
|
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
|
||||||
manager := manager.NewUserManager(userRepo, pwRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
|
||||||
|
|
||||||
connConfigs := []connector.ConnectorConfig{
|
connConfigs := []connector.ConnectorConfig{
|
||||||
&connector.OIDCConnectorConfig{
|
&connector.OIDCConnectorConfig{
|
||||||
|
@ -112,6 +111,9 @@ func makeTestFixtures() (*testFixtures, error) {
|
||||||
ID: "local",
|
ID: "local",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
||||||
|
|
||||||
|
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||||
|
|
||||||
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||||
|
|
2
test
2
test
|
@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
|
||||||
|
|
||||||
source ./build
|
source ./build
|
||||||
|
|
||||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api email"
|
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email"
|
||||||
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
||||||
|
|
||||||
# user has not provided PKG override
|
# user has not provided PKG override
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
|
@ -124,7 +125,10 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
Password: []byte("password-2"),
|
Password: []byte("password-2"),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
mgr := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||||
|
&connector.LocalConnectorConfig{ID: "local"},
|
||||||
|
})
|
||||||
|
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||||
mgr.Clock = clock
|
mgr.Clock = clock
|
||||||
ci := oidc.ClientIdentity{
|
ci := oidc.ClientIdentity{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
|
@ -25,6 +26,7 @@ type UserManager struct {
|
||||||
|
|
||||||
userRepo user.UserRepo
|
userRepo user.UserRepo
|
||||||
pwRepo user.PasswordInfoRepo
|
pwRepo user.PasswordInfoRepo
|
||||||
|
connCfgRepo connector.ConnectorConfigRepo
|
||||||
begin repo.TransactionFactory
|
begin repo.TransactionFactory
|
||||||
userIDGenerator user.UserIDGenerator
|
userIDGenerator user.UserIDGenerator
|
||||||
}
|
}
|
||||||
|
@ -35,12 +37,13 @@ type ManagerOptions struct {
|
||||||
// variable policies
|
// variable policies
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
|
func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, connCfgRepo connector.ConnectorConfigRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
|
||||||
return &UserManager{
|
return &UserManager{
|
||||||
Clock: clockwork.NewRealClock(),
|
Clock: clockwork.NewRealClock(),
|
||||||
|
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
pwRepo: pwRepo,
|
pwRepo: pwRepo,
|
||||||
|
connCfgRepo: connCfgRepo,
|
||||||
begin: txnFactory,
|
begin: txnFactory,
|
||||||
userIDGenerator: user.DefaultUserIDGenerator,
|
userIDGenerator: user.DefaultUserIDGenerator,
|
||||||
}
|
}
|
||||||
|
@ -80,7 +83,7 @@ func (m *UserManager) CreateUser(usr user.User, hashedPassword user.Password, co
|
||||||
ConnectorID: connID,
|
ConnectorID: connID,
|
||||||
ID: usr.ID,
|
ID: usr.ID,
|
||||||
}
|
}
|
||||||
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil {
|
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||||
rollback(tx)
|
rollback(tx)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -141,7 +144,7 @@ func (m *UserManager) RegisterWithRemoteIdentity(email string, emailVerified boo
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil {
|
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||||
rollback(tx)
|
rollback(tx)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -177,7 +180,7 @@ func (m *UserManager) RegisterWithPassword(email, plaintext, connID string) (str
|
||||||
ConnectorID: connID,
|
ConnectorID: connID,
|
||||||
ID: usr.ID,
|
ID: usr.ID,
|
||||||
}
|
}
|
||||||
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil {
|
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||||
rollback(tx)
|
rollback(tx)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -338,6 +341,16 @@ func (m *UserManager) insertNewUser(tx repo.Transaction, email string, emailVeri
|
||||||
return usr, nil
|
return usr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *UserManager) addRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error {
|
||||||
|
if _, err := m.connCfgRepo.GetConnectorByID(tx, rid.ConnectorID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := m.userRepo.AddRemoteIdentity(tx, userID, rid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func rollback(tx repo.Transaction) {
|
func rollback(tx repo.Transaction) {
|
||||||
err := tx.Rollback()
|
err := tx.Rollback()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/repo"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
)
|
)
|
||||||
|
@ -16,6 +17,7 @@ import (
|
||||||
type testFixtures struct {
|
type testFixtures struct {
|
||||||
ur user.UserRepo
|
ur user.UserRepo
|
||||||
pwr user.PasswordInfoRepo
|
pwr user.PasswordInfoRepo
|
||||||
|
ccr connector.ConnectorConfigRepo
|
||||||
mgr *UserManager
|
mgr *UserManager
|
||||||
clock clockwork.Clock
|
clock clockwork.Clock
|
||||||
}
|
}
|
||||||
|
@ -60,7 +62,10 @@ func makeTestFixtures() *testFixtures {
|
||||||
Password: []byte("password-2"),
|
Password: []byte("password-2"),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
f.mgr = NewUserManager(f.ur, f.pwr, repo.InMemTransactionFactory, ManagerOptions{})
|
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||||
|
&connector.LocalConnectorConfig{ID: "local"},
|
||||||
|
})
|
||||||
|
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
|
||||||
f.mgr.Clock = f.clock
|
f.mgr.Clock = f.clock
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
@ -98,6 +103,15 @@ func TestRegisterWithRemoteIdentity(t *testing.T) {
|
||||||
},
|
},
|
||||||
err: user.ErrorDuplicateRemoteIdentity,
|
err: user.ErrorDuplicateRemoteIdentity,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
email: "anotheremail@example.com",
|
||||||
|
emailVerified: false,
|
||||||
|
rid: user.RemoteIdentity{
|
||||||
|
ConnectorID: "idonotexist",
|
||||||
|
ID: "1",
|
||||||
|
},
|
||||||
|
err: connector.ErrorNotFound,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
@ -159,7 +173,7 @@ func TestRegisterWithPassword(t *testing.T) {
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
f := makeTestFixtures()
|
f := makeTestFixtures()
|
||||||
connID := "connID"
|
connID := "local"
|
||||||
userID, err := f.mgr.RegisterWithPassword(
|
userID, err := f.mgr.RegisterWithPassword(
|
||||||
tt.email,
|
tt.email,
|
||||||
tt.plaintext,
|
tt.plaintext,
|
||||||
|
@ -358,6 +372,7 @@ func TestCreateUser(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
usr user.User
|
usr user.User
|
||||||
hashedPW user.Password
|
hashedPW user.Password
|
||||||
|
localID string // defaults to "local"
|
||||||
|
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
|
@ -383,11 +398,24 @@ func TestCreateUser(t *testing.T) {
|
||||||
hashedPW: user.Password("I am a hash"),
|
hashedPW: user.Password("I am a hash"),
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
usr: user.User{
|
||||||
|
DisplayName: "Eric Exampleson",
|
||||||
|
Email: "eric@example.com",
|
||||||
|
},
|
||||||
|
hashedPW: user.Password("I am a hash"),
|
||||||
|
localID: "abadlocalid",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
f := makeTestFixtures()
|
f := makeTestFixtures()
|
||||||
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, "local")
|
localID := "local"
|
||||||
|
if tt.localID != "" {
|
||||||
|
localID = tt.localID
|
||||||
|
}
|
||||||
|
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("case %d: want non-nil err", i)
|
t.Errorf("case %d: want non-nil err", i)
|
||||||
|
|
Loading…
Reference in a new issue