diff --git a/admin/api_test.go b/admin/api_test.go index d885908f..5a4ba827 100644 --- a/admin/api_test.go +++ b/admin/api_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/coreos/dex/connector" - "github.com/coreos/dex/repo" + "github.com/coreos/dex/db" "github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/user" "github.com/coreos/dex/user/manager" @@ -22,22 +22,30 @@ type testFixtures struct { func makeTestFixtures() *testFixtures { f := &testFixtures{} - f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ - { - User: user.User{ - ID: "ID-1", - Email: "email-1@example.com", - DisplayName: "Name-1", + dbMap := db.NewMemDB() + f.ur = func() user.UserRepo { + repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{ + { + User: user.User{ + ID: "ID-1", + Email: "email-1@example.com", + DisplayName: "Name-1", + }, }, - }, - { - User: user.User{ - ID: "ID-2", - Email: "email-2@example.com", - DisplayName: "Name-2", + { + User: user.User{ + ID: "ID-2", + Email: "email-2@example.com", + DisplayName: "Name-2", + }, }, - }, - }) + }) + if err != nil { + panic("Failed to create user repo: " + err.Error()) + } + return repo + }() + f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ { UserID: "ID-1", @@ -47,7 +55,7 @@ func makeTestFixtures() *testFixtures { ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{ &connector.LocalConnectorConfig{ID: "local"}, }) - f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{}) + f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{}) f.adAPI = NewAdminAPI(f.mgr, f.ur, f.pwr, "local") return f diff --git a/functional/repo/user_repo_test.go b/functional/repo/user_repo_test.go index ecca2d22..3fa699df 100644 --- a/functional/repo/user_repo_test.go +++ b/functional/repo/user_repo_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/go-gorp/gorp" "github.com/kylelemons/godebug/pretty" "github.com/coreos/dex/db" @@ -49,10 +50,12 @@ func newUserRepo(t *testing.T, users []user.UserWithRemoteIdentities) user.UserR if users == nil { users = []user.UserWithRemoteIdentities{} } + var dbMap *gorp.DbMap if os.Getenv("DEX_TEST_DSN") == "" { - return user.NewUserRepoFromUsers(users) + dbMap = db.NewMemDB() + } else { + dbMap = connect(t) } - dbMap := connect(t) repo, err := db.NewUserRepoFromUsers(dbMap, users) if err != nil { t.Fatalf("Unable to add users: %v", err) @@ -416,59 +419,6 @@ func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int return -1 } -func TestNewUserRepoFromUsers(t *testing.T) { - tests := []struct { - users []user.UserWithRemoteIdentities - }{ - { - users: []user.UserWithRemoteIdentities{ - { - User: user.User{ - ID: "123", - Email: "email123@example.com", - }, - RemoteIdentities: []user.RemoteIdentity{}, - }, - { - User: user.User{ - ID: "456", - Email: "email456@example.com", - }, - RemoteIdentities: []user.RemoteIdentity{ - { - ID: "remoteID", - ConnectorID: "connID", - }, - }, - }, - }, - }, - } - - for i, tt := range tests { - repo := user.NewUserRepoFromUsers(tt.users) - for _, want := range tt.users { - gotUser, err := repo.Get(nil, want.User.ID) - if err != nil { - t.Errorf("case %d: want nil err: %v", i, err) - } - - gotRIDs, err := repo.GetRemoteIdentities(nil, want.User.ID) - if err != nil { - t.Errorf("case %d: want nil err: %v", i, err) - } - - if !reflect.DeepEqual(want.User, gotUser) { - t.Errorf("case %d: want=%#v got=%#v", i, want.User, gotUser) - } - - if !reflect.DeepEqual(want.RemoteIdentities, gotRIDs) { - t.Errorf("case %d: want=%#v got=%#v", i, want.RemoteIdentities, gotRIDs) - } - } - } -} - func TestGetByEmail(t *testing.T) { tests := []struct { email string diff --git a/integration/common_test.go b/integration/common_test.go index a203cbf0..058f7968 100644 --- a/integration/common_test.go +++ b/integration/common_test.go @@ -11,7 +11,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/coreos/dex/connector" - "github.com/coreos/dex/repo" + "github.com/coreos/dex/db" "github.com/coreos/dex/user" "github.com/coreos/dex/user/manager" ) @@ -45,13 +45,20 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro } func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) { - ur := user.NewUserRepoFromUsers(users) + dbMap := db.NewMemDB() + ur := func() user.UserRepo { + repo, err := db.NewUserRepoFromUsers(dbMap, users) + if err != nil { + panic("Failed to create user repo: " + err.Error()) + } + return repo + }() pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords) ccr := connector.NewConnectorConfigRepoFromConfigs( []connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}}, ) - um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{}) + um := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{}) um.Clock = clock return ur, pwr, um } diff --git a/integration/oidc_test.go b/integration/oidc_test.go index 4eefcea0..d4d32f62 100644 --- a/integration/oidc_test.go +++ b/integration/oidc_test.go @@ -139,7 +139,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { Email: "testemail@example.com", DisplayName: "displayname", } - userRepo := user.NewUserRepo() + userRepo := db.NewUserRepo(db.NewMemDB()) if err := userRepo.Create(nil, usr); err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/repo/repo.go b/repo/repo.go index b764d58c..3dd74603 100644 --- a/repo/repo.go +++ b/repo/repo.go @@ -1,7 +1,5 @@ package repo -import "errors" - // Transaction is an abstraction of transactions typically found in database systems. // One of Commit() or Rollback() must be called on each transaction. type Transaction interface { @@ -13,29 +11,3 @@ type Transaction interface { } type TransactionFactory func() (Transaction, error) - -// InMemTransaction satisifies the Transaction interface for in-memory systems. -// However, the only thing it really does is ensure that the same transaction is -// can't be committed/rolled back more than once. As such, this can lead to data -// corruption and should not be used in production systems. -type InMemTransaction bool - -func InMemTransactionFactory() (Transaction, error) { - return new(InMemTransaction), nil -} - -func (i *InMemTransaction) Commit() error { - return i.commitOrRollback() -} - -func (i *InMemTransaction) Rollback() error { - return i.commitOrRollback() -} - -func (i *InMemTransaction) commitOrRollback() error { - if *i { - return errors.New("Already committed/rolled-back.") - } - *i = true - return nil -} diff --git a/server/config.go b/server/config.go index 60ad64da..2e4ae35b 100644 --- a/server/config.go +++ b/server/config.go @@ -1,6 +1,7 @@ package server import ( + "encoding/json" "errors" "fmt" "html/template" @@ -17,7 +18,6 @@ import ( "github.com/coreos/dex/connector" "github.com/coreos/dex/db" "github.com/coreos/dex/email" - "github.com/coreos/dex/repo" sessionmanager "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" useremail "github.com/coreos/dex/user/email" @@ -100,6 +100,8 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { return err } + dbMap := db.NewMemDB() + ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour)) kRepo := key.NewPrivateKeySetRepo() if err = kRepo.Set(ks); err != nil { @@ -127,20 +129,24 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { } cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs) - sRepo := db.NewSessionRepo(db.NewMemDB()) - skRepo := db.NewSessionKeyRepo(db.NewMemDB()) + sRepo := db.NewSessionRepo(dbMap) + skRepo := db.NewSessionKeyRepo(dbMap) sm := sessionmanager.NewSessionManager(sRepo, skRepo) - userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile) + users, err := loadUsers(cfg.UsersFile) if err != nil { return fmt.Errorf("unable to read users from file: %v", err) } + userRepo, err := db.NewUserRepoFromUsers(dbMap, users) + if err != nil { + return err + } pwiRepo := user.NewPasswordInfoRepo() - refTokRepo := db.NewRefreshTokenRepo(db.NewMemDB()) + refTokRepo := db.NewRefreshTokenRepo(dbMap) - txnFactory := repo.InMemTransactionFactory + txnFactory := db.TransactionFactory(dbMap) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{}) srv.ClientIdentityRepo = ciRepo srv.KeySetRepo = kRepo @@ -154,6 +160,16 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { } +func loadUsers(filepath string) (users []user.UserWithRemoteIdentities, err error) { + f, err := os.Open(filepath) + if err != nil { + return + } + defer f.Close() + err = json.NewDecoder(f).Decode(&users) + return +} + func (cfg *MultiServerConfig) Configure(srv *Server) error { if len(cfg.KeySecrets) == 0 { return errors.New("missing key secret") diff --git a/server/server_test.go b/server/server_test.go index 8fef717a..d3051339 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -76,7 +76,7 @@ func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc { } func makeNewUserRepo() (user.UserRepo, error) { - userRepo := user.NewUserRepo() + userRepo := db.NewUserRepo(db.NewMemDB()) id := "testid-1" err := userRepo.Create(nil, user.User{ diff --git a/server/testutil.go b/server/testutil.go index c61bbfea..579fdf04 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -12,7 +12,6 @@ import ( "github.com/coreos/dex/connector" "github.com/coreos/dex/db" "github.com/coreos/dex/email" - "github.com/coreos/dex/repo" sessionmanager "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" useremail "github.com/coreos/dex/user/email" @@ -91,7 +90,11 @@ func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc { } func makeTestFixtures() (*testFixtures, error) { - userRepo := user.NewUserRepoFromUsers(testUsers) + dbMap := db.NewMemDB() + userRepo, err := db.NewUserRepoFromUsers(dbMap, testUsers) + if err != nil { + return nil, err + } pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos) connConfigs := []connector.ConnectorConfig{ @@ -114,7 +117,7 @@ func makeTestFixtures() (*testFixtures, error) { } connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs) - manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{}) + manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{}) sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sessionManager.GenerateCode = sequentialGenerateCodeFunc() diff --git a/user/api/api_test.go b/user/api/api_test.go index 404fa9d8..d85aade2 100644 --- a/user/api/api_test.go +++ b/user/api/api_test.go @@ -11,7 +11,7 @@ import ( "github.com/coreos/dex/client" "github.com/coreos/dex/connector" - "github.com/coreos/dex/repo" + "github.com/coreos/dex/db" schema "github.com/coreos/dex/schema/workerschema" "github.com/coreos/dex/user" "github.com/coreos/dex/user/manager" @@ -86,35 +86,43 @@ var ( ) func makeTestFixtures() (*UsersAPI, *testEmailer) { - ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ - { - User: user.User{ - ID: "ID-1", - Email: "id1@example.com", - Admin: true, - CreatedAt: clock.Now(), + dbMap := db.NewMemDB() + ur := func() user.UserRepo { + repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{ + { + User: user.User{ + ID: "ID-1", + Email: "id1@example.com", + Admin: true, + CreatedAt: clock.Now(), + }, + }, { + User: user.User{ + ID: "ID-2", + Email: "id2@example.com", + CreatedAt: clock.Now(), + }, + }, { + User: user.User{ + ID: "ID-3", + Email: "id3@example.com", + CreatedAt: clock.Now(), + }, + }, { + User: user.User{ + ID: "ID-4", + Email: "id4@example.com", + CreatedAt: clock.Now(), + Disabled: true, + }, }, - }, { - User: user.User{ - ID: "ID-2", - Email: "id2@example.com", - CreatedAt: clock.Now(), - }, - }, { - User: user.User{ - ID: "ID-3", - Email: "id3@example.com", - CreatedAt: clock.Now(), - }, - }, { - User: user.User{ - ID: "ID-4", - Email: "id4@example.com", - CreatedAt: clock.Now(), - Disabled: true, - }, - }, - }) + }) + if err != nil { + panic("Failed to create user repo: " + err.Error()) + } + return repo + }() + pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ { UserID: "ID-1", @@ -128,7 +136,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{ &connector.LocalConnectorConfig{ID: "local"}, }) - mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{}) + mgr := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{}) mgr.Clock = clock ci := oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ diff --git a/user/email/email_test.go b/user/email/email_test.go index b533160c..eb40642f 100644 --- a/user/email/email_test.go +++ b/user/email/email_test.go @@ -12,6 +12,7 @@ import ( "github.com/coreos/go-oidc/key" "github.com/kylelemons/godebug/pretty" + "github.com/coreos/dex/db" "github.com/coreos/dex/email" "github.com/coreos/dex/user" ) @@ -45,25 +46,32 @@ func (t *testEmailer) SendMail(from, subject, text, html string, to ...string) e } func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) { - ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ - { - User: user.User{ - ID: "ID-1", - Email: "id1@example.com", - Admin: true, + ur := func() user.UserRepo { + repo, err := db.NewUserRepoFromUsers(db.NewMemDB(), []user.UserWithRemoteIdentities{ + { + User: user.User{ + ID: "ID-1", + Email: "id1@example.com", + Admin: true, + }, + }, { + User: user.User{ + ID: "ID-2", + Email: "id2@example.com", + }, + }, { + User: user.User{ + ID: "ID-3", + Email: "id3@example.com", + }, }, - }, { - User: user.User{ - ID: "ID-2", - Email: "id2@example.com", - }, - }, { - User: user.User{ - ID: "ID-3", - Email: "id3@example.com", - }, - }, - }) + }) + if err != nil { + panic("Failed to create user repo: " + err.Error()) + } + return repo + }() + pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ { UserID: "ID-1", diff --git a/user/manager/manager_test.go b/user/manager/manager_test.go index fbe0a4a3..337f790e 100644 --- a/user/manager/manager_test.go +++ b/user/manager/manager_test.go @@ -10,7 +10,7 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/coreos/dex/connector" - "github.com/coreos/dex/repo" + "github.com/coreos/dex/db" "github.com/coreos/dex/user" ) @@ -26,32 +26,40 @@ func makeTestFixtures() *testFixtures { f := &testFixtures{} f.clock = clockwork.NewFakeClock() - f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{ - { - User: user.User{ - ID: "ID-1", - Email: "Email-1@example.com", - }, - RemoteIdentities: []user.RemoteIdentity{ - { - ConnectorID: "local", - ID: "1", + dbMap := db.NewMemDB() + f.ur = func() user.UserRepo { + repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{ + { + User: user.User{ + ID: "ID-1", + Email: "Email-1@example.com", + }, + RemoteIdentities: []user.RemoteIdentity{ + { + ConnectorID: "local", + ID: "1", + }, + }, + }, { + User: user.User{ + ID: "ID-2", + Email: "Email-2@example.com", + EmailVerified: true, + }, + RemoteIdentities: []user.RemoteIdentity{ + { + ConnectorID: "local", + ID: "2", + }, }, }, - }, { - User: user.User{ - ID: "ID-2", - Email: "Email-2@example.com", - EmailVerified: true, - }, - RemoteIdentities: []user.RemoteIdentity{ - { - ConnectorID: "local", - ID: "2", - }, - }, - }, - }) + }) + if err != nil { + panic("Failed to create user repo: " + err.Error()) + } + return repo + }() + f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ { UserID: "ID-1", @@ -65,7 +73,7 @@ func makeTestFixtures() *testFixtures { f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{ &connector.LocalConnectorConfig{ID: "local"}, }) - f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{}) + f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, db.TransactionFactory(dbMap), ManagerOptions{}) f.mgr.Clock = f.clock return f } diff --git a/user/user.go b/user/user.go index 5e706e1b..979ee590 100644 --- a/user/user.go +++ b/user/user.go @@ -4,13 +4,10 @@ import ( "encoding/json" "errors" "fmt" - "io" "time" "net/mail" "net/url" - "os" - "sort" "github.com/jonboulle/clockwork" "github.com/pborman/uuid" @@ -172,262 +169,11 @@ func ValidPassword(plaintext string) bool { return len(plaintext) > 5 } -// NewUserRepo returns an in-memory UserRepo useful for development. -func NewUserRepo() UserRepo { - return &memUserRepo{ - usersByID: make(map[string]User), - userIDsByEmail: make(map[string]string), - userIDsByRemoteID: make(map[RemoteIdentity]string), - remoteIDsByUserID: make(map[string]map[RemoteIdentity]struct{}), - } -} - -type memUserRepo struct { - usersByID map[string]User - userIDsByEmail map[string]string - userIDsByRemoteID map[RemoteIdentity]string - remoteIDsByUserID map[string]map[RemoteIdentity]struct{} -} - -func (r *memUserRepo) Get(_ repo.Transaction, id string) (User, error) { - user, ok := r.usersByID[id] - if !ok { - return User{}, ErrorNotFound - } - return user, nil -} - -type usersByEmail []User - -func (s usersByEmail) Len() int { return len(s) } -func (s usersByEmail) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s usersByEmail) Less(i, j int) bool { return s[i].Email < s[j].Email } - -func (r *memUserRepo) List(tx repo.Transaction, filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) { - var offset int - var err error - if nextPageToken != "" { - filter, maxResults, offset, err = DecodeNextPageToken(nextPageToken) - } - if err != nil { - return nil, "", err - } - - users := []User{} - for _, usr := range r.usersByID { - users = append(users, usr) - } - - sort.Sort(usersByEmail(users)) - - high := offset + maxResults - - var tok string - if high >= len(users) { - high = len(users) - } else { - tok, err = EncodeNextPageToken(filter, maxResults, high) - } - - if err != nil { - return nil, "", err - } - - if len(users[offset:high]) == 0 { - return nil, "", ErrorNotFound - } - return users[offset:high], tok, nil -} - -func (r *memUserRepo) GetByEmail(tx repo.Transaction, email string) (User, error) { - userID, ok := r.userIDsByEmail[email] - if !ok { - return User{}, ErrorNotFound - } - return r.Get(tx, userID) -} - -func (r *memUserRepo) Create(_ repo.Transaction, user User) error { - if user.ID == "" { - return ErrorInvalidID - } - - if !ValidEmail(user.Email) { - return ErrorInvalidEmail - } - - // make sure no one has the same ID; if using UUID the chances of this - // happening are astronomically small. - _, ok := r.usersByID[user.ID] - if ok { - return ErrorDuplicateID - } - - // make sure there's no other user with the same Email - _, ok = r.userIDsByEmail[user.Email] - if ok { - return ErrorDuplicateEmail - } - - r.set(user) - return nil -} - -func (r *memUserRepo) Update(_ repo.Transaction, user User) error { - if user.ID == "" { - return ErrorInvalidID - } - - if !ValidEmail(user.Email) { - return ErrorInvalidEmail - } - - // make sure this user exists already - _, ok := r.usersByID[user.ID] - if !ok { - return ErrorNotFound - } - - // make sure there's no other user with the same Email - otherID, ok := r.userIDsByEmail[user.Email] - if ok && otherID != user.ID { - return ErrorDuplicateEmail - } - - r.set(user) - return nil -} - -func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error { - if id == "" { - return ErrorInvalidID - } - user, ok := r.usersByID[id] - if !ok { - return ErrorNotFound - } - user.Disabled = disable - r.set(user) - return nil -} - -func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error { - _, ok := r.usersByID[userID] - if !ok { - return ErrorNotFound - } - _, ok = r.userIDsByRemoteID[ri] - if ok { - return ErrorDuplicateRemoteIdentity - } - - r.userIDsByRemoteID[ri] = userID - rIDs, ok := r.remoteIDsByUserID[userID] - if !ok { - rIDs = make(map[RemoteIdentity]struct{}) - r.remoteIDsByUserID[userID] = rIDs - } - - rIDs[ri] = struct{}{} - return nil -} - -func (r *memUserRepo) RemoveRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error { - otherID, ok := r.userIDsByRemoteID[ri] - if !ok { - return ErrorNotFound - } - if otherID != userID { - return ErrorNotFound - } - delete(r.userIDsByRemoteID, ri) - delete(r.remoteIDsByUserID[userID], ri) - return nil -} - -func (r *memUserRepo) GetByRemoteIdentity(_ repo.Transaction, ri RemoteIdentity) (User, error) { - userID, ok := r.userIDsByRemoteID[ri] - if !ok { - return User{}, ErrorNotFound - } - - user, ok := r.usersByID[userID] - if !ok { - return User{}, ErrorNotFound - } - return user, nil -} - -func (r *memUserRepo) GetRemoteIdentities(_ repo.Transaction, userID string) ([]RemoteIdentity, error) { - ids := []RemoteIdentity{} - for id := range r.remoteIDsByUserID[userID] { - ids = append(ids, id) - } - return ids, nil -} - -func (r *memUserRepo) GetAdminCount(_ repo.Transaction) (int, error) { - var i int - for _, usr := range r.usersByID { - if usr.Admin { - i++ - } - } - return i, nil -} - -func (r *memUserRepo) set(user User) error { - r.usersByID[user.ID] = user - r.userIDsByEmail[user.Email] = user.ID - return nil -} - type UserWithRemoteIdentities struct { User User `json:"user"` RemoteIdentities []RemoteIdentity `json:"remoteIdentities"` } -// NewUserRepoFromFile returns an in-memory UserRepo useful for development given a JSON serialized file of Users. -func NewUserRepoFromFile(loc string) (UserRepo, error) { - us, err := readUsersFromFile(loc) - if err != nil { - return nil, err - } - return NewUserRepoFromUsers(us), nil -} - -func NewUserRepoFromUsers(us []UserWithRemoteIdentities) UserRepo { - memUserRepo := NewUserRepo().(*memUserRepo) - for _, u := range us { - memUserRepo.set(u.User) - for _, ri := range u.RemoteIdentities { - memUserRepo.AddRemoteIdentity(nil, u.User.ID, ri) - } - } - return memUserRepo -} - -func newUsersFromReader(r io.Reader) ([]UserWithRemoteIdentities, error) { - var us []UserWithRemoteIdentities - err := json.NewDecoder(r).Decode(&us) - return us, err -} - -func readUsersFromFile(loc string) ([]UserWithRemoteIdentities, error) { - uf, err := os.Open(loc) - if err != nil { - return nil, fmt.Errorf("unable to read users from file %q: %v", loc, err) - } - defer uf.Close() - - us, err := newUsersFromReader(uf) - if err != nil { - return nil, err - } - - return us, err -} - func (u *User) UnmarshalJSON(data []byte) error { var dec struct { ID string `json:"id"` diff --git a/user/user_test.go b/user/user_test.go index eac81588..17b91726 100644 --- a/user/user_test.go +++ b/user/user_test.go @@ -2,7 +2,6 @@ package user import ( "reflect" - "strings" "testing" "github.com/kylelemons/godebug/pretty" @@ -10,44 +9,6 @@ import ( "github.com/coreos/go-oidc/jose" ) -func TestNewUsersFromReader(t *testing.T) { - tests := []struct { - json string - want []UserWithRemoteIdentities - }{ - { - json: `[{"user":{"id":"12345", "displayName": "Elroy Canis", "email":"elroy23@example.com"}, "remoteIdentities":[{"connectorID":"google", "id":"elroy@example.com"}] }]`, - want: []UserWithRemoteIdentities{ - { - User: User{ - ID: "12345", - DisplayName: "Elroy Canis", - Email: "elroy23@example.com", - }, - RemoteIdentities: []RemoteIdentity{ - { - ConnectorID: "google", - ID: "elroy@example.com", - }, - }, - }, - }, - }, - } - - for i, tt := range tests { - r := strings.NewReader(tt.json) - us, err := newUsersFromReader(r) - if err != nil { - t.Errorf("case %d: want nil err: %v", i, err) - continue - } - if diff := pretty.Compare(tt.want, us); diff != "" { - t.Errorf("case %d: Compare(want, got): %v", i, diff) - } - } -} - func TestAddToClaims(t *testing.T) { tests := []struct { user User