user/manager: connector must exists when creating remote identity

Add ConnectorConfigRepo to UserManager. When trying to create a
RemoteIdentity, validate that the connector ID exists.

Fixes #198
This commit is contained in:
Eric Chiang 2015-12-07 17:19:55 -08:00
parent d518447282
commit f43655a8c3
12 changed files with 183 additions and 14 deletions

View file

@ -99,8 +99,9 @@ func main() {
userRepo := db.NewUserRepo(dbc) userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc)
connCfgRepo := db.NewConnectorConfigRepo(dbc)
userManager := manager.NewUserManager(userRepo, userManager := manager.NewUserManager(userRepo,
pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID) adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...) kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
if err != nil { if err != nil {

View file

@ -4,6 +4,8 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"os" "os"
"github.com/coreos/dex/repo"
) )
func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) { func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) {
@ -41,6 +43,19 @@ type memConnectorConfigRepo struct {
configs []ConnectorConfig configs []ConnectorConfig
} }
func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo {
return &memConnectorConfigRepo{configs: cfgs}
}
func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) { func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
return r.configs, nil return r.configs, nil
} }
func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) {
for _, cfg := range r.configs {
if cfg.ConnectorID() == id {
return cfg, nil
}
}
return nil, ErrorNotFound
}

View file

@ -1,14 +1,18 @@
package connector package connector
import ( import (
"errors"
"html/template" "html/template"
"net/http" "net/http"
"net/url" "net/url"
"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/coreos/pkg/health" "github.com/coreos/pkg/health"
) )
var ErrorNotFound = errors.New("connector not found in repository")
type Connector interface { type Connector interface {
ID() string ID() string
LoginURL(sessionKey, prompt string) (string, error) LoginURL(sessionKey, prompt string) (string, error)
@ -34,4 +38,5 @@ type ConnectorConfig interface {
type ConnectorConfigRepo interface { type ConnectorConfigRepo interface {
All() ([]ConnectorConfig, error) All() ([]ConnectorConfig, error)
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
} }

View file

@ -1,6 +1,7 @@
package db package db
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -9,6 +10,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
) )
const ( const (
@ -91,6 +93,18 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
return cfgs, nil return cfgs, nil
} }
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
var c connectorConfigModel
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
if err == sql.ErrNoRows {
return nil, connector.ErrorNotFound
}
}
return c.ConnectorConfig()
}
func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error { func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
insert := make([]interface{}, len(cfgs)) insert := make([]interface{}, len(cfgs))
for i, cfg := range cfgs { for i, cfg := range cfgs {
@ -119,3 +133,15 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
return tx.Commit() return tx.Commit()
} }
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}

View file

@ -0,0 +1,71 @@
package repo
import (
"fmt"
"os"
"testing"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
)
type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo
var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory
func init() {
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
} else {
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn)
}
}
func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory {
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
dbMap := initDB(dsn)
repo := db.NewConnectorConfigRepo(dbMap)
if err := repo.Set(cfgs); err != nil {
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
}
return repo
}
}
func TestConnectorConfigRepoGetByID(t *testing.T) {
tests := []struct {
cfgs []connector.ConnectorConfig
id string
err error
}{
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
},
id: "local",
},
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local1"},
&connector.LocalConnectorConfig{ID: "local2"},
},
id: "local2",
},
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local1"},
&connector.LocalConnectorConfig{ID: "local2"},
},
id: "foo",
err: connector.ErrorNotFound,
},
}
for i, tt := range tests {
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs)
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
}
}
}

View file

@ -10,6 +10,7 @@ import (
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" "github.com/coreos/dex/user/manager"
@ -47,7 +48,10 @@ func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.Pas
ur := user.NewUserRepoFromUsers(users) ur := user.NewUserRepoFromUsers(users)
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords) pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
um := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{}) ccr := connector.NewConnectorConfigRepoFromConfigs(
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
)
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
um.Clock = clock um.Clock = clock
return ur, pwr, um return ur, pwr, um
} }

View file

@ -134,7 +134,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
refTokRepo := refresh.NewRefreshTokenRepo() refTokRepo := refresh.NewRefreshTokenRepo()
txnFactory := repo.InMemTransactionFactory txnFactory := repo.InMemTransactionFactory
userManager := manager.NewUserManager(userRepo, pwiRepo, txnFactory, manager.ManagerOptions{}) userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
srv.ClientIdentityRepo = ciRepo srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo srv.ConnectorConfigRepo = cfgRepo
@ -172,7 +172,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
cfgRepo := db.NewConnectorConfigRepo(dbc) cfgRepo := db.NewConnectorConfigRepo(dbc)
userRepo := db.NewUserRepo(dbc) userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := manager.NewUserManager(userRepo, pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc) refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := session.NewSessionManager(sRepo, skRepo) sm := session.NewSessionManager(sRepo, skRepo)

View file

