Merge pull request #199 from ericchiang/validate_connector

api: validate local connector existence before creating user
This commit is contained in:
bobbyrullo 2015-12-07 17:44:22 -08:00
commit 521aeae3db
20 changed files with 317 additions and 134 deletions

View file

@ -6,17 +6,18 @@ import (
"github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
)
// AdminAPI provides the logic necessary to implement the Admin API.
type AdminAPI struct {
userManager *user.Manager
userManager *manager.UserManager
userRepo user.UserRepo
passwordInfoRepo user.PasswordInfoRepo
localConnectorID string
}
func NewAdminAPI(userManager *user.Manager, userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, localConnectorID string) *AdminAPI {
func NewAdminAPI(userManager *manager.UserManager, userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, localConnectorID string) *AdminAPI {
if localConnectorID == "" {
panic("must specify non-blank localConnectorID")
}

View file

@ -17,7 +17,7 @@ import (
"github.com/coreos/dex/pkg/log"
ptime "github.com/coreos/dex/pkg/time"
"github.com/coreos/dex/server"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
)
var version = "DEV"
@ -99,8 +99,9 @@ func main() {
userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := user.NewManager(userRepo,
pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{})
connCfgRepo := db.NewConnectorConfigRepo(dbc)
userManager := manager.NewUserManager(userRepo,
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
if err != nil {

View file

@ -4,6 +4,8 @@ import (
"encoding/json"
"io"
"os"
"github.com/coreos/dex/repo"
)
func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) {
@ -41,6 +43,19 @@ type memConnectorConfigRepo struct {
configs []ConnectorConfig
}
func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo {
return &memConnectorConfigRepo{configs: cfgs}
}
func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
return r.configs, nil
}
func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) {
for _, cfg := range r.configs {
if cfg.ConnectorID() == id {
return cfg, nil
}
}
return nil, ErrorNotFound
}

View file

@ -1,14 +1,18 @@
package connector
import (
"errors"
"html/template"
"net/http"
"net/url"
"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/pkg/health"
)
var ErrorNotFound = errors.New("connector not found in repository")
type Connector interface {
ID() string
LoginURL(sessionKey, prompt string) (string, error)
@ -34,4 +38,5 @@ type ConnectorConfig interface {
type ConnectorConfigRepo interface {
All() ([]ConnectorConfig, error)
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
}

View file

@ -1,6 +1,7 @@
package db
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
@ -9,6 +10,7 @@ import (
"github.com/lib/pq"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
)
const (
@ -91,6 +93,18 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
return cfgs, nil
}
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
var c connectorConfigModel
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
if err == sql.ErrNoRows {
return nil, connector.ErrorNotFound
}
}
return c.ConnectorConfig()
}
func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
insert := make([]interface{}, len(cfgs))
for i, cfg := range cfgs {
@ -119,3 +133,15 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
return tx.Commit()
}
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}

View file

@ -0,0 +1,71 @@
package repo
import (
"fmt"
"os"
"testing"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
)
type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo
var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory
func init() {
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
} else {
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn)
}
}
func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory {
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
dbMap := initDB(dsn)
repo := db.NewConnectorConfigRepo(dbMap)
if err := repo.Set(cfgs); err != nil {
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
}
return repo
}
}
func TestConnectorConfigRepoGetByID(t *testing.T) {
tests := []struct {
cfgs []connector.ConnectorConfig
id string
err error
}{
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
},
id: "local",
},
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local1"},
&connector.LocalConnectorConfig{ID: "local2"},
},
id: "local2",
},
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local1"},
&connector.LocalConnectorConfig{ID: "local2"},
},
id: "foo",
err: connector.ErrorNotFound,
},
}
for i, tt := range tests {
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs)
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
}
}
}

View file

@ -10,8 +10,10 @@ import (
"github.com/coreos/go-oidc/key"
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
)
var (
@ -42,11 +44,14 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro
return &resp, nil
}
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *user.Manager) {
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) {
ur := user.NewUserRepoFromUsers(users)
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
um := user.NewManager(ur, pwr, repo.InMemTransactionFactory, user.ManagerOptions{})
ccr := connector.NewConnectorConfigRepoFromConfigs(
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
)
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
um.Clock = clock
return ur, pwr, um
}

View file

