From f43655a8c37094882e67ef18ff325f69283561e7 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Mon, 7 Dec 2015 17:19:55 -0800 Subject: [PATCH] user/manager: connector must exists when creating remote identity Add ConnectorConfigRepo to UserManager. When trying to create a RemoteIdentity, validate that the connector ID exists. Fixes #198 --- cmd/dex-overlord/main.go | 3 +- connector/config_repo.go | 15 ++++++ connector/interface.go | 5 ++ db/connector_config.go | 26 ++++++++++ functional/repo/connector_repo_test.go | 71 ++++++++++++++++++++++++++ integration/common_test.go | 6 ++- server/config.go | 4 +- server/testutil.go | 4 +- test | 2 +- user/api/api_test.go | 6 ++- user/manager/manager.go | 21 ++++++-- user/manager/manager_test.go | 34 ++++++++++-- 12 files changed, 183 insertions(+), 14 deletions(-) create mode 100644 functional/repo/connector_repo_test.go diff --git a/cmd/dex-overlord/main.go b/cmd/dex-overlord/main.go index 649db873..b8d48fc9 100644 --- a/cmd/dex-overlord/main.go +++ b/cmd/dex-overlord/main.go @@ -99,8 +99,9 @@ func main() { userRepo := db.NewUserRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc) + connCfgRepo := db.NewConnectorConfigRepo(dbc) userManager := manager.NewUserManager(userRepo, - pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) + pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID) kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...) if err != nil { diff --git a/connector/config_repo.go b/connector/config_repo.go index 212a4cd7..dd2d2c09 100644 --- a/connector/config_repo.go +++ b/connector/config_repo.go @@ -4,6 +4,8 @@ import ( "encoding/json" "io" "os" + + "github.com/coreos/dex/repo" ) func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) { @@ -41,6 +43,19 @@ type memConnectorConfigRepo struct { configs []ConnectorConfig } +func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo { + return &memConnectorConfigRepo{configs: cfgs} +} + func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) { return r.configs, nil } + +func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) { + for _, cfg := range r.configs { + if cfg.ConnectorID() == id { + return cfg, nil + } + } + return nil, ErrorNotFound +} diff --git a/connector/interface.go b/connector/interface.go index 44cbb5da..36d0dcc6 100644 --- a/connector/interface.go +++ b/connector/interface.go @@ -1,14 +1,18 @@ package connector import ( + "errors" "html/template" "net/http" "net/url" + "github.com/coreos/dex/repo" "github.com/coreos/go-oidc/oidc" "github.com/coreos/pkg/health" ) +var ErrorNotFound = errors.New("connector not found in repository") + type Connector interface { ID() string LoginURL(sessionKey, prompt string) (string, error) @@ -34,4 +38,5 @@ type ConnectorConfig interface { type ConnectorConfigRepo interface { All() ([]ConnectorConfig, error) + GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error) } diff --git a/db/connector_config.go b/db/connector_config.go index f9d8386e..c8044ec7 100644 --- a/db/connector_config.go +++ b/db/connector_config.go @@ -1,6 +1,7 @@ package db import ( + "database/sql" "encoding/json" "errors" "fmt" @@ -9,6 +10,7 @@ import ( "github.com/lib/pq" "github.com/coreos/dex/connector" + "github.com/coreos/dex/repo" ) const ( @@ -91,6 +93,18 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) { return cfgs, nil } +func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) { + qt := pq.QuoteIdentifier(connectorConfigTableName) + q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt) + var c connectorConfigModel + if err := r.executor(tx).SelectOne(&c, q, id); err != nil { + if err == sql.ErrNoRows { + return nil, connector.ErrorNotFound + } + } + return c.ConnectorConfig() +} + func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error { insert := make([]interface{}, len(cfgs)) for i, cfg := range cfgs { @@ -119,3 +133,15 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error { return tx.Commit() } + +func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor { + if tx == nil { + return r.dbMap + } + + gorpTx, ok := tx.(*gorp.Transaction) + if !ok { + panic("wrong kind of transaction passed to a DB repo") + } + return gorpTx +} diff --git a/functional/repo/connector_repo_test.go b/functional/repo/connector_repo_test.go new file mode 100644 index 00000000..7d5369d4 --- /dev/null +++ b/functional/repo/connector_repo_test.go @@ -0,0 +1,71 @@ +package repo + +import ( + "fmt" + "os" + "testing" + + "github.com/coreos/dex/connector" + "github.com/coreos/dex/db" +) + +type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo + +var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory + +func init() { + if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" { + makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs + } else { + makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn) + } +} + +func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory { + return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo { + dbMap := initDB(dsn) + + repo := db.NewConnectorConfigRepo(dbMap) + if err := repo.Set(cfgs); err != nil { + panic(fmt.Sprintf("Unable to set connector configs: %v", err)) + } + return repo + } +} + +func TestConnectorConfigRepoGetByID(t *testing.T) { + tests := []struct { + cfgs []connector.ConnectorConfig + id string + err error + }{ + { + cfgs: []connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local"}, + }, + id: "local", + }, + { + cfgs: []connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local1"}, + &connector.LocalConnectorConfig{ID: "local2"}, + }, + id: "local2", + }, + { + cfgs: []connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local1"}, + &connector.LocalConnectorConfig{ID: "local2"}, + }, + id: "foo", + err: connector.ErrorNotFound, + }, + } + + for i, tt := range tests { + repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs) + if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err { + t.Errorf("case %d: want=%v, got=%v", i, tt.err, err) + } + } +} diff --git a/integration/common_test.go b/integration/common_test.go index 6029c175..a203cbf0 100644 --- a/integration/common_test.go +++ b/integration/common_test.go @@ -10,6 +10,7 @@ import ( "github.com/coreos/go-oidc/key" "github.com/jonboulle/clockwork" + "github.com/coreos/dex/connector" "github.com/coreos/dex/repo" "github.com/coreos/dex/user" "github.com/coreos/dex/user/manager" @@ -47,7 +48,10 @@ func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.Pas ur := user.NewUserRepoFromUsers(users) pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords) - um := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{}) + ccr := connector.NewConnectorConfigRepoFromConfigs( + []connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}}, + ) + um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{}) um.Clock = clock return ur, pwr, um } diff --git a/server/config.go b/server/config.go index 870bfd71..f70065a3 100644 --- a/server/config.go +++ b/server/config.go @@ -134,7 +134,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { refTokRepo := refresh.NewRefreshTokenRepo() txnFactory := repo.InMemTransactionFactory - userManager := manager.NewUserManager(userRepo, pwiRepo, txnFactory, manager.ManagerOptions{}) + userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{}) srv.ClientIdentityRepo = ciRepo srv.KeySetRepo = kRepo srv.ConnectorConfigRepo = cfgRepo @@ -172,7 +172,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error { cfgRepo := db.NewConnectorConfigRepo(dbc) userRepo := db.NewUserRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc) - userManager := manager.NewUserManager(userRepo, pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) + userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) refreshTokenRepo := db.NewRefreshTokenRepo(dbc) sm := session.NewSessionManager(sRepo, skRepo) diff --git a/server/testutil.go b/server/testutil.go index 946a3299..99615539 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -92,7 +92,6 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc { func makeTestFixtures() (*testFixtures, error) { userRepo := user.NewUserRepoFromUsers(testUsers) pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos) - manager := manager.NewUserManager(userRepo, pwRepo, repo.InMemTransactionFactory, manager.ManagerOptions{}) connConfigs := []connector.ConnectorConfig{ &connector.OIDCConnectorConfig{ @@ -112,6 +111,9 @@ func makeTestFixtures() (*testFixtures, error) { ID: "local", }, } + connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs) + + manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{}) sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sessionManager.GenerateCode = sequentialGenerateCodeFunc() diff --git a/test b/test index 3d26c7ba..981ca110 100755 --- a/test +++ b/test @@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"} source ./build -TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api email" +TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email" FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log" # user has not provided PKG override diff --git a/user/api/api_test.go b/user/api/api_test.go index 4d1789df..967db2ac 100644 --- a/user/api/api_test.go +++ b/user/api/api_test.go @@ -10,6 +10,7 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/coreos/dex/client" + "github.com/coreos/dex/connector" "github.com/coreos/dex/repo" schema "github.com/coreos/dex/schema/workerschema" "github.com/coreos/dex/user" @@ -124,7 +125,10 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { Password: []byte("password-2"), }, }) - mgr := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{}) + ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local"}, + }) + mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{}) mgr.Clock = clock ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ diff --git a/user/manager/manager.go b/user/manager/manager.go index 06249a9c..a37ca8b1 100644 --- a/user/manager/manager.go +++ b/user/manager/manager.go @@ -6,6 +6,7 @@ import ( "github.com/jonboulle/clockwork" + "github.com/coreos/dex/connector" "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/repo" "github.com/coreos/dex/user" @@ -25,6 +26,7 @@ type UserManager struct { userRepo user.UserRepo pwRepo user.PasswordInfoRepo + connCfgRepo connector.ConnectorConfigRepo begin repo.TransactionFactory userIDGenerator user.UserIDGenerator } @@ -35,12 +37,13 @@ type ManagerOptions struct { // variable policies } -func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager { +func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, connCfgRepo connector.ConnectorConfigRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager { return &UserManager{ Clock: clockwork.NewRealClock(), userRepo: userRepo, pwRepo: pwRepo, + connCfgRepo: connCfgRepo, begin: txnFactory, userIDGenerator: user.DefaultUserIDGenerator, } @@ -80,7 +83,7 @@ func (m *UserManager) CreateUser(usr user.User, hashedPassword user.Password, co ConnectorID: connID, ID: usr.ID, } - if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil { + if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil { rollback(tx) return "", err } @@ -141,7 +144,7 @@ func (m *UserManager) RegisterWithRemoteIdentity(email string, emailVerified boo return "", err } - if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil { + if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil { rollback(tx) return "", err } @@ -177,7 +180,7 @@ func (m *UserManager) RegisterWithPassword(email, plaintext, connID string) (str ConnectorID: connID, ID: usr.ID, } - if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil { + if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil { rollback(tx) return "", err } @@ -338,6 +341,16 @@ func (m *UserManager) insertNewUser(tx repo.Transaction, email string, emailVeri return usr, nil } +func (m *UserManager) addRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error { + if _, err := m.connCfgRepo.GetConnectorByID(tx, rid.ConnectorID); err != nil { + return err + } + if err := m.userRepo.AddRemoteIdentity(tx, userID, rid); err != nil { + return err + } + return nil +} + func rollback(tx repo.Transaction) { err := tx.Rollback() if err != nil { diff --git a/user/manager/manager_test.go b/user/manager/manager_test.go index 0838a525..fbe0a4a3 100644 --- a/user/manager/manager_test.go +++ b/user/manager/manager_test.go @@ -9,6 +9,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/kylelemons/godebug/pretty" + "github.com/coreos/dex/connector" "github.com/coreos/dex/repo" "github.com/coreos/dex/user" ) @@ -16,6 +17,7 @@ import ( type testFixtures struct { ur user.UserRepo pwr user.PasswordInfoRepo + ccr connector.ConnectorConfigRepo mgr *UserManager clock clockwork.Clock } @@ -60,7 +62,10 @@ func makeTestFixtures() *testFixtures { Password: []byte("password-2"), }, }) - f.mgr = NewUserManager(f.ur, f.pwr, repo.InMemTransactionFactory, ManagerOptions{}) + f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{ + &connector.LocalConnectorConfig{ID: "local"}, + }) + f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{}) f.mgr.Clock = f.clock return f } @@ -98,6 +103,15 @@ func TestRegisterWithRemoteIdentity(t *testing.T) { }, err: user.ErrorDuplicateRemoteIdentity, }, + { + email: "anotheremail@example.com", + emailVerified: false, + rid: user.RemoteIdentity{ + ConnectorID: "idonotexist", + ID: "1", + }, + err: connector.ErrorNotFound, + }, } for i, tt := range tests { @@ -159,7 +173,7 @@ func TestRegisterWithPassword(t *testing.T) { for i, tt := range tests { f := makeTestFixtures() - connID := "connID" + connID := "local" userID, err := f.mgr.RegisterWithPassword( tt.email, tt.plaintext, @@ -358,6 +372,7 @@ func TestCreateUser(t *testing.T) { tests := []struct { usr user.User hashedPW user.Password + localID string // defaults to "local" wantErr bool }{ @@ -383,11 +398,24 @@ func TestCreateUser(t *testing.T) { hashedPW: user.Password("I am a hash"), wantErr: true, }, + { + usr: user.User{ + DisplayName: "Eric Exampleson", + Email: "eric@example.com", + }, + hashedPW: user.Password("I am a hash"), + localID: "abadlocalid", + wantErr: true, + }, } for i, tt := range tests { f := makeTestFixtures() - id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, "local") + localID := "local" + if tt.localID != "" { + localID = tt.localID + } + id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID) if tt.wantErr { if err == nil { t.Errorf("case %d: want non-nil err", i)