@ -92,7 +92,6 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
func makeTestFixtures() (*testFixtures, error) { func makeTestFixtures() (*testFixtures, error) {
userRepo := user.NewUserRepoFromUsers(testUsers) userRepo := user.NewUserRepoFromUsers(testUsers)
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos) pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
manager := manager.NewUserManager(userRepo, pwRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
connConfigs := []connector.ConnectorConfig{ connConfigs := []connector.ConnectorConfig{
&connector.OIDCConnectorConfig{ &connector.OIDCConnectorConfig{
@ -112,6 +111,9 @@ func makeTestFixtures() (*testFixtures, error) {
ID: "local", ID: "local",
}, },
} }
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sessionManager.GenerateCode = sequentialGenerateCodeFunc() sessionManager.GenerateCode = sequentialGenerateCodeFunc()

2
test
View file

@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
source ./build source ./build
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api email" TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log" FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override # user has not provided PKG override

View file

@ -10,6 +10,7 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
@ -124,7 +125,10 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
Password: []byte("password-2"), Password: []byte("password-2"),
}, },
}) })
mgr := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{}) ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
mgr.Clock = clock mgr.Clock = clock
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{

View file

@ -6,6 +6,7 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
@ -25,6 +26,7 @@ type UserManager struct {
userRepo user.UserRepo userRepo user.UserRepo
pwRepo user.PasswordInfoRepo pwRepo user.PasswordInfoRepo
connCfgRepo connector.ConnectorConfigRepo
begin repo.TransactionFactory begin repo.TransactionFactory
userIDGenerator user.UserIDGenerator userIDGenerator user.UserIDGenerator
} }
@ -35,12 +37,13 @@ type ManagerOptions struct {
// variable policies // variable policies
} }
func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager { func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, connCfgRepo connector.ConnectorConfigRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
return &UserManager{ return &UserManager{
Clock: clockwork.NewRealClock(), Clock: clockwork.NewRealClock(),
userRepo: userRepo, userRepo: userRepo,
pwRepo: pwRepo, pwRepo: pwRepo,
connCfgRepo: connCfgRepo,
begin: txnFactory, begin: txnFactory,
userIDGenerator: user.DefaultUserIDGenerator, userIDGenerator: user.DefaultUserIDGenerator,
} }
@ -80,7 +83,7 @@ func (m *UserManager) CreateUser(usr user.User, hashedPassword user.Password, co
ConnectorID: connID, ConnectorID: connID,
ID: usr.ID, ID: usr.ID,
} }
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil { if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
@ -141,7 +144,7 @@ func (m *UserManager) RegisterWithRemoteIdentity(email string, emailVerified boo
return "", err return "", err
} }
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil { if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
@ -177,7 +180,7 @@ func (m *UserManager) RegisterWithPassword(email, plaintext, connID string) (str
ConnectorID: connID, ConnectorID: connID,
ID: usr.ID, ID: usr.ID,
} }
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil { if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
@ -338,6 +341,16 @@ func (m *UserManager) insertNewUser(tx repo.Transaction, email string, emailVeri
return usr, nil return usr, nil
} }
func (m *UserManager) addRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error {
if _, err := m.connCfgRepo.GetConnectorByID(tx, rid.ConnectorID); err != nil {
return err
}
if err := m.userRepo.AddRemoteIdentity(tx, userID, rid); err != nil {
return err
}
return nil
}
func rollback(tx repo.Transaction) { func rollback(tx repo.Transaction) {
err := tx.Rollback() err := tx.Rollback()
if err != nil { if err != nil {

View file

@ -9,6 +9,7 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
) )
@ -16,6 +17,7 @@ import (
type testFixtures struct { type testFixtures struct {
ur user.UserRepo ur user.UserRepo
pwr user.PasswordInfoRepo pwr user.PasswordInfoRepo
ccr connector.ConnectorConfigRepo
mgr *UserManager mgr *UserManager
clock clockwork.Clock clock clockwork.Clock
} }
@ -60,7 +62,10 @@ func makeTestFixtures() *testFixtures {
Password: []byte("password-2"), Password: []byte("password-2"),
}, },
}) })
f.mgr = NewUserManager(f.ur, f.pwr, repo.InMemTransactionFactory, ManagerOptions{}) f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
f.mgr.Clock = f.clock f.mgr.Clock = f.clock
return f return f
} }
@ -98,6 +103,15 @@ func TestRegisterWithRemoteIdentity(t *testing.T) {
}, },
err: user.ErrorDuplicateRemoteIdentity, err: user.ErrorDuplicateRemoteIdentity,
}, },
{
email: "anotheremail@example.com",
emailVerified: false,
rid: user.RemoteIdentity{
ConnectorID: "idonotexist",
ID: "1",
},
err: connector.ErrorNotFound,
},
} }
for i, tt := range tests { for i, tt := range tests {
@ -159,7 +173,7 @@ func TestRegisterWithPassword(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
f := makeTestFixtures() f := makeTestFixtures()
connID := "connID" connID := "local"
userID, err := f.mgr.RegisterWithPassword( userID, err := f.mgr.RegisterWithPassword(
tt.email, tt.email,
tt.plaintext, tt.plaintext,
@ -358,6 +372,7 @@ func TestCreateUser(t *testing.T) {
tests := []struct { tests := []struct {
usr user.User usr user.User
hashedPW user.Password hashedPW user.Password
localID string // defaults to "local"
wantErr bool wantErr bool
}{ }{
@ -383,11 +398,24 @@ func TestCreateUser(t *testing.T) {
hashedPW: user.Password("I am a hash"), hashedPW: user.Password("I am a hash"),
wantErr: true, wantErr: true,
}, },
{
usr: user.User{
DisplayName: "Eric Exampleson",
Email: "eric@example.com",
},
hashedPW: user.Password("I am a hash"),
localID: "abadlocalid",
wantErr: true,
},
} }
for i, tt := range tests { for i, tt := range tests {
f := makeTestFixtures() f := makeTestFixtures()
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, "local") localID := "local"
if tt.localID != "" {
localID = tt.localID
}
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID)
if tt.wantErr { if tt.wantErr {
if err == nil { if err == nil {
t.Errorf("case %d: want non-nil err", i) t.Errorf("case %d: want non-nil err", i)