forked from mystiq/dex
user/manager: connector must exists when creating remote identity
Add ConnectorConfigRepo to UserManager. When trying to create a RemoteIdentity, validate that the connector ID exists. Fixes #198
This commit is contained in:
parent
d518447282
commit
f43655a8c3
12 changed files with 183 additions and 14 deletions
|
@ -99,8 +99,9 @@ func main() {
|
|||
|
||||
userRepo := db.NewUserRepo(dbc)
|
||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||
connCfgRepo := db.NewConnectorConfigRepo(dbc)
|
||||
userManager := manager.NewUserManager(userRepo,
|
||||
pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
|
||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
|
||||
if err != nil {
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) {
|
||||
|
@ -41,6 +43,19 @@ type memConnectorConfigRepo struct {
|
|||
configs []ConnectorConfig
|
||||
}
|
||||
|
||||
func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo {
|
||||
return &memConnectorConfigRepo{configs: cfgs}
|
||||
}
|
||||
|
||||
func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
|
||||
return r.configs, nil
|
||||
}
|
||||
|
||||
func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) {
|
||||
for _, cfg := range r.configs {
|
||||
if cfg.ConnectorID() == id {
|
||||
return cfg, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrorNotFound
|
||||
}
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
package connector
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/coreos/pkg/health"
|
||||
)
|
||||
|
||||
var ErrorNotFound = errors.New("connector not found in repository")
|
||||
|
||||
type Connector interface {
|
||||
ID() string
|
||||
LoginURL(sessionKey, prompt string) (string, error)
|
||||
|
@ -34,4 +38,5 @@ type ConnectorConfig interface {
|
|||
|
||||
type ConnectorConfigRepo interface {
|
||||
All() ([]ConnectorConfig, error)
|
||||
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -9,6 +10,7 @@ import (
|
|||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -91,6 +93,18 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
|||
return cfgs, nil
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
|
||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
|
||||
var c connectorConfigModel
|
||||
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, connector.ErrorNotFound
|
||||
}
|
||||
}
|
||||
return c.ConnectorConfig()
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
||||
insert := make([]interface{}, len(cfgs))
|
||||
for i, cfg := range cfgs {
|
||||
|
@ -119,3 +133,15 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
|||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
return gorpTx
|
||||
}
|
||||
|
|
71
functional/repo/connector_repo_test.go
Normal file
71
functional/repo/connector_repo_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
)
|
||||
|
||||
type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo
|
||||
|
||||
var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory
|
||||
|
||||
func init() {
|
||||
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
|
||||
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
|
||||
} else {
|
||||
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory {
|
||||
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
|
||||
dbMap := initDB(dsn)
|
||||
|
||||
repo := db.NewConnectorConfigRepo(dbMap)
|
||||
if err := repo.Set(cfgs); err != nil {
|
||||
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
|
||||
}
|
||||
return repo
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectorConfigRepoGetByID(t *testing.T) {
|
||||
tests := []struct {
|
||||
cfgs []connector.ConnectorConfig
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
cfgs: []connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
},
|
||||
id: "local",
|
||||
},
|
||||
{
|
||||
cfgs: []connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local1"},
|
||||
&connector.LocalConnectorConfig{ID: "local2"},
|
||||
},
|
||||
id: "local2",
|
||||
},
|
||||
{
|
||||
cfgs: []connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local1"},
|
||||
&connector.LocalConnectorConfig{ID: "local2"},
|
||||
},
|
||||
id: "foo",
|
||||
err: connector.ErrorNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs)
|
||||
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
|
||||
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
|
@ -47,7 +48,10 @@ func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.Pas
|
|||
ur := user.NewUserRepoFromUsers(users)
|
||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
|
||||
|
||||
um := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
|
||||
)
|
||||
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
um.Clock = clock
|
||||
return ur, pwr, um
|
||||
}
|
||||
|
|
|
@ -134,7 +134,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
refTokRepo := refresh.NewRefreshTokenRepo()
|
||||
|
||||
txnFactory := repo.InMemTransactionFactory
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, txnFactory, manager.ManagerOptions{})
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
|
||||
srv.ClientIdentityRepo = ciRepo
|
||||
srv.KeySetRepo = kRepo
|
||||
srv.ConnectorConfigRepo = cfgRepo
|
||||
|
@ -172,7 +172,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
|||
cfgRepo := db.NewConnectorConfigRepo(dbc)
|
||||
userRepo := db.NewUserRepo(dbc)
|
||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
||||
|
||||
sm := session.NewSessionManager(sRepo, skRepo)
|
||||
|
|
|
@ -92,7 +92,6 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
|
|||
func makeTestFixtures() (*testFixtures, error) {
|
||||
userRepo := user.NewUserRepoFromUsers(testUsers)
|
||||
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
|
||||
manager := manager.NewUserManager(userRepo, pwRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
|
||||
connConfigs := []connector.ConnectorConfig{
|
||||
&connector.OIDCConnectorConfig{
|
||||
|
@ -112,6 +111,9 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
ID: "local",
|
||||
},
|
||||
}
|
||||
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
||||
|
||||
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
|
||||
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||
|
|
2
test
2
test
|
@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
|
|||
|
||||
source ./build
|
||||
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api email"
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email"
|
||||
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
||||
|
||||
# user has not provided PKG override
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/user"
|
||||
|
@ -124,7 +125,10 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
Password: []byte("password-2"),
|
||||
},
|
||||
})
|
||||
mgr := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
mgr.Clock = clock
|
||||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/user"
|
||||
|
@ -25,6 +26,7 @@ type UserManager struct {
|
|||
|
||||
userRepo user.UserRepo
|
||||
pwRepo user.PasswordInfoRepo
|
||||
connCfgRepo connector.ConnectorConfigRepo
|
||||
begin repo.TransactionFactory
|
||||
userIDGenerator user.UserIDGenerator
|
||||
}
|
||||
|
@ -35,12 +37,13 @@ type ManagerOptions struct {
|
|||
// variable policies
|
||||
}
|
||||
|
||||
func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
|
||||
func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, connCfgRepo connector.ConnectorConfigRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
|
||||
return &UserManager{
|
||||
Clock: clockwork.NewRealClock(),
|
||||
|
||||
userRepo: userRepo,
|
||||
pwRepo: pwRepo,
|
||||
connCfgRepo: connCfgRepo,
|
||||
begin: txnFactory,
|
||||
userIDGenerator: user.DefaultUserIDGenerator,
|
||||
}
|
||||
|
@ -80,7 +83,7 @@ func (m *UserManager) CreateUser(usr user.User, hashedPassword user.Password, co
|
|||
ConnectorID: connID,
|
||||
ID: usr.ID,
|
||||
}
|
||||
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
@ -141,7 +144,7 @@ func (m *UserManager) RegisterWithRemoteIdentity(email string, emailVerified boo
|
|||
return "", err
|
||||
}
|
||||
|
||||
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
@ -177,7 +180,7 @@ func (m *UserManager) RegisterWithPassword(email, plaintext, connID string) (str
|
|||
ConnectorID: connID,
|
||||
ID: usr.ID,
|
||||
}
|
||||
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
@ -338,6 +341,16 @@ func (m *UserManager) insertNewUser(tx repo.Transaction, email string, emailVeri
|
|||
return usr, nil
|
||||
}
|
||||
|
||||
func (m *UserManager) addRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error {
|
||||
if _, err := m.connCfgRepo.GetConnectorByID(tx, rid.ConnectorID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.userRepo.AddRemoteIdentity(tx, userID, rid); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rollback(tx repo.Transaction) {
|
||||
err := tx.Rollback()
|
||||
if err != nil {
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/jonboulle/clockwork"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/user"
|
||||
)
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
type testFixtures struct {
|
||||
ur user.UserRepo
|
||||
pwr user.PasswordInfoRepo
|
||||
ccr connector.ConnectorConfigRepo
|
||||
mgr *UserManager
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
@ -60,7 +62,10 @@ func makeTestFixtures() *testFixtures {
|
|||
Password: []byte("password-2"),
|
||||
},
|
||||
})
|
||||
f.mgr = NewUserManager(f.ur, f.pwr, repo.InMemTransactionFactory, ManagerOptions{})
|
||||
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
|
||||
f.mgr.Clock = f.clock
|
||||
return f
|
||||
}
|
||||
|
@ -98,6 +103,15 @@ func TestRegisterWithRemoteIdentity(t *testing.T) {
|
|||
},
|
||||
err: user.ErrorDuplicateRemoteIdentity,
|
||||
},
|
||||
{
|
||||
email: "anotheremail@example.com",
|
||||
emailVerified: false,
|
||||
rid: user.RemoteIdentity{
|
||||
ConnectorID: "idonotexist",
|
||||
ID: "1",
|
||||
},
|
||||
err: connector.ErrorNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
@ -159,7 +173,7 @@ func TestRegisterWithPassword(t *testing.T) {
|
|||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
connID := "connID"
|
||||
connID := "local"
|
||||
userID, err := f.mgr.RegisterWithPassword(
|
||||
tt.email,
|
||||
tt.plaintext,
|
||||
|
@ -358,6 +372,7 @@ func TestCreateUser(t *testing.T) {
|
|||
tests := []struct {
|
||||
usr user.User
|
||||
hashedPW user.Password
|
||||
localID string // defaults to "local"
|
||||
|
||||
wantErr bool
|
||||
}{
|
||||
|
@ -383,11 +398,24 @@ func TestCreateUser(t *testing.T) {
|
|||
hashedPW: user.Password("I am a hash"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
usr: user.User{
|
||||
DisplayName: "Eric Exampleson",
|
||||
Email: "eric@example.com",
|
||||
},
|
||||
hashedPW: user.Password("I am a hash"),
|
||||
localID: "abadlocalid",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, "local")
|
||||
localID := "local"
|
||||
if tt.localID != "" {
|
||||
localID = tt.localID
|
||||
}
|
||||
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
|
|
Loading…
Reference in a new issue