@ -22,6 +22,7 @@ import (
"github.com/coreos/dex/session"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
)
type ServerConfig struct {
@ -133,7 +134,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
refTokRepo := refresh.NewRefreshTokenRepo()
txnFactory := repo.InMemTransactionFactory
userManager := user.NewManager(userRepo, pwiRepo, txnFactory, user.ManagerOptions{})
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo
@ -171,7 +172,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
cfgRepo := db.NewConnectorConfigRepo(dbc)
userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := user.NewManager(userRepo, pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{})
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := session.NewSessionManager(sRepo, skRepo)

View file

@ -15,6 +15,7 @@ import (
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
)
// handleVerifyEmailResendFunc will resend an email-verification email given a valid JWT for the user and a redirect URL.
@ -190,7 +191,7 @@ type emailVerifiedTemplateData struct {
}
func handleEmailVerifyFunc(verifiedTpl *template.Template, issuer url.URL, keysFunc func() ([]key.PublicKey,
error), userManager *user.Manager) http.HandlerFunc {
error), userManager *manager.UserManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
@ -217,12 +218,12 @@ func handleEmailVerifyFunc(verifiedTpl *template.Template, issuer url.URL, keysF
cbURL, err := userManager.VerifyEmail(ev)
if err != nil {
switch err {
case user.ErrorEmailAlreadyVerified:
case manager.ErrorEmailAlreadyVerified:
execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{
Error: "Invalid Verification Link",
Message: "Your email link has expired or has already been verified.",
}, http.StatusBadRequest)
case user.ErrorEVEmailDoesntMatch:
case manager.ErrorEVEmailDoesntMatch:
execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{
Error: "Invalid Verification Link",
Message: "Your email link does not match the email address on file. Perhaps you have a more recent verification link?",

View file

@ -7,6 +7,7 @@ import (
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
)
@ -18,7 +19,7 @@ type invitationTemplateData struct {
type InvitationHandler struct {
issuerURL url.URL
passwordResetURL url.URL
um *user.Manager
um *manager.UserManager
keysFunc func() ([]key.PublicKey, error)
signerFunc func() (jose.Signer, error)
redirectValidityWindow time.Duration
@ -55,13 +56,13 @@ func (h *InvitationHandler) handleGET(w http.ResponseWriter, r *http.Request) {
}
_, err = h.um.VerifyEmail(invite)
if err != nil && err != user.ErrorEmailAlreadyVerified {
if err != nil && err != manager.ErrorEmailAlreadyVerified {
// Allow AlreadyVerified folks to pass through- otherwise
// folks who encounter an error after passing this point will
// never be able to set their passwords.
log.Debugf("error attempting to verify email: %v", err)
switch err {
case user.ErrorEVEmailDoesntMatch:
case manager.ErrorEVEmailDoesntMatch:
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidRequest,
"Your email does not match the email address on file"))
return

View file

@ -12,6 +12,7 @@ import (
"github.com/coreos/dex/session"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
)
type sendResetPasswordEmailData struct {
@ -181,7 +182,7 @@ type resetPasswordTemplateData struct {
type ResetPasswordHandler struct {
tpl *template.Template
issuerURL url.URL
um *user.Manager
um *manager.UserManager
keysFunc func() ([]key.PublicKey, error)
}
@ -237,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
if err != nil {
switch err {
case user.ErrorPasswordAlreadyChanged:
case manager.ErrorPasswordAlreadyChanged:
r.data.Error = "Link Expired"
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
r.data.DontShowForm = true

View file

@ -11,6 +11,7 @@ import (
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
"github.com/coreos/go-oidc/oidc"
)
@ -222,7 +223,7 @@ func handleRegisterFunc(s *Server) http.HandlerFunc {
}
}
func registerFromLocalConnector(userManager *user.Manager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
if err != nil {
return "", err
@ -237,7 +238,7 @@ func registerFromLocalConnector(userManager *user.Manager, sessionManager *sessi
return userID, nil
}
func registerFromRemoteConnector(userManager *user.Manager, ses *session.Session, email string, emailVerified bool) (string, error) {
func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
if ses.Identity.ID == "" {
return "", errors.New("No Identity found in session.")
}

View file

@ -25,6 +25,7 @@ import (
"github.com/coreos/dex/user"
usersapi "github.com/coreos/dex/user/api"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
)
const (
@ -68,7 +69,7 @@ type Server struct {
HealthChecks []health.Checkable
Connectors []connector.Connector
UserRepo user.UserRepo
UserManager *user.Manager
UserManager *manager.UserManager
PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer

View file

@ -15,6 +15,7 @@ import (
"github.com/coreos/dex/session"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
)
const (
@ -91,7 +92,6 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
func makeTestFixtures() (*testFixtures, error) {
userRepo := user.NewUserRepoFromUsers(testUsers)
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
manager := user.NewManager(userRepo, pwRepo, repo.InMemTransactionFactory, user.ManagerOptions{})
connConfigs := []connector.ConnectorConfig{
&connector.OIDCConnectorConfig{
@ -111,6 +111,9 @@ func makeTestFixtures() (*testFixtures, error) {
ID: "local",
},
}
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sessionManager.GenerateCode = sequentialGenerateCodeFunc()

View file

@ -16,6 +16,7 @@ import (
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/api"
"github.com/coreos/dex/user/manager"
)
const (
@ -33,11 +34,11 @@ var (
type UserMgmtServer struct {
api *api.UsersAPI
jwtvFactory JWTVerifierFactory
um *user.Manager
um *manager.UserManager
cir client.ClientIdentityRepo
}
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *user.Manager, cir client.ClientIdentityRepo) *UserMgmtServer {
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *manager.UserManager, cir client.ClientIdentityRepo) *UserMgmtServer {
return &UserMgmtServer{
api: userMgmtAPI,
jwtvFactory: jwtvFactory,

2
test
View file

@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
source ./build
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api email"
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override

View file

@ -13,6 +13,7 @@ import (
"github.com/coreos/dex/pkg/log"
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
)
var (
@ -81,7 +82,7 @@ func (e Error) Error() string {
// calling User. It is assumed that the clientID has already validated as an
// admin app before calling.
type UsersAPI struct {
manager *user.Manager
manager *manager.UserManager
localConnectorID string
clientIdentityRepo client.ClientIdentityRepo
emailer Emailer
@ -96,7 +97,7 @@ type Creds struct {
User user.User
}
func NewUsersAPI(manager *user.Manager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI {
func NewUsersAPI(manager *manager.UserManager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI {
return &UsersAPI{
manager: manager,
clientIdentityRepo: cir,

View file

@ -10,9 +10,11 @@ import (
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
)
type testEmailer struct {
@ -123,7 +125,10 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
Password: []byte("password-2"),
},
})
mgr := user.NewManager(ur, pwr, repo.InMemTransactionFactory, user.ManagerOptions{})
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
mgr.Clock = clock
ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{

View file

@ -1,4 +1,4 @@
package user
package manager
import (
"errors"
@ -6,8 +6,10 @@ import (
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
)
var (
@ -19,13 +21,14 @@ var (
// Manager performs user-related "business-logic" functions on user and related objects.
// This is in contrast to the Repos which perform little more than CRUD operations.
type Manager struct {
type UserManager struct {
Clock clockwork.Clock
userRepo UserRepo
pwRepo PasswordInfoRepo
userRepo user.UserRepo
pwRepo user.PasswordInfoRepo
connCfgRepo connector.ConnectorConfigRepo
begin repo.TransactionFactory
userIDGenerator UserIDGenerator
userIDGenerator user.UserIDGenerator
}
type ManagerOptions struct {
@ -34,58 +37,59 @@ type ManagerOptions struct {
// variable policies
}
func NewManager(userRepo UserRepo, pwRepo PasswordInfoRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *Manager {
return &Manager{
func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, connCfgRepo connector.ConnectorConfigRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
return &UserManager{
Clock: clockwork.NewRealClock(),
userRepo: userRepo,
pwRepo: pwRepo,
connCfgRepo: connCfgRepo,
begin: txnFactory,
userIDGenerator: DefaultUserIDGenerator,
userIDGenerator: user.DefaultUserIDGenerator,
}
}
func (m *Manager) Get(id string) (User, error) {
func (m *UserManager) Get(id string) (user.User, error) {
return m.userRepo.Get(nil, id)
}
func (m *Manager) List(filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) {
func (m *UserManager) List(filter user.UserFilter, maxResults int, nextPageToken string) ([]user.User, string, error) {
return m.userRepo.List(nil, filter, maxResults, nextPageToken)
}
// CreateUser creates a new user with the given hashedPassword; the connID should be the ID of the local connector.
// The userID of the created user is returned as the first argument.
func (m *Manager) CreateUser(user User, hashedPassword Password, connID string) (string, error) {
func (m *UserManager) CreateUser(usr user.User, hashedPassword user.Password, connID string) (string, error) {
tx, err := m.begin()
if err != nil {
return "", err
}
insertedUser, err := m.insertNewUser(tx, user.Email, user.EmailVerified)
insertedUser, err := m.insertNewUser(tx, usr.Email, usr.EmailVerified)
if err != nil {
rollback(tx)
return "", err
}
user.ID = insertedUser.ID
user.CreatedAt = insertedUser.CreatedAt
err = m.userRepo.Update(tx, user)
usr.ID = insertedUser.ID
usr.CreatedAt = insertedUser.CreatedAt
err = m.userRepo.Update(tx, usr)
if err != nil {
rollback(tx)
return "", err
}
rid := RemoteIdentity{
rid := user.RemoteIdentity{
ConnectorID: connID,
ID: user.ID,
ID: usr.ID,
}
if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil {
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx)
return "", err
}
pwi := PasswordInfo{
UserID: user.ID,
pwi := user.PasswordInfo{
UserID: usr.ID,
Password: hashedPassword,
}
err = m.pwRepo.Create(tx, pwi)
@ -99,10 +103,10 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string)
rollback(tx)
return "", err
}
return user.ID, nil
return usr.ID, nil
}
func (m *Manager) Disable(userID string, disabled bool) error {
func (m *UserManager) Disable(userID string, disabled bool) error {
tx, err := m.begin()
if err = m.userRepo.Disable(tx, userID, disabled); err != nil {
@ -119,7 +123,7 @@ func (m *Manager) Disable(userID string, disabled bool) error {
}
// RegisterWithRemoteIdentity creates new user and attaches the given remote identity.
func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) {
func (m *UserManager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid user.RemoteIdentity) (string, error) {
tx, err := m.begin()
if err != nil {
return "", err
@ -127,20 +131,20 @@ func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, r
if _, err = m.userRepo.GetByRemoteIdentity(tx, rid); err == nil {
rollback(tx)
return "", ErrorDuplicateRemoteIdentity
return "", user.ErrorDuplicateRemoteIdentity
}
if err != ErrorNotFound {
if err != user.ErrorNotFound {
rollback(tx)
return "", err
}
user, err := m.insertNewUser(tx, email, emailVerified)
usr, err := m.insertNewUser(tx, email, emailVerified)
if err != nil {
rollback(tx)
return "", err
}
if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil {
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx)
return "", err
}
@ -150,44 +154,44 @@ func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, r
rollback(tx)
return "", err
}
return user.ID, nil
return usr.ID, nil
}
// RegisterWithPassword creates a new user with the given name and password.
// connID is the connector ID of the ConnectorLocal connector.
func (m *Manager) RegisterWithPassword(email, plaintext, connID string) (string, error) {
func (m *UserManager) RegisterWithPassword(email, plaintext, connID string) (string, error) {
tx, err := m.begin()
if err != nil {
return "", err
}
if !ValidPassword(plaintext) {
if !user.ValidPassword(plaintext) {
rollback(tx)
return "", ErrorInvalidPassword
return "", user.ErrorInvalidPassword
}
user, err := m.insertNewUser(tx, email, false)
usr, err := m.insertNewUser(tx, email, false)
if err != nil {
rollback(tx)
return "", err
}
rid := RemoteIdentity{
rid := user.RemoteIdentity{
ConnectorID: connID,
ID: user.ID,
ID: usr.ID,
}
if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil {
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx)
return "", err
}
password, err := NewPasswordFromPlaintext(plaintext)
password, err := user.NewPasswordFromPlaintext(plaintext)
if err != nil {
rollback(tx)
return "", err
}
pwi := PasswordInfo{
UserID: user.ID,
pwi := user.PasswordInfo{
UserID: usr.ID,
Password: password,
}
@ -202,7 +206,7 @@ func (m *Manager) RegisterWithPassword(email, plaintext, connID string) (string,
rollback(tx)
return "", err
}
return user.ID, nil
return usr.ID, nil
}
type EmailVerifiable interface {
@ -218,31 +222,31 @@ type EmailVerifiable interface {
// create it, ensuring that the token was signed and that the JWT was not
// expired.
// The callback url (i.e. where to send the user after the verification) is returned.
func (m *Manager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) {
func (m *UserManager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) {
tx, err := m.begin()
if err != nil {
return nil, err
}
user, err := m.userRepo.Get(tx, ev.UserID())
usr, err := m.userRepo.Get(tx, ev.UserID())
if err != nil {
rollback(tx)
return nil, err
}
if user.Email != ev.Email() {
if usr.Email != ev.Email() {
rollback(tx)
return nil, ErrorEVEmailDoesntMatch
}
if user.EmailVerified {
if usr.EmailVerified {
rollback(tx)
return nil, ErrorEmailAlreadyVerified
}
user.EmailVerified = true
usr.EmailVerified = true
err = m.userRepo.Update(tx, user)
err = m.userRepo.Update(tx, usr)
if err != nil {
rollback(tx)
return nil, err
@ -258,19 +262,19 @@ func (m *Manager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) {
type PasswordChangeable interface {
UserID() string
Password() Password
Password() user.Password
Callback() *url.URL
}
func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url.URL, error) {
func (m *UserManager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url.URL, error) {
tx, err := m.begin()
if err != nil {
return nil, err
}
if !ValidPassword(plaintext) {
if !user.ValidPassword(plaintext) {
rollback(tx)
return nil, ErrorInvalidPassword
return nil, user.ErrorInvalidPassword
}
pwi, err := m.pwRepo.Get(tx, pwr.UserID())
@ -284,7 +288,7 @@ func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url
return nil, ErrorPasswordAlreadyChanged
}
newPass, err := NewPasswordFromPlaintext(plaintext)
newPass, err := user.NewPasswordFromPlaintext(plaintext)
if err != nil {
rollback(tx)
return nil, err
@ -305,36 +309,46 @@ func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url
return pwr.Callback(), nil
}
func (m *Manager) insertNewUser(tx repo.Transaction, email string, emailVerified bool) (User, error) {
if !ValidEmail(email) {
return User{}, ErrorInvalidEmail
func (m *UserManager) insertNewUser(tx repo.Transaction, email string, emailVerified bool) (user.User, error) {
if !user.ValidEmail(email) {
return user.User{}, user.ErrorInvalidEmail
}
var err error
if _, err = m.userRepo.GetByEmail(tx, email); err == nil {
return User{}, ErrorDuplicateEmail
return user.User{}, user.ErrorDuplicateEmail
}
if err != ErrorNotFound {
return User{}, err
if err != user.ErrorNotFound {
return user.User{}, err
}
userID, err := m.userIDGenerator()
if err != nil {
return User{}, err
return user.User{}, err
}
user := User{
usr := user.User{
ID: userID,
Email: email,
EmailVerified: emailVerified,
CreatedAt: m.Clock.Now(),
}
err = m.userRepo.Create(tx, user)
err = m.userRepo.Create(tx, usr)
if err != nil {
return User{}, err
return user.User{}, err
}
return user, nil
return usr, nil
}
func (m *UserManager) addRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error {
if _, err := m.connCfgRepo.GetConnectorByID(tx, rid.ConnectorID); err != nil {
return err
}
if err := m.userRepo.AddRemoteIdentity(tx, userID, rid); err != nil {
return err
}
return nil
}
func rollback(tx repo.Transaction) {

View file

@ -1,4 +1,4 @@
package user
package manager
import (
"net/url"
@ -9,13 +9,16 @@ import (
"github.com/jonboulle/clockwork"
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
)
type testFixtures struct {
ur UserRepo
pwr PasswordInfoRepo
mgr *Manager
ur user.UserRepo
pwr user.PasswordInfoRepo
ccr connector.ConnectorConfigRepo
mgr *UserManager
clock clockwork.Clock
}
@ -23,25 +26,25 @@ func makeTestFixtures() *testFixtures {
f := &testFixtures{}
f.clock = clockwork.NewFakeClock()
f.ur = NewUserRepoFromUsers([]UserWithRemoteIdentities{
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
{
User: User{
User: user.User{
ID: "ID-1",
Email: "Email-1@example.com",
},
RemoteIdentities: []RemoteIdentity{
RemoteIdentities: []user.RemoteIdentity{
{
ConnectorID: "local",
ID: "1",
},
},
}, {
User: User{
User: user.User{
ID: "ID-2",
Email: "Email-2@example.com",
EmailVerified: true,
},
RemoteIdentities: []RemoteIdentity{
RemoteIdentities: []user.RemoteIdentity{
{
ConnectorID: "local",
ID: "2",
@ -49,7 +52,7 @@ func makeTestFixtures() *testFixtures {
},
},
})
f.pwr = NewPasswordInfoRepoFromPasswordInfos([]PasswordInfo{
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
{
UserID: "ID-1",
Password: []byte("password-1"),
@ -59,7 +62,10 @@ func makeTestFixtures() *testFixtures {
Password: []byte("password-2"),
},
})
f.mgr = NewManager(f.ur, f.pwr, repo.InMemTransactionFactory, ManagerOptions{})
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
f.mgr.Clock = f.clock
return f
}
@ -68,13 +74,13 @@ func TestRegisterWithRemoteIdentity(t *testing.T) {
tests := []struct {
email string
emailVerified bool
rid RemoteIdentity
rid user.RemoteIdentity
err error
}{
{
email: "email@example.com",
emailVerified: false,
rid: RemoteIdentity{
rid: user.RemoteIdentity{
ConnectorID: "local",
ID: "1234",
},
@ -82,20 +88,29 @@ func TestRegisterWithRemoteIdentity(t *testing.T) {
},
{
emailVerified: false,
rid: RemoteIdentity{
rid: user.RemoteIdentity{
ConnectorID: "local",
ID: "1234",
},
err: ErrorInvalidEmail,
err: user.ErrorInvalidEmail,
},
{
email: "email@example.com",
emailVerified: false,
rid: RemoteIdentity{
rid: user.RemoteIdentity{
ConnectorID: "local",
ID: "1",
},
err: ErrorDuplicateRemoteIdentity,
err: user.ErrorDuplicateRemoteIdentity,
},
{
email: "anotheremail@example.com",
emailVerified: false,
rid: user.RemoteIdentity{
ConnectorID: "idonotexist",
ID: "1",
},
err: connector.ErrorNotFound,
},
}
@ -148,17 +163,17 @@ func TestRegisterWithPassword(t *testing.T) {
},
{
plaintext: "secretpassword123",
err: ErrorInvalidEmail,
err: user.ErrorInvalidEmail,
},
{
email: "email@example.com",
err: ErrorInvalidPassword,
err: user.ErrorInvalidPassword,
},
}
for i, tt := range tests {
f := makeTestFixtures()
connID := "connID"
connID := "local"
userID, err := f.mgr.RegisterWithPassword(
tt.email,
tt.plaintext,
@ -183,7 +198,7 @@ func TestRegisterWithPassword(t *testing.T) {
t.Errorf("case %d: user.EmailVerified: want=%v, got=%v", i, false, usr.EmailVerified)
}
ridUSR, err := f.ur.GetByRemoteIdentity(nil, RemoteIdentity{
ridUSR, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{
ID: userID,
ConnectorID: connID,
})
@ -220,12 +235,12 @@ func TestVerifyEmail(t *testing.T) {
callback := "http://client.example.com/callback"
expires := time.Hour * 3
makeClaims := func(usr User) jose.Claims {
makeClaims := func(usr user.User) jose.Claims {
return map[string]interface{}{
"iss": issuer.String(),
"aud": clientID,
ClaimEmailVerificationCallback: callback,
ClaimEmailVerificationEmail: usr.Email,
user.ClaimEmailVerificationCallback: callback,
user.ClaimEmailVerificationEmail: usr.Email,
"exp": float64(now.Add(expires).Unix()),
"sub": usr.ID,
"iat": float64(now.Unix()),
@ -238,28 +253,28 @@ func TestVerifyEmail(t *testing.T) {
}{
{
// happy path
evClaims: makeClaims(User{ID: "ID-1", Email: "Email-1@example.com"}),
evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-1@example.com"}),
},
{
// non-matching email
evClaims: makeClaims(User{ID: "ID-1", Email: "Email-2@example.com"}),
evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-2@example.com"}),
wantErr: true,
},
{
// already verified email
evClaims: makeClaims(User{ID: "ID-2", Email: "Email-2@example.com"}),
evClaims: makeClaims(user.User{ID: "ID-2", Email: "Email-2@example.com"}),
wantErr: true,
},
{
// non-existent user.
evClaims: makeClaims(User{ID: "ID-UNKNOWN", Email: "noone@example.com"}),
evClaims: makeClaims(user.User{ID: "ID-UNKNOWN", Email: "noone@example.com"}),
wantErr: true,
},
}
for i, tt := range tests {
f := makeTestFixtures()
cb, err := f.mgr.VerifyEmail(EmailVerification{tt.evClaims})
cb, err := f.mgr.VerifyEmail(user.EmailVerification{tt.evClaims})
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil err", i)
@ -271,9 +286,9 @@ func TestVerifyEmail(t *testing.T) {
t.Errorf("case %d: want err=nil got=%q", i, err)
}
if cb.String() != tt.evClaims[ClaimEmailVerificationCallback] {
if cb.String() != tt.evClaims[user.ClaimEmailVerificationCallback] {
t.Errorf("case %d: want=%q, got=%q", i, cb.String(),
tt.evClaims[ClaimEmailVerificationCallback])
tt.evClaims[user.ClaimEmailVerificationCallback])
}
}
}
@ -290,8 +305,8 @@ func TestChangePassword(t *testing.T) {
return map[string]interface{}{
"iss": issuer.String(),
"aud": clientID,
ClaimPasswordResetCallback: callback,
ClaimPasswordResetPassword: password,
user.ClaimPasswordResetCallback: callback,
user.ClaimPasswordResetPassword: password,
"exp": float64(now.Add(expires).Unix()),
"sub": usrID,
"iat": float64(now.Unix()),
@ -329,7 +344,7 @@ func TestChangePassword(t *testing.T) {
for i, tt := range tests {
f := makeTestFixtures()
cb, err := f.mgr.ChangePassword(PasswordReset{tt.pwrClaims}, tt.newPassword)
cb, err := f.mgr.ChangePassword(user.PasswordReset{tt.pwrClaims}, tt.newPassword)
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil err", i)
@ -346,47 +361,61 @@ func TestChangePassword(t *testing.T) {
if cb != nil {
cbString = cb.String()
}
if cbString != tt.pwrClaims[ClaimPasswordResetCallback] {
if cbString != tt.pwrClaims[user.ClaimPasswordResetCallback] {
t.Errorf("case %d: want=%q, got=%q", i, cb.String(),
tt.pwrClaims[ClaimPasswordResetCallback])
tt.pwrClaims[user.ClaimPasswordResetCallback])
}
}
}
func TestCreateUser(t *testing.T) {
tests := []struct {
usr User
hashedPW Password
usr user.User
hashedPW user.Password
localID string // defaults to "local"
wantErr bool
}{
{
usr: User{
usr: user.User{
DisplayName: "Bob Exampleson",
Email: "bob@example.com",
},
hashedPW: Password("I am a hash"),
hashedPW: user.Password("I am a hash"),
},
{
usr: User{
usr: user.User{
DisplayName: "Al Adminson",
Email: "al@example.com",
Admin: true,
},
hashedPW: Password("I am a hash"),
hashedPW: user.Password("I am a hash"),
},
{
usr: User{
usr: user.User{
DisplayName: "Ed Emailless",
},
hashedPW: Password("I am a hash"),
hashedPW: user.Password("I am a hash"),
wantErr: true,
},
{
usr: user.User{
DisplayName: "Eric Exampleson",
Email: "eric@example.com",
},
hashedPW: user.Password("I am a hash"),
localID: "abadlocalid",
wantErr: true,
},
}
for i, tt := range tests {
f := makeTestFixtures()
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, "local")
localID := "local"
if tt.localID != "" {
localID = tt.localID
}
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID)
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil err", i)
@ -422,7 +451,7 @@ func TestCreateUser(t *testing.T) {
t.Errorf("case %d: want=%q, got=%q", i, tt.hashedPW, pwi.Password)
}
ridUser, err := f.ur.GetByRemoteIdentity(nil, RemoteIdentity{
ridUser, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{
ID: id,
ConnectorID: "local",
})