*: remove in memory user repo

This commit is contained in:
Eric Chiang 2016-02-09 12:22:40 -08:00
parent 95560404a3
commit 2726f4dcdf
13 changed files with 167 additions and 480 deletions

View file

@ -4,7 +4,7 @@ import (
"testing"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/db"
"github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
@ -22,7 +22,9 @@ type testFixtures struct {
func makeTestFixtures() *testFixtures {
f := &testFixtures{}
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
dbMap := db.NewMemDB()
f.ur = func() user.UserRepo {
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
{
User: user.User{
ID: "ID-1",
@ -38,6 +40,12 @@ func makeTestFixtures() *testFixtures {
},
},
})
if err != nil {
panic("Failed to create user repo: " + err.Error())
}
return repo
}()
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
{
UserID: "ID-1",
@ -47,7 +55,7 @@ func makeTestFixtures() *testFixtures {
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
f.adAPI = NewAdminAPI(f.mgr, f.ur, f.pwr, "local")
return f

View file

@ -7,6 +7,7 @@ import (
"testing"
"time"
"github.com/go-gorp/gorp"
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/db"
@ -49,10 +50,12 @@ func newUserRepo(t *testing.T, users []user.UserWithRemoteIdentities) user.UserR
if users == nil {
users = []user.UserWithRemoteIdentities{}
}
var dbMap *gorp.DbMap
if os.Getenv("DEX_TEST_DSN") == "" {
return user.NewUserRepoFromUsers(users)
dbMap = db.NewMemDB()
} else {
dbMap = connect(t)
}
dbMap := connect(t)
repo, err := db.NewUserRepoFromUsers(dbMap, users)
if err != nil {
t.Fatalf("Unable to add users: %v", err)
@ -416,59 +419,6 @@ func findRemoteIdentity(rids []user.RemoteIdentity, rid user.RemoteIdentity) int
return -1
}
func TestNewUserRepoFromUsers(t *testing.T) {
tests := []struct {
users []user.UserWithRemoteIdentities
}{
{
users: []user.UserWithRemoteIdentities{
{
User: user.User{
ID: "123",
Email: "email123@example.com",
},
RemoteIdentities: []user.RemoteIdentity{},
},
{
User: user.User{
ID: "456",
Email: "email456@example.com",
},
RemoteIdentities: []user.RemoteIdentity{
{
ID: "remoteID",
ConnectorID: "connID",
},
},
},
},
},
}
for i, tt := range tests {
repo := user.NewUserRepoFromUsers(tt.users)
for _, want := range tt.users {
gotUser, err := repo.Get(nil, want.User.ID)
if err != nil {
t.Errorf("case %d: want nil err: %v", i, err)
}
gotRIDs, err := repo.GetRemoteIdentities(nil, want.User.ID)
if err != nil {
t.Errorf("case %d: want nil err: %v", i, err)
}
if !reflect.DeepEqual(want.User, gotUser) {
t.Errorf("case %d: want=%#v got=%#v", i, want.User, gotUser)
}
if !reflect.DeepEqual(want.RemoteIdentities, gotRIDs) {
t.Errorf("case %d: want=%#v got=%#v", i, want.RemoteIdentities, gotRIDs)
}
}
}
}
func TestGetByEmail(t *testing.T) {
tests := []struct {
email string

View file

@ -11,7 +11,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/db"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
)
@ -45,13 +45,20 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro
}
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) {
ur := user.NewUserRepoFromUsers(users)
dbMap := db.NewMemDB()
ur := func() user.UserRepo {
repo, err := db.NewUserRepoFromUsers(dbMap, users)
if err != nil {
panic("Failed to create user repo: " + err.Error())
}
return repo
}()
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
ccr := connector.NewConnectorConfigRepoFromConfigs(
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
)
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
um := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
um.Clock = clock
return ur, pwr, um
}

View file

@ -139,7 +139,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
Email: "testemail@example.com",
DisplayName: "displayname",
}
userRepo := user.NewUserRepo()
userRepo := db.NewUserRepo(db.NewMemDB())
if err := userRepo.Create(nil, usr); err != nil {
t.Fatalf("Unexpected error: %v", err)
}

View file

@ -1,7 +1,5 @@
package repo
import "errors"
// Transaction is an abstraction of transactions typically found in database systems.
// One of Commit() or Rollback() must be called on each transaction.
type Transaction interface {
@ -13,29 +11,3 @@ type Transaction interface {
}
type TransactionFactory func() (Transaction, error)
// InMemTransaction satisifies the Transaction interface for in-memory systems.
// However, the only thing it really does is ensure that the same transaction is
// can't be committed/rolled back more than once. As such, this can lead to data
// corruption and should not be used in production systems.
type InMemTransaction bool
func InMemTransactionFactory() (Transaction, error) {
return new(InMemTransaction), nil
}
func (i *InMemTransaction) Commit() error {
return i.commitOrRollback()
}
func (i *InMemTransaction) Rollback() error {
return i.commitOrRollback()
}
func (i *InMemTransaction) commitOrRollback() error {
if *i {
return errors.New("Already committed/rolled-back.")
}
*i = true
return nil
}

View file

@ -1,6 +1,7 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"html/template"
@ -17,7 +18,6 @@ import (
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email"
"github.com/coreos/dex/repo"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
@ -100,6 +100,8 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
return err
}
dbMap := db.NewMemDB()
ks := key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(24*time.Hour))
kRepo := key.NewPrivateKeySetRepo()
if err = kRepo.Set(ks); err != nil {
@ -127,20 +129,24 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
}
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
sRepo := db.NewSessionRepo(db.NewMemDB())
skRepo := db.NewSessionKeyRepo(db.NewMemDB())
sRepo := db.NewSessionRepo(dbMap)
skRepo := db.NewSessionKeyRepo(dbMap)
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
users, err := loadUsers(cfg.UsersFile)
if err != nil {
return fmt.Errorf("unable to read users from file: %v", err)
}
userRepo, err := db.NewUserRepoFromUsers(dbMap, users)
if err != nil {
return err
}
pwiRepo := user.NewPasswordInfoRepo()
refTokRepo := db.NewRefreshTokenRepo(db.NewMemDB())
refTokRepo := db.NewRefreshTokenRepo(dbMap)
txnFactory := repo.InMemTransactionFactory
txnFactory := db.TransactionFactory(dbMap)
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo
@ -154,6 +160,16 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
}
func loadUsers(filepath string) (users []user.UserWithRemoteIdentities, err error) {
f, err := os.Open(filepath)
if err != nil {
return
}
defer f.Close()
err = json.NewDecoder(f).Decode(&users)
return
}
func (cfg *MultiServerConfig) Configure(srv *Server) error {
if len(cfg.KeySecrets) == 0 {
return errors.New("missing key secret")

View file

@ -76,7 +76,7 @@ func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
}
func makeNewUserRepo() (user.UserRepo, error) {
userRepo := user.NewUserRepo()
userRepo := db.NewUserRepo(db.NewMemDB())
id := "testid-1"
err := userRepo.Create(nil, user.User{

View file

@ -12,7 +12,6 @@ import (
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email"
"github.com/coreos/dex/repo"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
@ -91,7 +90,11 @@ func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
}
func makeTestFixtures() (*testFixtures, error) {
userRepo := user.NewUserRepoFromUsers(testUsers)
dbMap := db.NewMemDB()
userRepo, err := db.NewUserRepoFromUsers(dbMap, testUsers)
if err != nil {
return nil, err
}
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
connConfigs := []connector.ConnectorConfig{
@ -114,7 +117,7 @@ func makeTestFixtures() (*testFixtures, error) {
}
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{})
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sessionManager.GenerateCode = sequentialGenerateCodeFunc()

View file

@ -11,7 +11,7 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
@ -86,7 +86,9 @@ var (
)
func makeTestFixtures() (*UsersAPI, *testEmailer) {
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
dbMap := db.NewMemDB()
ur := func() user.UserRepo {
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
{
User: user.User{
ID: "ID-1",
@ -115,6 +117,12 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
},
},
})
if err != nil {
panic("Failed to create user repo: " + err.Error())
}
return repo
}()
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
{
UserID: "ID-1",
@ -128,7 +136,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
mgr := manager.NewUserManager(ur, pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
mgr.Clock = clock
ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{

View file

@ -12,6 +12,7 @@ import (
"github.com/coreos/go-oidc/key"
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email"
"github.com/coreos/dex/user"
)
@ -45,7 +46,8 @@ func (t *testEmailer) SendMail(from, subject, text, html string, to ...string) e
}
func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
ur := user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
ur := func() user.UserRepo {
repo, err := db.NewUserRepoFromUsers(db.NewMemDB(), []user.UserWithRemoteIdentities{
{
User: user.User{
ID: "ID-1",
@ -64,6 +66,12 @@ func makeTestFixtures() (*UserEmailer, *testEmailer, *key.PublicKey) {
},
},
})
if err != nil {
panic("Failed to create user repo: " + err.Error())
}
return repo
}()
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
{
UserID: "ID-1",

View file

@ -10,7 +10,7 @@ import (
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/db"
"github.com/coreos/dex/user"
)
@ -26,7 +26,9 @@ func makeTestFixtures() *testFixtures {
f := &testFixtures{}
f.clock = clockwork.NewFakeClock()
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
dbMap := db.NewMemDB()
f.ur = func() user.UserRepo {
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{
{
User: user.User{
ID: "ID-1",
@ -52,6 +54,12 @@ func makeTestFixtures() *testFixtures {
},
},
})
if err != nil {
panic("Failed to create user repo: " + err.Error())
}
return repo
}()
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
{
UserID: "ID-1",
@ -65,7 +73,7 @@ func makeTestFixtures() *testFixtures {
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, db.TransactionFactory(dbMap), ManagerOptions{})
f.mgr.Clock = f.clock
return f
}

View file

@ -4,13 +4,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"time"
"net/mail"
"net/url"
"os"
"sort"
"github.com/jonboulle/clockwork"
"github.com/pborman/uuid"
@ -172,262 +169,11 @@ func ValidPassword(plaintext string) bool {
return len(plaintext) > 5
}
// NewUserRepo returns an in-memory UserRepo useful for development.
func NewUserRepo() UserRepo {
return &memUserRepo{
usersByID: make(map[string]User),
userIDsByEmail: make(map[string]string),
userIDsByRemoteID: make(map[RemoteIdentity]string),
remoteIDsByUserID: make(map[string]map[RemoteIdentity]struct{}),
}
}
type memUserRepo struct {
usersByID map[string]User
userIDsByEmail map[string]string
userIDsByRemoteID map[RemoteIdentity]string
remoteIDsByUserID map[string]map[RemoteIdentity]struct{}
}
func (r *memUserRepo) Get(_ repo.Transaction, id string) (User, error) {
user, ok := r.usersByID[id]
if !ok {
return User{}, ErrorNotFound
}
return user, nil
}
type usersByEmail []User
func (s usersByEmail) Len() int { return len(s) }
func (s usersByEmail) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s usersByEmail) Less(i, j int) bool { return s[i].Email < s[j].Email }
func (r *memUserRepo) List(tx repo.Transaction, filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) {
var offset int
var err error
if nextPageToken != "" {
filter, maxResults, offset, err = DecodeNextPageToken(nextPageToken)
}
if err != nil {
return nil, "", err
}
users := []User{}
for _, usr := range r.usersByID {
users = append(users, usr)
}
sort.Sort(usersByEmail(users))
high := offset + maxResults
var tok string
if high >= len(users) {
high = len(users)
} else {
tok, err = EncodeNextPageToken(filter, maxResults, high)
}
if err != nil {
return nil, "", err
}
if len(users[offset:high]) == 0 {
return nil, "", ErrorNotFound
}
return users[offset:high], tok, nil
}
func (r *memUserRepo) GetByEmail(tx repo.Transaction, email string) (User, error) {
userID, ok := r.userIDsByEmail[email]
if !ok {
return User{}, ErrorNotFound
}
return r.Get(tx, userID)
}
func (r *memUserRepo) Create(_ repo.Transaction, user User) error {
if user.ID == "" {
return ErrorInvalidID
}
if !ValidEmail(user.Email) {
return ErrorInvalidEmail
}
// make sure no one has the same ID; if using UUID the chances of this
// happening are astronomically small.
_, ok := r.usersByID[user.ID]
if ok {
return ErrorDuplicateID
}
// make sure there's no other user with the same Email
_, ok = r.userIDsByEmail[user.Email]
if ok {
return ErrorDuplicateEmail
}
r.set(user)
return nil
}
func (r *memUserRepo) Update(_ repo.Transaction, user User) error {
if user.ID == "" {
return ErrorInvalidID
}
if !ValidEmail(user.Email) {
return ErrorInvalidEmail
}
// make sure this user exists already
_, ok := r.usersByID[user.ID]
if !ok {
return ErrorNotFound
}
// make sure there's no other user with the same Email
otherID, ok := r.userIDsByEmail[user.Email]
if ok && otherID != user.ID {
return ErrorDuplicateEmail
}
r.set(user)
return nil
}
func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error {
if id == "" {
return ErrorInvalidID
}
user, ok := r.usersByID[id]
if !ok {
return ErrorNotFound
}
user.Disabled = disable
r.set(user)
return nil
}
func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
_, ok := r.usersByID[userID]
if !ok {
return ErrorNotFound
}
_, ok = r.userIDsByRemoteID[ri]
if ok {
return ErrorDuplicateRemoteIdentity
}
r.userIDsByRemoteID[ri] = userID
rIDs, ok := r.remoteIDsByUserID[userID]
if !ok {
rIDs = make(map[RemoteIdentity]struct{})
r.remoteIDsByUserID[userID] = rIDs
}
rIDs[ri] = struct{}{}
return nil
}
func (r *memUserRepo) RemoveRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
otherID, ok := r.userIDsByRemoteID[ri]
if !ok {
return ErrorNotFound
}
if otherID != userID {
return ErrorNotFound
}
delete(r.userIDsByRemoteID, ri)
delete(r.remoteIDsByUserID[userID], ri)
return nil
}
func (r *memUserRepo) GetByRemoteIdentity(_ repo.Transaction, ri RemoteIdentity) (User, error) {
userID, ok := r.userIDsByRemoteID[ri]
if !ok {
return User{}, ErrorNotFound
}
user, ok := r.usersByID[userID]
if !ok {
return User{}, ErrorNotFound
}
return user, nil
}
func (r *memUserRepo) GetRemoteIdentities(_ repo.Transaction, userID string) ([]RemoteIdentity, error) {
ids := []RemoteIdentity{}
for id := range r.remoteIDsByUserID[userID] {
ids = append(ids, id)
}
return ids, nil
}
func (r *memUserRepo) GetAdminCount(_ repo.Transaction) (int, error) {
var i int
for _, usr := range r.usersByID {
if usr.Admin {
i++
}
}
return i, nil
}
func (r *memUserRepo) set(user User) error {
r.usersByID[user.ID] = user
r.userIDsByEmail[user.Email] = user.ID
return nil
}
type UserWithRemoteIdentities struct {
User User `json:"user"`
RemoteIdentities []RemoteIdentity `json:"remoteIdentities"`
}
// NewUserRepoFromFile returns an in-memory UserRepo useful for development given a JSON serialized file of Users.
func NewUserRepoFromFile(loc string) (UserRepo, error) {
us, err := readUsersFromFile(loc)
if err != nil {
return nil, err
}
return NewUserRepoFromUsers(us), nil
}
func NewUserRepoFromUsers(us []UserWithRemoteIdentities) UserRepo {
memUserRepo := NewUserRepo().(*memUserRepo)
for _, u := range us {
memUserRepo.set(u.User)
for _, ri := range u.RemoteIdentities {
memUserRepo.AddRemoteIdentity(nil, u.User.ID, ri)
}
}
return memUserRepo
}
func newUsersFromReader(r io.Reader) ([]UserWithRemoteIdentities, error) {
var us []UserWithRemoteIdentities
err := json.NewDecoder(r).Decode(&us)
return us, err
}
func readUsersFromFile(loc string) ([]UserWithRemoteIdentities, error) {
uf, err := os.Open(loc)
if err != nil {
return nil, fmt.Errorf("unable to read users from file %q: %v", loc, err)
}
defer uf.Close()
us, err := newUsersFromReader(uf)
if err != nil {
return nil, err
}
return us, err
}
func (u *User) UnmarshalJSON(data []byte) error {
var dec struct {
ID string `json:"id"`

View file

@ -2,7 +2,6 @@ package user
import (
"reflect"
"strings"
"testing"
"github.com/kylelemons/godebug/pretty"
@ -10,44 +9,6 @@ import (
"github.com/coreos/go-oidc/jose"
)
func TestNewUsersFromReader(t *testing.T) {
tests := []struct {
json string
want []UserWithRemoteIdentities
}{
{
json: `[{"user":{"id":"12345", "displayName": "Elroy Canis", "email":"elroy23@example.com"}, "remoteIdentities":[{"connectorID":"google", "id":"elroy@example.com"}] }]`,
want: []UserWithRemoteIdentities{
{
User: User{
ID: "12345",
DisplayName: "Elroy Canis",
Email: "elroy23@example.com",
},
RemoteIdentities: []RemoteIdentity{
{
ConnectorID: "google",
ID: "elroy@example.com",
},
},
},
},
},
}
for i, tt := range tests {
r := strings.NewReader(tt.json)
us, err := newUsersFromReader(r)
if err != nil {
t.Errorf("case %d: want nil err: %v", i, err)
continue
}
if diff := pretty.Compare(tt.want, us); diff != "" {
t.Errorf("case %d: Compare(want, got): %v", i, diff)
}
}
}
func TestAddToClaims(t *testing.T) {
tests := []struct {
user User