forked from mystiq/dex
*: remove in memory user repo
This commit is contained in:
parent
95560404a3
commit
2726f4dcdf
13 changed files with 167 additions and 480 deletions
admin
functional/repo
integration
repo
server
user
|
@ -4,7 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/schema/adminschema"
|
"github.com/coreos/dex/schema/adminschema"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/dex/user/manager"
|
"github.com/coreos/dex/user/manager"
|
||||||
|
@ -22,22 +22,30 @@ type testFixtures struct {
|
||||||
func makeTestFixtures() *testFixtures {
|
func makeTestFixtures() *testFixtures {
|
||||||
f := &testFixtures{}
|
f := &testFixtures{}
|
||||||
|
|
||||||
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
dbMap := db.NewMemDB()
|
||||||
{
|
f.ur = func() user.UserRepo {
|
||||||
User: user.User{
|
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||||
ID: "ID-1",
|
{
|
||||||
Email: "email-1@example.com",
|
User: user.User{
|
||||||
DisplayName: "Name-1",
|
ID: "ID-1",
|
||||||
|
Email: "email-1@example.com",
|
||||||
|
DisplayName: "Name-1",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
{
|
||||||
{
|
User: user.User{
|
||||||
User: user.User{
|
ID: "ID-2",
|
||||||
ID: "ID-2",
|
Email: "email-2@example.com",
|
||||||
Email: "email-2@example.com",
|
DisplayName: "Name-2",
|
||||||
DisplayName: "Name-2",
|
},
|
||||||
},
|
},
|
||||||
},
|
})
|
||||||
})
|
if err != nil {
|
||||||
|
panic("Failed to create user repo: " + err.Error())
|
||||||
|
}
|
||||||
|
return repo
|
||||||
|
}()
|
||||||
|
|
||||||
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||||
{
|
{
|
||||||
UserID: "ID-1",
|
UserID: "ID-1",
|
||||||
|
@ -47,7 +55,7 @@ func makeTestFixtures() *testFixtures {
|
||||||
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||||
&connector.LocalConnectorConfig{ID: "local"},
|
&connector.LocalConnectorConfig{ID: "local"},
|
||||||
})
|
})
|
||||||
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||||
f.adAPI = NewAdminAPI(f.mgr, f.ur, f.pwr, "local")
|
f.adAPI = NewAdminAPI(f.mgr, f.ur, f.pwr, "local")
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gorp/gorp"
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
|
@ -49,10 +50,12 @@ func newUserRepo(t *testing.T, users []user.UserWithRemoteIdentities) user.UserR
|
||||||
if users == nil {
|
if users == nil {
|
||||||
users = []user.UserWithRemoteIdentities{}
|
users = []user.UserWithRemoteIdentities{}
|
||||||
}
|
}
|
||||||
|
var dbMap *gorp.DbMap
|
||||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||||
return user.NewUserRepoFromUsers(users)
|
dbMap = db.NewMemDB()
|
||||||
|
} else {
|
||||||
|
dbMap = connect(t)
|
||||||
}
|
}
|
||||||
dbMap := connect(t)
|
|
||||||
repo, err := db.NewUserRepoFromUsers(dbMap, users)
|
repo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to add users: %v", err)
|
t.Fatalf("Unable to add users: %v", err)
|
||||||
|
@ -416,59 +419,6 @@ func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewUserRepoFromUsers(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
users []user.UserWithRemoteIdentities
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
users: []user.UserWithRemoteIdentities{
|
|
||||||
{
|
|
||||||
User: user.User{
|
|
||||||
ID: "123",
|
|
||||||
Email: "email123@example.com",
|
|
||||||
},
|
|
||||||
RemoteIdentities: []user.RemoteIdentity{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
User: user.User{
|
|
||||||
ID: "456",
|
|
||||||
Email: "email456@example.com",
|
|
||||||
},
|
|
||||||
RemoteIdentities: []user.RemoteIdentity{
|
|
||||||
{
|
|
||||||
ID: "remoteID",
|
|
||||||
ConnectorID: "connID",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
repo := user.NewUserRepoFromUsers(tt.users)
|
|
||||||
for _, want := range tt.users {
|
|
||||||
gotUser, err := repo.Get(nil, want.User.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: want nil err: %v", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
gotRIDs, err := repo.GetRemoteIdentities(nil, want.User.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: want nil err: %v", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(want.User, gotUser) {
|
|
||||||
t.Errorf("case %d: want=%#v got=%#v", i, want.User, gotUser)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(want.RemoteIdentities, gotRIDs) {
|
|
||||||
t.Errorf("case %d: want=%#v got=%#v", i, want.RemoteIdentities, gotRIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetByEmail(t *testing.T) {
|
func TestGetByEmail(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
email string
|
email string
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
|
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/dex/user/manager"
|
"github.com/coreos/dex/user/manager"
|
||||||
)
|
)
|
||||||
|
@ -45,13 +45,20 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) {
|
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) {
|
||||||
ur := user.NewUserRepoFromUsers(users)
|
dbMap := db.NewMemDB()
|
||||||
|
ur := func() user.UserRepo {
|
||||||
|
repo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||||
|
if err != nil {
|
||||||
|
panic("Failed to create user repo: " + err.Error())
|
||||||
|
}
|
||||||
|
return repo
|
||||||
|
}()
|
||||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
|
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
|
||||||
|
|
||||||
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||||
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
|
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
|
||||||
)
|
)
|
||||||
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
um := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||||
um.Clock = clock
|
um.Clock = clock
|
||||||
return ur, pwr, um
|
return ur, pwr, um
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,7 +139,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
||||||
Email: "testemail@example.com",
|
Email: "testemail@example.com",
|
||||||
DisplayName: "displayname",
|
DisplayName: "displayname",
|
||||||
}
|
}
|
||||||
userRepo := user.NewUserRepo()
|
userRepo := db.NewUserRepo(db.NewMemDB())
|
||||||
if err := userRepo.Create(nil, usr); err != nil {
|
if err := userRepo.Create(nil, usr); err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
28
repo/repo.go
28
repo/repo.go
|
@ -1,7 +1,5 @@
|
||||||
package repo
|
package repo
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
// Transaction is an abstraction of transactions typically found in database systems.
|
// Transaction is an abstraction of transactions typically found in database systems.
|
||||||
// One of Commit() or Rollback() must be called on each transaction.
|
// One of Commit() or Rollback() must be called on each transaction.
|
||||||
type Transaction interface {
|
type Transaction interface {
|
||||||
|
@ -13,29 +11,3 @@ type Transaction interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type TransactionFactory func() (Transaction, error)
|
type TransactionFactory func() (Transaction, error)
|
||||||
|
|
||||||
// InMemTransaction satisifies the Transaction interface for in-memory systems.
|
|
||||||
// However, the only thing it really does is ensure that the same transaction is
|
|
||||||
// can't be committed/rolled back more than once. As such, this can lead to data
|
|
||||||
// corruption and should not be used in production systems.
|
|
||||||
type InMemTransaction bool
|
|
||||||
|
|
||||||
func InMemTransactionFactory() (Transaction, error) {
|
|
||||||
return new(InMemTransaction), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *InMemTransaction) Commit() error {
|
|
||||||
return i.commitOrRollback()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *InMemTransaction) Rollback() error {
|
|
||||||
return i.commitOrRollback()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *InMemTransaction) commitOrRollback() error {
|
|
||||||
if *i {
|
|
||||||
return errors.New("Already committed/rolled-back.")
|
|
||||||
}
|
|
||||||
*i = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
@ -17,7 +18,6 @@ import (
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/email"
|
"github.com/coreos/dex/email"
|
||||||
"github.com/coreos/dex/repo"
|
|
||||||
sessionmanager "github.com/coreos/dex/session/manager"
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
useremail "github.com/coreos/dex/user/email"
|
useremail "github.com/coreos/dex/user/email"
|
||||||
|
@ -100,6 +100,8 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dbMap := db.NewMemDB()
|
||||||
|
|
||||||
ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour))
|
ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour))
|
||||||
kRepo := key.NewPrivateKeySetRepo()
|
kRepo := key.NewPrivateKeySetRepo()
|
||||||
if err = kRepo.Set(ks); err != nil {
|
if err = kRepo.Set(ks); err != nil {
|
||||||
|
@ -127,20 +129,24 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
||||||
}
|
}
|
||||||
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
|
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
|
||||||
|
|
||||||
sRepo := db.NewSessionRepo(db.NewMemDB())
|
sRepo := db.NewSessionRepo(dbMap)
|
||||||
skRepo := db.NewSessionKeyRepo(db.NewMemDB())
|
skRepo := db.NewSessionKeyRepo(dbMap)
|
||||||
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||||
|
|
||||||
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
|
users, err := loadUsers(cfg.UsersFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to read users from file: %v", err)
|
return fmt.Errorf("unable to read users from file: %v", err)
|
||||||
}
|
}
|
||||||
|
userRepo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
pwiRepo := user.NewPasswordInfoRepo()
|
pwiRepo := user.NewPasswordInfoRepo()
|
||||||
|
|
||||||
refTokRepo := db.NewRefreshTokenRepo(db.NewMemDB())
|
refTokRepo := db.NewRefreshTokenRepo(dbMap)
|
||||||
|
|
||||||
txnFactory := repo.InMemTransactionFactory
|
txnFactory := db.TransactionFactory(dbMap)
|
||||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
||||||
srv.ClientIdentityRepo = ciRepo
|
srv.ClientIdentityRepo = ciRepo
|
||||||
srv.KeySetRepo = kRepo
|
srv.KeySetRepo = kRepo
|
||||||
|
@ -154,6 +160,16 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadUsers(filepath string) (users []user.UserWithRemoteIdentities, err error) {
|
||||||
|
f, err := os.Open(filepath)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
err = json.NewDecoder(f).Decode(&users)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
||||||
if len(cfg.KeySecrets) == 0 {
|
if len(cfg.KeySecrets) == 0 {
|
||||||
return errors.New("missing key secret")
|
return errors.New("missing key secret")
|
||||||
|
|
|
@ -76,7 +76,7 @@ func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeNewUserRepo() (user.UserRepo, error) {
|
func makeNewUserRepo() (user.UserRepo, error) {
|
||||||
userRepo := user.NewUserRepo()
|
userRepo := db.NewUserRepo(db.NewMemDB())
|
||||||
|
|
||||||
id := "testid-1"
|
id := "testid-1"
|
||||||
err := userRepo.Create(nil, user.User{
|
err := userRepo.Create(nil, user.User{
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/email"
|
"github.com/coreos/dex/email"
|
||||||
"github.com/coreos/dex/repo"
|
|
||||||
sessionmanager "github.com/coreos/dex/session/manager"
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
useremail "github.com/coreos/dex/user/email"
|
useremail "github.com/coreos/dex/user/email"
|
||||||
|
@ -91,7 +90,11 @@ func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeTestFixtures() (*testFixtures, error) {
|
func makeTestFixtures() (*testFixtures, error) {
|
||||||
userRepo := user.NewUserRepoFromUsers(testUsers)
|
dbMap := db.NewMemDB()
|
||||||
|
userRepo, err := db.NewUserRepoFromUsers(dbMap, testUsers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
|
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
|
||||||
|
|
||||||
connConfigs := []connector.ConnectorConfig{
|
connConfigs := []connector.ConnectorConfig{
|
||||||
|
@ -114,7 +117,7 @@ func makeTestFixtures() (*testFixtures, error) {
|
||||||
}
|
}
|
||||||
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
||||||
|
|
||||||
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{})
|
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
|
||||||
|
|
||||||
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/db"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/dex/user/manager"
|
"github.com/coreos/dex/user/manager"
|
||||||
|
@ -86,35 +86,43 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
dbMap := db.NewMemDB()
|
||||||
{
|
ur := func() user.UserRepo {
|
||||||
User: user.User{
|
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||||
ID: "ID-1",
|
{
|
||||||
Email: "id1@example.com",
|
User: user.User{
|
||||||
Admin: true,
|
ID: "ID-1",
|
||||||
CreatedAt: clock.Now(),
|
Email: "id1@example.com",
|
||||||
|
Admin: true,
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
User: user.User{
|
||||||
|
ID: "ID-2",
|
||||||
|
Email: "id2@example.com",
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
User: user.User{
|
||||||
|
ID: "ID-3",
|
||||||
|
Email: "id3@example.com",
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
User: user.User{
|
||||||
|
ID: "ID-4",
|
||||||
|
Email: "id4@example.com",
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
Disabled: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}, {
|
})
|
||||||
User: user.User{
|
if err != nil {
|
||||||
ID: "ID-2",
|
panic("Failed to create user repo: " + err.Error())
|
||||||
Email: "id2@example.com",
|
}
|
||||||
CreatedAt: clock.Now(),
|
return repo
|
||||||
},
|
}()
|
||||||
}, {
|
|
||||||
User: user.User{
|
|
||||||
ID: "ID-3",
|
|
||||||
Email: "id3@example.com",
|
|
||||||
CreatedAt: clock.Now(),
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
User: user.User{
|
|
||||||
ID: "ID-4",
|
|
||||||
Email: "id4@example.com",
|
|
||||||
CreatedAt: clock.Now(),
|
|
||||||
Disabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||||
{
|
{
|
||||||
UserID: "ID-1",
|
UserID: "ID-1",
|
||||||
|
@ -128,7 +136,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||||
&connector.LocalConnectorConfig{ID: "local"},
|
&connector.LocalConnectorConfig{ID: "local"},
|
||||||
})
|
})
|
||||||
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
mgr := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||||
mgr.Clock = clock
|
mgr.Clock = clock
|
||||||
ci := oidc.ClientIdentity{
|
ci := oidc.ClientIdentity{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/coreos/go-oidc/key"
|
"github.com/coreos/go-oidc/key"
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/email"
|
"github.com/coreos/dex/email"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
)
|
)
|
||||||
|
@ -45,25 +46,32 @@ func (t *testEmailer) SendMail(from, subject, text, html string, to ...string) e
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
|
func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
|
||||||
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
ur := func() user.UserRepo {
|
||||||
{
|
repo, err := db.NewUserRepoFromUsers(db.NewMemDB(), []user.UserWithRemoteIdentities{
|
||||||
User: user.User{
|
{
|
||||||
ID: "ID-1",
|
User: user.User{
|
||||||
Email: "id1@example.com",
|
ID: "ID-1",
|
||||||
Admin: true,
|
Email: "id1@example.com",
|
||||||
|
Admin: true,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
User: user.User{
|
||||||
|
ID: "ID-2",
|
||||||
|
Email: "id2@example.com",
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
User: user.User{
|
||||||
|
ID: "ID-3",
|
||||||
|
Email: "id3@example.com",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}, {
|
})
|
||||||
User: user.User{
|
if err != nil {
|
||||||
ID: "ID-2",
|
panic("Failed to create user repo: " + err.Error())
|
||||||
Email: "id2@example.com",
|
}
|
||||||
},
|
return repo
|
||||||
}, {
|
}()
|
||||||
User: user.User{
|
|
||||||
ID: "ID-3",
|
|
||||||
Email: "id3@example.com",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||||
{
|
{
|
||||||
UserID: "ID-1",
|
UserID: "ID-1",
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/repo"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,32 +26,40 @@ func makeTestFixtures() *testFixtures {
|
||||||
f := &testFixtures{}
|
f := &testFixtures{}
|
||||||
f.clock = clockwork.NewFakeClock()
|
f.clock = clockwork.NewFakeClock()
|
||||||
|
|
||||||
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
dbMap := db.NewMemDB()
|
||||||
{
|
f.ur = func() user.UserRepo {
|
||||||
User: user.User{
|
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||||
ID: "ID-1",
|
{
|
||||||
Email: "Email-1@example.com",
|
User: user.User{
|
||||||
},
|
ID: "ID-1",
|
||||||
RemoteIdentities: []user.RemoteIdentity{
|
Email: "Email-1@example.com",
|
||||||
{
|
},
|
||||||
ConnectorID: "local",
|
RemoteIdentities: []user.RemoteIdentity{
|
||||||
ID: "1",
|
{
|
||||||
|
ConnectorID: "local",
|
||||||
|
ID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
User: user.User{
|
||||||
|
ID: "ID-2",
|
||||||
|
Email: "Email-2@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
},
|
||||||
|
RemoteIdentities: []user.RemoteIdentity{
|
||||||
|
{
|
||||||
|
ConnectorID: "local",
|
||||||
|
ID: "2",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, {
|
})
|
||||||
User: user.User{
|
if err != nil {
|
||||||
ID: "ID-2",
|
panic("Failed to create user repo: " + err.Error())
|
||||||
Email: "Email-2@example.com",
|
}
|
||||||
EmailVerified: true,
|
return repo
|
||||||
},
|
}()
|
||||||
RemoteIdentities: []user.RemoteIdentity{
|
|
||||||
{
|
|
||||||
ConnectorID: "local",
|
|
||||||
ID: "2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||||
{
|
{
|
||||||
UserID: "ID-1",
|
UserID: "ID-1",
|
||||||
|
@ -65,7 +73,7 @@ func makeTestFixtures() *testFixtures {
|
||||||
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||||
&connector.LocalConnectorConfig{ID: "local"},
|
&connector.LocalConnectorConfig{ID: "local"},
|
||||||
})
|
})
|
||||||
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
|
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, db.TransactionFactory(dbMap), ManagerOptions{})
|
||||||
f.mgr.Clock = f.clock
|
f.mgr.Clock = f.clock
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
254
user/user.go
254
user/user.go
|
@ -4,13 +4,10 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"net/mail"
|
"net/mail"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
"github.com/pborman/uuid"
|
"github.com/pborman/uuid"
|
||||||
|
@ -172,262 +169,11 @@ func ValidPassword(plaintext string) bool {
|
||||||
return len(plaintext) > 5
|
return len(plaintext) > 5
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserRepo returns an in-memory UserRepo useful for development.
|
|
||||||
func NewUserRepo() UserRepo {
|
|
||||||
return &memUserRepo{
|
|
||||||
usersByID: make(map[string]User),
|
|
||||||
userIDsByEmail: make(map[string]string),
|
|
||||||
userIDsByRemoteID: make(map[RemoteIdentity]string),
|
|
||||||
remoteIDsByUserID: make(map[string]map[RemoteIdentity]struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type memUserRepo struct {
|
|
||||||
usersByID map[string]User
|
|
||||||
userIDsByEmail map[string]string
|
|
||||||
userIDsByRemoteID map[RemoteIdentity]string
|
|
||||||
remoteIDsByUserID map[string]map[RemoteIdentity]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) Get(_ repo.Transaction, id string) (User, error) {
|
|
||||||
user, ok := r.usersByID[id]
|
|
||||||
if !ok {
|
|
||||||
return User{}, ErrorNotFound
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type usersByEmail []User
|
|
||||||
|
|
||||||
func (s usersByEmail) Len() int { return len(s) }
|
|
||||||
func (s usersByEmail) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
func (s usersByEmail) Less(i, j int) bool { return s[i].Email < s[j].Email }
|
|
||||||
|
|
||||||
func (r *memUserRepo) List(tx repo.Transaction, filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) {
|
|
||||||
var offset int
|
|
||||||
var err error
|
|
||||||
if nextPageToken != "" {
|
|
||||||
filter, maxResults, offset, err = DecodeNextPageToken(nextPageToken)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
users := []User{}
|
|
||||||
for _, usr := range r.usersByID {
|
|
||||||
users = append(users, usr)
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Sort(usersByEmail(users))
|
|
||||||
|
|
||||||
high := offset + maxResults
|
|
||||||
|
|
||||||
var tok string
|
|
||||||
if high >= len(users) {
|
|
||||||
high = len(users)
|
|
||||||
} else {
|
|
||||||
tok, err = EncodeNextPageToken(filter, maxResults, high)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users[offset:high]) == 0 {
|
|
||||||
return nil, "", ErrorNotFound
|
|
||||||
}
|
|
||||||
return users[offset:high], tok, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) GetByEmail(tx repo.Transaction, email string) (User, error) {
|
|
||||||
userID, ok := r.userIDsByEmail[email]
|
|
||||||
if !ok {
|
|
||||||
return User{}, ErrorNotFound
|
|
||||||
}
|
|
||||||
return r.Get(tx, userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) Create(_ repo.Transaction, user User) error {
|
|
||||||
if user.ID == "" {
|
|
||||||
return ErrorInvalidID
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ValidEmail(user.Email) {
|
|
||||||
return ErrorInvalidEmail
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure no one has the same ID; if using UUID the chances of this
|
|
||||||
// happening are astronomically small.
|
|
||||||
_, ok := r.usersByID[user.ID]
|
|
||||||
if ok {
|
|
||||||
return ErrorDuplicateID
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure there's no other user with the same Email
|
|
||||||
_, ok = r.userIDsByEmail[user.Email]
|
|
||||||
if ok {
|
|
||||||
return ErrorDuplicateEmail
|
|
||||||
}
|
|
||||||
|
|
||||||
r.set(user)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) Update(_ repo.Transaction, user User) error {
|
|
||||||
if user.ID == "" {
|
|
||||||
return ErrorInvalidID
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ValidEmail(user.Email) {
|
|
||||||
return ErrorInvalidEmail
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure this user exists already
|
|
||||||
_, ok := r.usersByID[user.ID]
|
|
||||||
if !ok {
|
|
||||||
return ErrorNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure there's no other user with the same Email
|
|
||||||
otherID, ok := r.userIDsByEmail[user.Email]
|
|
||||||
if ok && otherID != user.ID {
|
|
||||||
return ErrorDuplicateEmail
|
|
||||||
}
|
|
||||||
|
|
||||||
r.set(user)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error {
|
|
||||||
if id == "" {
|
|
||||||
return ErrorInvalidID
|
|
||||||
}
|
|
||||||
user, ok := r.usersByID[id]
|
|
||||||
if !ok {
|
|
||||||
return ErrorNotFound
|
|
||||||
}
|
|
||||||
user.Disabled = disable
|
|
||||||
r.set(user)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
|
|
||||||
_, ok := r.usersByID[userID]
|
|
||||||
if !ok {
|
|
||||||
return ErrorNotFound
|
|
||||||
}
|
|
||||||
_, ok = r.userIDsByRemoteID[ri]
|
|
||||||
if ok {
|
|
||||||
return ErrorDuplicateRemoteIdentity
|
|
||||||
}
|
|
||||||
|
|
||||||
r.userIDsByRemoteID[ri] = userID
|
|
||||||
rIDs, ok := r.remoteIDsByUserID[userID]
|
|
||||||
if !ok {
|
|
||||||
rIDs = make(map[RemoteIdentity]struct{})
|
|
||||||
r.remoteIDsByUserID[userID] = rIDs
|
|
||||||
}
|
|
||||||
|
|
||||||
rIDs[ri] = struct{}{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) RemoveRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
|
|
||||||
otherID, ok := r.userIDsByRemoteID[ri]
|
|
||||||
if !ok {
|
|
||||||
return ErrorNotFound
|
|
||||||
}
|
|
||||||
if otherID != userID {
|
|
||||||
return ErrorNotFound
|
|
||||||
}
|
|
||||||
delete(r.userIDsByRemoteID, ri)
|
|
||||||
delete(r.remoteIDsByUserID[userID], ri)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) GetByRemoteIdentity(_ repo.Transaction, ri RemoteIdentity) (User, error) {
|
|
||||||
userID, ok := r.userIDsByRemoteID[ri]
|
|
||||||
if !ok {
|
|
||||||
return User{}, ErrorNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
user, ok := r.usersByID[userID]
|
|
||||||
if !ok {
|
|
||||||
return User{}, ErrorNotFound
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) GetRemoteIdentities(_ repo.Transaction, userID string) ([]RemoteIdentity, error) {
|
|
||||||
ids := []RemoteIdentity{}
|
|
||||||
for id := range r.remoteIDsByUserID[userID] {
|
|
||||||
ids = append(ids, id)
|
|
||||||
}
|
|
||||||
return ids, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) GetAdminCount(_ repo.Transaction) (int, error) {
|
|
||||||
var i int
|
|
||||||
for _, usr := range r.usersByID {
|
|
||||||
if usr.Admin {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *memUserRepo) set(user User) error {
|
|
||||||
r.usersByID[user.ID] = user
|
|
||||||
r.userIDsByEmail[user.Email] = user.ID
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserWithRemoteIdentities struct {
|
type UserWithRemoteIdentities struct {
|
||||||
User User `json:"user"`
|
User User `json:"user"`
|
||||||
RemoteIdentities []RemoteIdentity `json:"remoteIdentities"`
|
RemoteIdentities []RemoteIdentity `json:"remoteIdentities"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserRepoFromFile returns an in-memory UserRepo useful for development given a JSON serialized file of Users.
|
|
||||||
func NewUserRepoFromFile(loc string) (UserRepo, error) {
|
|
||||||
us, err := readUsersFromFile(loc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewUserRepoFromUsers(us), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUserRepoFromUsers(us []UserWithRemoteIdentities) UserRepo {
|
|
||||||
memUserRepo := NewUserRepo().(*memUserRepo)
|
|
||||||
for _, u := range us {
|
|
||||||
memUserRepo.set(u.User)
|
|
||||||
for _, ri := range u.RemoteIdentities {
|
|
||||||
memUserRepo.AddRemoteIdentity(nil, u.User.ID, ri)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return memUserRepo
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUsersFromReader(r io.Reader) ([]UserWithRemoteIdentities, error) {
|
|
||||||
var us []UserWithRemoteIdentities
|
|
||||||
err := json.NewDecoder(r).Decode(&us)
|
|
||||||
return us, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func readUsersFromFile(loc string) ([]UserWithRemoteIdentities, error) {
|
|
||||||
uf, err := os.Open(loc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to read users from file %q: %v", loc, err)
|
|
||||||
}
|
|
||||||
defer uf.Close()
|
|
||||||
|
|
||||||
us, err := newUsersFromReader(uf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return us, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *User) UnmarshalJSON(data []byte) error {
|
func (u *User) UnmarshalJSON(data []byte) error {
|
||||||
var dec struct {
|
var dec struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
|
|
|
@ -2,7 +2,6 @@ package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
@ -10,44 +9,6 @@ import (
|
||||||
"github.com/coreos/go-oidc/jose"
|
"github.com/coreos/go-oidc/jose"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewUsersFromReader(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
json string
|
|
||||||
want []UserWithRemoteIdentities
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
json: `[{"user":{"id":"12345", "displayName": "Elroy Canis", "email":"elroy23@example.com"}, "remoteIdentities":[{"connectorID":"google", "id":"elroy@example.com"}] }]`,
|
|
||||||
want: []UserWithRemoteIdentities{
|
|
||||||
{
|
|
||||||
User: User{
|
|
||||||
ID: "12345",
|
|
||||||
DisplayName: "Elroy Canis",
|
|
||||||
Email: "elroy23@example.com",
|
|
||||||
},
|
|
||||||
RemoteIdentities: []RemoteIdentity{
|
|
||||||
{
|
|
||||||
ConnectorID: "google",
|
|
||||||
ID: "elroy@example.com",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
r := strings.NewReader(tt.json)
|
|
||||||
us, err := newUsersFromReader(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: want nil err: %v", i, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if diff := pretty.Compare(tt.want, us); diff != "" {
|
|
||||||
t.Errorf("case %d: Compare(want, got): %v", i, diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddToClaims(t *testing.T) {
|
func TestAddToClaims(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
user User
|
user User
|
||||||
|
|
Loading…
Add table
Reference in a new issue