forked from mystiq/dex
*: remove in memory user repo
This commit is contained in:
parent
95560404a3
commit
2726f4dcdf
13 changed files with 167 additions and 480 deletions
|
@ -4,7 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/schema/adminschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
|
@ -22,7 +22,9 @@ type testFixtures struct {
|
|||
func makeTestFixtures() *testFixtures {
|
||||
f := &testFixtures{}
|
||||
|
||||
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
||||
dbMap := db.NewMemDB()
|
||||
f.ur = func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
|
@ -38,6 +40,12 @@ func makeTestFixtures() *testFixtures {
|
|||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
{
|
||||
UserID: "ID-1",
|
||||
|
@ -47,7 +55,7 @@ func makeTestFixtures() *testFixtures {
|
|||
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||
f.adAPI = NewAdminAPI(f.mgr, f.ur, f.pwr, "local")
|
||||
|
||||
return f
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
|
@ -49,10 +50,12 @@ func newUserRepo(t *testing.T, users []user.UserWithRemoteIdentities) user.UserR
|
|||
if users == nil {
|
||||
users = []user.UserWithRemoteIdentities{}
|
||||
}
|
||||
var dbMap *gorp.DbMap
|
||||
if os.Getenv("DEX_TEST_DSN") == "" {
|
||||
return user.NewUserRepoFromUsers(users)
|
||||
dbMap = db.NewMemDB()
|
||||
} else {
|
||||
dbMap = connect(t)
|
||||
}
|
||||
dbMap := connect(t)
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to add users: %v", err)
|
||||
|
@ -416,59 +419,6 @@ func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int
|
|||
return -1
|
||||
}
|
||||
|
||||
func TestNewUserRepoFromUsers(t *testing.T) {
|
||||
tests := []struct {
|
||||
users []user.UserWithRemoteIdentities
|
||||
}{
|
||||
{
|
||||
users: []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "123",
|
||||
Email: "email123@example.com",
|
||||
},
|
||||
RemoteIdentities: []user.RemoteIdentity{},
|
||||
},
|
||||
{
|
||||
User: user.User{
|
||||
ID: "456",
|
||||
Email: "email456@example.com",
|
||||
},
|
||||
RemoteIdentities: []user.RemoteIdentity{
|
||||
{
|
||||
ID: "remoteID",
|
||||
ConnectorID: "connID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := user.NewUserRepoFromUsers(tt.users)
|
||||
for _, want := range tt.users {
|
||||
gotUser, err := repo.Get(nil, want.User.ID)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want nil err: %v", i, err)
|
||||
}
|
||||
|
||||
gotRIDs, err := repo.GetRemoteIdentities(nil, want.User.ID)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want nil err: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want.User, gotUser) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, want.User, gotUser)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want.RemoteIdentities, gotRIDs) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, want.RemoteIdentities, gotRIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetByEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
email string
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
@ -45,13 +45,20 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro
|
|||
}
|
||||
|
||||
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) {
|
||||
ur := user.NewUserRepoFromUsers(users)
|
||||
dbMap := db.NewMemDB()
|
||||
ur := func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
|
||||
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
|
||||
)
|
||||
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
um := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||
um.Clock = clock
|
||||
return ur, pwr, um
|
||||
}
|
||||
|
|
|
@ -139,7 +139,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
Email: "testemail@example.com",
|
||||
DisplayName: "displayname",
|
||||
}
|
||||
userRepo := user.NewUserRepo()
|
||||
userRepo := db.NewUserRepo(db.NewMemDB())
|
||||
if err := userRepo.Create(nil, usr); err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
|
28
repo/repo.go
28
repo/repo.go
|
@ -1,7 +1,5 @@
|
|||
package repo
|
||||
|
||||
import "errors"
|
||||
|
||||
// Transaction is an abstraction of transactions typically found in database systems.
|
||||
// One of Commit() or Rollback() must be called on each transaction.
|
||||
type Transaction interface {
|
||||
|
@ -13,29 +11,3 @@ type Transaction interface {
|
|||
}
|
||||
|
||||
type TransactionFactory func() (Transaction, error)
|
||||
|
||||
// InMemTransaction satisifies the Transaction interface for in-memory systems.
|
||||
// However, the only thing it really does is ensure that the same transaction is
|
||||
// can't be committed/rolled back more than once. As such, this can lead to data
|
||||
// corruption and should not be used in production systems.
|
||||
type InMemTransaction bool
|
||||
|
||||
func InMemTransactionFactory() (Transaction, error) {
|
||||
return new(InMemTransaction), nil
|
||||
}
|
||||
|
||||
func (i *InMemTransaction) Commit() error {
|
||||
return i.commitOrRollback()
|
||||
}
|
||||
|
||||
func (i *InMemTransaction) Rollback() error {
|
||||
return i.commitOrRollback()
|
||||
}
|
||||
|
||||
func (i *InMemTransaction) commitOrRollback() error {
|
||||
if *i {
|
||||
return errors.New("Already committed/rolled-back.")
|
||||
}
|
||||
*i = true
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
|
@ -17,7 +18,6 @@ import (
|
|||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
"github.com/coreos/dex/repo"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
|
@ -100,6 +100,8 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
return err
|
||||
}
|
||||
|
||||
dbMap := db.NewMemDB()
|
||||
|
||||
ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour))
|
||||
kRepo := key.NewPrivateKeySetRepo()
|
||||
if err = kRepo.Set(ks); err != nil {
|
||||
|
@ -127,20 +129,24 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
}
|
||||
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
|
||||
|
||||
sRepo := db.NewSessionRepo(db.NewMemDB())
|
||||
skRepo := db.NewSessionKeyRepo(db.NewMemDB())
|
||||
sRepo := db.NewSessionRepo(dbMap)
|
||||
skRepo := db.NewSessionKeyRepo(dbMap)
|
||||
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||
|
||||
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
|
||||
users, err := loadUsers(cfg.UsersFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read users from file: %v", err)
|
||||
}
|
||||
userRepo, err := db.NewUserRepoFromUsers(dbMap, users)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pwiRepo := user.NewPasswordInfoRepo()
|
||||
|
||||
refTokRepo := db.NewRefreshTokenRepo(db.NewMemDB())
|
||||
refTokRepo := db.NewRefreshTokenRepo(dbMap)
|
||||
|
||||
txnFactory := repo.InMemTransactionFactory
|
||||
txnFactory := db.TransactionFactory(dbMap)
|
||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
||||
srv.ClientIdentityRepo = ciRepo
|
||||
srv.KeySetRepo = kRepo
|
||||
|
@ -154,6 +160,16 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
|
||||
}
|
||||
|
||||
func loadUsers(filepath string) (users []user.UserWithRemoteIdentities, err error) {
|
||||
f, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
err = json.NewDecoder(f).Decode(&users)
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
||||
if len(cfg.KeySecrets) == 0 {
|
||||
return errors.New("missing key secret")
|
||||
|
|
|
@ -76,7 +76,7 @@ func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
|
|||
}
|
||||
|
||||
func makeNewUserRepo() (user.UserRepo, error) {
|
||||
userRepo := user.NewUserRepo()
|
||||
userRepo := db.NewUserRepo(db.NewMemDB())
|
||||
|
||||
id := "testid-1"
|
||||
err := userRepo.Create(nil, user.User{
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
"github.com/coreos/dex/repo"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
|
@ -91,7 +90,11 @@ func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
|
|||
}
|
||||
|
||||
func makeTestFixtures() (*testFixtures, error) {
|
||||
userRepo := user.NewUserRepoFromUsers(testUsers)
|
||||
dbMap := db.NewMemDB()
|
||||
userRepo, err := db.NewUserRepoFromUsers(dbMap, testUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
|
||||
|
||||
connConfigs := []connector.ConnectorConfig{
|
||||
|
@ -114,7 +117,7 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
}
|
||||
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
||||
|
||||
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{})
|
||||
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
|
||||
|
||||
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/db"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
|
@ -86,7 +86,9 @@ var (
|
|||
)
|
||||
|
||||
func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
||||
dbMap := db.NewMemDB()
|
||||
ur := func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
|
@ -115,6 +117,12 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
{
|
||||
UserID: "ID-1",
|
||||
|
@ -128,7 +136,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
mgr := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||
mgr.Clock = clock
|
||||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
"github.com/coreos/dex/user"
|
||||
)
|
||||
|
@ -45,7 +46,8 @@ func (t *testEmailer) SendMail(from, subject, text, html string, to ...string) e
|
|||
}
|
||||
|
||||
func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
|
||||
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
||||
ur := func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(db.NewMemDB(), []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
|
@ -64,6 +66,12 @@ func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
|
|||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
{
|
||||
UserID: "ID-1",
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/user"
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,9 @@ func makeTestFixtures() *testFixtures {
|
|||
f := &testFixtures{}
|
||||
f.clock = clockwork.NewFakeClock()
|
||||
|
||||
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
||||
dbMap := db.NewMemDB()
|
||||
f.ur = func() user.UserRepo {
|
||||
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
|
@ -52,6 +54,12 @@ func makeTestFixtures() *testFixtures {
|
|||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to create user repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
{
|
||||
UserID: "ID-1",
|
||||
|
@ -65,7 +73,7 @@ func makeTestFixtures() *testFixtures {
|
|||
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
|
||||
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, db.TransactionFactory(dbMap), ManagerOptions{})
|
||||
f.mgr.Clock = f.clock
|
||||
return f
|
||||
}
|
||||
|
|
254
user/user.go
254
user/user.go
|
@ -4,13 +4,10 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/pborman/uuid"
|
||||
|
@ -172,262 +169,11 @@ func ValidPassword(plaintext string) bool {
|
|||
return len(plaintext) > 5
|
||||
}
|
||||
|
||||
// NewUserRepo returns an in-memory UserRepo useful for development.
|
||||
func NewUserRepo() UserRepo {
|
||||
return &memUserRepo{
|
||||
usersByID: make(map[string]User),
|
||||
userIDsByEmail: make(map[string]string),
|
||||
userIDsByRemoteID: make(map[RemoteIdentity]string),
|
||||
remoteIDsByUserID: make(map[string]map[RemoteIdentity]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
type memUserRepo struct {
|
||||
usersByID map[string]User
|
||||
userIDsByEmail map[string]string
|
||||
userIDsByRemoteID map[RemoteIdentity]string
|
||||
remoteIDsByUserID map[string]map[RemoteIdentity]struct{}
|
||||
}
|
||||
|
||||
func (r *memUserRepo) Get(_ repo.Transaction, id string) (User, error) {
|
||||
user, ok := r.usersByID[id]
|
||||
if !ok {
|
||||
return User{}, ErrorNotFound
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
type usersByEmail []User
|
||||
|
||||
func (s usersByEmail) Len() int { return len(s) }
|
||||
func (s usersByEmail) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s usersByEmail) Less(i, j int) bool { return s[i].Email < s[j].Email }
|
||||
|
||||
func (r *memUserRepo) List(tx repo.Transaction, filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) {
|
||||
var offset int
|
||||
var err error
|
||||
if nextPageToken != "" {
|
||||
filter, maxResults, offset, err = DecodeNextPageToken(nextPageToken)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
users := []User{}
|
||||
for _, usr := range r.usersByID {
|
||||
users = append(users, usr)
|
||||
}
|
||||
|
||||
sort.Sort(usersByEmail(users))
|
||||
|
||||
high := offset + maxResults
|
||||
|
||||
var tok string
|
||||
if high >= len(users) {
|
||||
high = len(users)
|
||||
} else {
|
||||
tok, err = EncodeNextPageToken(filter, maxResults, high)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if len(users[offset:high]) == 0 {
|
||||
return nil, "", ErrorNotFound
|
||||
}
|
||||
return users[offset:high], tok, nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) GetByEmail(tx repo.Transaction, email string) (User, error) {
|
||||
userID, ok := r.userIDsByEmail[email]
|
||||
if !ok {
|
||||
return User{}, ErrorNotFound
|
||||
}
|
||||
return r.Get(tx, userID)
|
||||
}
|
||||
|
||||
func (r *memUserRepo) Create(_ repo.Transaction, user User) error {
|
||||
if user.ID == "" {
|
||||
return ErrorInvalidID
|
||||
}
|
||||
|
||||
if !ValidEmail(user.Email) {
|
||||
return ErrorInvalidEmail
|
||||
}
|
||||
|
||||
// make sure no one has the same ID; if using UUID the chances of this
|
||||
// happening are astronomically small.
|
||||
_, ok := r.usersByID[user.ID]
|
||||
if ok {
|
||||
return ErrorDuplicateID
|
||||
}
|
||||
|
||||
// make sure there's no other user with the same Email
|
||||
_, ok = r.userIDsByEmail[user.Email]
|
||||
if ok {
|
||||
return ErrorDuplicateEmail
|
||||
}
|
||||
|
||||
r.set(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) Update(_ repo.Transaction, user User) error {
|
||||
if user.ID == "" {
|
||||
return ErrorInvalidID
|
||||
}
|
||||
|
||||
if !ValidEmail(user.Email) {
|
||||
return ErrorInvalidEmail
|
||||
}
|
||||
|
||||
// make sure this user exists already
|
||||
_, ok := r.usersByID[user.ID]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
|
||||
// make sure there's no other user with the same Email
|
||||
otherID, ok := r.userIDsByEmail[user.Email]
|
||||
if ok && otherID != user.ID {
|
||||
return ErrorDuplicateEmail
|
||||
}
|
||||
|
||||
r.set(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error {
|
||||
if id == "" {
|
||||
return ErrorInvalidID
|
||||
}
|
||||
user, ok := r.usersByID[id]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
user.Disabled = disable
|
||||
r.set(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
|
||||
_, ok := r.usersByID[userID]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
_, ok = r.userIDsByRemoteID[ri]
|
||||
if ok {
|
||||
return ErrorDuplicateRemoteIdentity
|
||||
}
|
||||
|
||||
r.userIDsByRemoteID[ri] = userID
|
||||
rIDs, ok := r.remoteIDsByUserID[userID]
|
||||
if !ok {
|
||||
rIDs = make(map[RemoteIdentity]struct{})
|
||||
r.remoteIDsByUserID[userID] = rIDs
|
||||
}
|
||||
|
||||
rIDs[ri] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) RemoveRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
|
||||
otherID, ok := r.userIDsByRemoteID[ri]
|
||||
if !ok {
|
||||
return ErrorNotFound
|
||||
}
|
||||
if otherID != userID {
|
||||
return ErrorNotFound
|
||||
}
|
||||
delete(r.userIDsByRemoteID, ri)
|
||||
delete(r.remoteIDsByUserID[userID], ri)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) GetByRemoteIdentity(_ repo.Transaction, ri RemoteIdentity) (User, error) {
|
||||
userID, ok := r.userIDsByRemoteID[ri]
|
||||
if !ok {
|
||||
return User{}, ErrorNotFound
|
||||
}
|
||||
|
||||
user, ok := r.usersByID[userID]
|
||||
if !ok {
|
||||
return User{}, ErrorNotFound
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) GetRemoteIdentities(_ repo.Transaction, userID string) ([]RemoteIdentity, error) {
|
||||
ids := []RemoteIdentity{}
|
||||
for id := range r.remoteIDsByUserID[userID] {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) GetAdminCount(_ repo.Transaction) (int, error) {
|
||||
var i int
|
||||
for _, usr := range r.usersByID {
|
||||
if usr.Admin {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (r *memUserRepo) set(user User) error {
|
||||
r.usersByID[user.ID] = user
|
||||
r.userIDsByEmail[user.Email] = user.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
type UserWithRemoteIdentities struct {
|
||||
User User `json:"user"`
|
||||
RemoteIdentities []RemoteIdentity `json:"remoteIdentities"`
|
||||
}
|
||||
|
||||
// NewUserRepoFromFile returns an in-memory UserRepo useful for development given a JSON serialized file of Users.
|
||||
func NewUserRepoFromFile(loc string) (UserRepo, error) {
|
||||
us, err := readUsersFromFile(loc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewUserRepoFromUsers(us), nil
|
||||
}
|
||||
|
||||
func NewUserRepoFromUsers(us []UserWithRemoteIdentities) UserRepo {
|
||||
memUserRepo := NewUserRepo().(*memUserRepo)
|
||||
for _, u := range us {
|
||||
memUserRepo.set(u.User)
|
||||
for _, ri := range u.RemoteIdentities {
|
||||
memUserRepo.AddRemoteIdentity(nil, u.User.ID, ri)
|
||||
}
|
||||
}
|
||||
return memUserRepo
|
||||
}
|
||||
|
||||
func newUsersFromReader(r io.Reader) ([]UserWithRemoteIdentities, error) {
|
||||
var us []UserWithRemoteIdentities
|
||||
err := json.NewDecoder(r).Decode(&us)
|
||||
return us, err
|
||||
}
|
||||
|
||||
func readUsersFromFile(loc string) ([]UserWithRemoteIdentities, error) {
|
||||
uf, err := os.Open(loc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read users from file %q: %v", loc, err)
|
||||
}
|
||||
defer uf.Close()
|
||||
|
||||
us, err := newUsersFromReader(uf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return us, err
|
||||
}
|
||||
|
||||
func (u *User) UnmarshalJSON(data []byte) error {
|
||||
var dec struct {
|
||||
ID string `json:"id"`
|
||||
|
|
|
@ -2,7 +2,6 @@ package user
|
|||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
@ -10,44 +9,6 @@ import (
|
|||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
func TestNewUsersFromReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
json string
|
||||
want []UserWithRemoteIdentities
|
||||
}{
|
||||
{
|
||||
json: `[{"user":{"id":"12345", "displayName": "Elroy Canis", "email":"elroy23@example.com"}, "remoteIdentities":[{"connectorID":"google", "id":"elroy@example.com"}] }]`,
|
||||
want: []UserWithRemoteIdentities{
|
||||
{
|
||||
User: User{
|
||||
ID: "12345",
|
||||
DisplayName: "Elroy Canis",
|
||||
Email: "elroy23@example.com",
|
||||
},
|
||||
RemoteIdentities: []RemoteIdentity{
|
||||
{
|
||||
ConnectorID: "google",
|
||||
ID: "elroy@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r := strings.NewReader(tt.json)
|
||||
us, err := newUsersFromReader(r)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want nil err: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if diff := pretty.Compare(tt.want, us); diff != "" {
|
||||
t.Errorf("case %d: Compare(want, got): %v", i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddToClaims(t *testing.T) {
|
||||
tests := []struct {
|
||||
user User
|
||||
|
|
Loading…
Reference in a new issue