Merge pull request #199 from ericchiang/validate_connector

api: validate local connector existence before creating user
This commit is contained in:
bobbyrullo 2015-12-07 17:44:22 -08:00
commit 521aeae3db
20 changed files with 317 additions and 134 deletions

View file

@ -6,17 +6,18 @@ import (
"github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
) )
// AdminAPI provides the logic necessary to implement the Admin API. // AdminAPI provides the logic necessary to implement the Admin API.
type AdminAPI struct { type AdminAPI struct {
userManager *user.Manager userManager *manager.UserManager
userRepo user.UserRepo userRepo user.UserRepo
passwordInfoRepo user.PasswordInfoRepo passwordInfoRepo user.PasswordInfoRepo
localConnectorID string localConnectorID string
} }
func NewAdminAPI(userManager *user.Manager, userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, localConnectorID string) *AdminAPI { func NewAdminAPI(userManager *manager.UserManager, userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, localConnectorID string) *AdminAPI {
if localConnectorID == "" { if localConnectorID == "" {
panic("must specify non-blank localConnectorID") panic("must specify non-blank localConnectorID")
} }

View file

@ -17,7 +17,7 @@ import (
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
ptime "github.com/coreos/dex/pkg/time" ptime "github.com/coreos/dex/pkg/time"
"github.com/coreos/dex/server" "github.com/coreos/dex/server"
"github.com/coreos/dex/user" "github.com/coreos/dex/user/manager"
) )
var version = "DEV" var version = "DEV"
@ -99,8 +99,9 @@ func main() {
userRepo := db.NewUserRepo(dbc) userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := user.NewManager(userRepo, connCfgRepo := db.NewConnectorConfigRepo(dbc)
pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{}) userManager := manager.NewUserManager(userRepo,
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID) adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...) kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
if err != nil { if err != nil {

View file

@ -4,6 +4,8 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"os" "os"
"github.com/coreos/dex/repo"
) )
func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) { func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) {
@ -41,6 +43,19 @@ type memConnectorConfigRepo struct {
configs []ConnectorConfig configs []ConnectorConfig
} }
func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo {
return &memConnectorConfigRepo{configs: cfgs}
}
func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) { func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
return r.configs, nil return r.configs, nil
} }
func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) {
for _, cfg := range r.configs {
if cfg.ConnectorID() == id {
return cfg, nil
}
}
return nil, ErrorNotFound
}

View file

@ -1,14 +1,18 @@
package connector package connector
import ( import (
"errors"
"html/template" "html/template"
"net/http" "net/http"
"net/url" "net/url"
"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/coreos/pkg/health" "github.com/coreos/pkg/health"
) )
var ErrorNotFound = errors.New("connector not found in repository")
type Connector interface { type Connector interface {
ID() string ID() string
LoginURL(sessionKey, prompt string) (string, error) LoginURL(sessionKey, prompt string) (string, error)
@ -34,4 +38,5 @@ type ConnectorConfig interface {
type ConnectorConfigRepo interface { type ConnectorConfigRepo interface {
All() ([]ConnectorConfig, error) All() ([]ConnectorConfig, error)
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
} }

View file

@ -1,6 +1,7 @@
package db package db
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -9,6 +10,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
) )
const ( const (
@ -91,6 +93,18 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
return cfgs, nil return cfgs, nil
} }
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
var c connectorConfigModel
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
if err == sql.ErrNoRows {
return nil, connector.ErrorNotFound
}
}
return c.ConnectorConfig()
}
func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error { func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
insert := make([]interface{}, len(cfgs)) insert := make([]interface{}, len(cfgs))
for i, cfg := range cfgs { for i, cfg := range cfgs {
@ -119,3 +133,15 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
return tx.Commit() return tx.Commit()
} }
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}
gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}

View file

@ -0,0 +1,71 @@
package repo
import (
"fmt"
"os"
"testing"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
)
type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo
var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory
func init() {
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
} else {
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn)
}
}
func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory {
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
dbMap := initDB(dsn)
repo := db.NewConnectorConfigRepo(dbMap)
if err := repo.Set(cfgs); err != nil {
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
}
return repo
}
}
func TestConnectorConfigRepoGetByID(t *testing.T) {
tests := []struct {
cfgs []connector.ConnectorConfig
id string
err error
}{
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
},
id: "local",
},
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local1"},
&connector.LocalConnectorConfig{ID: "local2"},
},
id: "local2",
},
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local1"},
&connector.LocalConnectorConfig{ID: "local2"},
},
id: "foo",
err: connector.ErrorNotFound,
},
}
for i, tt := range tests {
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs)
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
}
}
}

View file

@ -10,8 +10,10 @@ import (
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
) )
var ( var (
@ -42,11 +44,14 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro
return &resp, nil return &resp, nil
} }
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *user.Manager) { func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) {
ur := user.NewUserRepoFromUsers(users) ur := user.NewUserRepoFromUsers(users)
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords) pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
um := user.NewManager(ur, pwr, repo.InMemTransactionFactory, user.ManagerOptions{}) ccr := connector.NewConnectorConfigRepoFromConfigs(
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
)
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
um.Clock = clock um.Clock = clock
return ur, pwr, um return ur, pwr, um
} }

View file

@ -22,6 +22,7 @@ import (
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
) )
type ServerConfig struct { type ServerConfig struct {
@ -133,7 +134,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
refTokRepo := refresh.NewRefreshTokenRepo() refTokRepo := refresh.NewRefreshTokenRepo()
txnFactory := repo.InMemTransactionFactory txnFactory := repo.InMemTransactionFactory
userManager := user.NewManager(userRepo, pwiRepo, txnFactory, user.ManagerOptions{}) userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
srv.ClientIdentityRepo = ciRepo srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo srv.ConnectorConfigRepo = cfgRepo
@ -171,7 +172,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
cfgRepo := db.NewConnectorConfigRepo(dbc) cfgRepo := db.NewConnectorConfigRepo(dbc)
userRepo := db.NewUserRepo(dbc) userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := user.NewManager(userRepo, pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{}) userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc) refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := session.NewSessionManager(sRepo, skRepo) sm := session.NewSessionManager(sRepo, skRepo)

View file

@ -15,6 +15,7 @@ import (
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
) )
// handleVerifyEmailResendFunc will resend an email-verification email given a valid JWT for the user and a redirect URL. // handleVerifyEmailResendFunc will resend an email-verification email given a valid JWT for the user and a redirect URL.
@ -190,7 +191,7 @@ type emailVerifiedTemplateData struct {
} }
func handleEmailVerifyFunc(verifiedTpl *template.Template, issuer url.URL, keysFunc func() ([]key.PublicKey, func handleEmailVerifyFunc(verifiedTpl *template.Template, issuer url.URL, keysFunc func() ([]key.PublicKey,
error), userManager *user.Manager) http.HandlerFunc { error), userManager *manager.UserManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query() q := r.URL.Query()
@ -217,12 +218,12 @@ func handleEmailVerifyFunc(verifiedTpl *template.Template, issuer url.URL, keysF
cbURL, err := userManager.VerifyEmail(ev) cbURL, err := userManager.VerifyEmail(ev)
if err != nil { if err != nil {
switch err { switch err {
case user.ErrorEmailAlreadyVerified: case manager.ErrorEmailAlreadyVerified:
execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{ execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{
Error: "Invalid Verification Link", Error: "Invalid Verification Link",
Message: "Your email link has expired or has already been verified.", Message: "Your email link has expired or has already been verified.",
}, http.StatusBadRequest) }, http.StatusBadRequest)
case user.ErrorEVEmailDoesntMatch: case manager.ErrorEVEmailDoesntMatch:
execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{ execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{
Error: "Invalid Verification Link", Error: "Invalid Verification Link",
Message: "Your email link does not match the email address on file. Perhaps you have a more recent verification link?", Message: "Your email link does not match the email address on file. Perhaps you have a more recent verification link?",

View file

@ -7,6 +7,7 @@ import (
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
) )
@ -18,7 +19,7 @@ type invitationTemplateData struct {
type InvitationHandler struct { type InvitationHandler struct {
issuerURL url.URL issuerURL url.URL
passwordResetURL url.URL passwordResetURL url.URL
um *user.Manager um *manager.UserManager
keysFunc func() ([]key.PublicKey, error) keysFunc func() ([]key.PublicKey, error)
signerFunc func() (jose.Signer, error) signerFunc func() (jose.Signer, error)
redirectValidityWindow time.Duration redirectValidityWindow time.Duration
@ -55,13 +56,13 @@ func (h *InvitationHandler) handleGET(w http.ResponseWriter, r *http.Request) {
} }
_, err = h.um.VerifyEmail(invite) _, err = h.um.VerifyEmail(invite)
if err != nil && err != user.ErrorEmailAlreadyVerified { if err != nil && err != manager.ErrorEmailAlreadyVerified {
// Allow AlreadyVerified folks to pass through- otherwise // Allow AlreadyVerified folks to pass through- otherwise
// folks who encounter an error after passing this point will // folks who encounter an error after passing this point will
// never be able to set their passwords. // never be able to set their passwords.
log.Debugf("error attempting to verify email: %v", err) log.Debugf("error attempting to verify email: %v", err)
switch err { switch err {
case user.ErrorEVEmailDoesntMatch: case manager.ErrorEVEmailDoesntMatch:
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidRequest, writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidRequest,
"Your email does not match the email address on file")) "Your email does not match the email address on file"))
return return

View file

@ -12,6 +12,7 @@ import (
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
) )
type sendResetPasswordEmailData struct { type sendResetPasswordEmailData struct {
@ -181,7 +182,7 @@ type resetPasswordTemplateData struct {
type ResetPasswordHandler struct { type ResetPasswordHandler struct {
tpl *template.Template tpl *template.Template
issuerURL url.URL issuerURL url.URL
um *user.Manager um *manager.UserManager
keysFunc func() ([]key.PublicKey, error) keysFunc func() ([]key.PublicKey, error)
} }
@ -237,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext) cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
if err != nil { if err != nil {
switch err { switch err {
case user.ErrorPasswordAlreadyChanged: case manager.ErrorPasswordAlreadyChanged:
r.data.Error = "Link Expired" r.data.Error = "Link Expired"
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email." r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
r.data.DontShowForm = true r.data.DontShowForm = true

View file

@ -11,6 +11,7 @@ import (
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
@ -222,7 +223,7 @@ func handleRegisterFunc(s *Server) http.HandlerFunc {
} }
} }
func registerFromLocalConnector(userManager *user.Manager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) { func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID) userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
if err != nil { if err != nil {
return "", err return "", err
@ -237,7 +238,7 @@ func registerFromLocalConnector(userManager *user.Manager, sessionManager *sessi
return userID, nil return userID, nil
} }
func registerFromRemoteConnector(userManager *user.Manager, ses *session.Session, email string, emailVerified bool) (string, error) { func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
if ses.Identity.ID == "" { if ses.Identity.ID == "" {
return "", errors.New("No Identity found in session.") return "", errors.New("No Identity found in session.")
} }

View file

@ -25,6 +25,7 @@ import (
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
usersapi "github.com/coreos/dex/user/api" usersapi "github.com/coreos/dex/user/api"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
) )
const ( const (
@ -68,7 +69,7 @@ type Server struct {
HealthChecks []health.Checkable HealthChecks []health.Checkable
Connectors []connector.Connector Connectors []connector.Connector
UserRepo user.UserRepo UserRepo user.UserRepo
UserManager *user.Manager UserManager *manager.UserManager
PasswordInfoRepo user.PasswordInfoRepo PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer UserEmailer *useremail.UserEmailer

View file

@ -15,6 +15,7 @@ import (
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
) )
const ( const (
@ -91,7 +92,6 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
func makeTestFixtures() (*testFixtures, error) { func makeTestFixtures() (*testFixtures, error) {
userRepo := user.NewUserRepoFromUsers(testUsers) userRepo := user.NewUserRepoFromUsers(testUsers)
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos) pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
manager := user.NewManager(userRepo, pwRepo, repo.InMemTransactionFactory, user.ManagerOptions{})
connConfigs := []connector.ConnectorConfig{ connConfigs := []connector.ConnectorConfig{
&connector.OIDCConnectorConfig{ &connector.OIDCConnectorConfig{
@ -111,6 +111,9 @@ func makeTestFixtures() (*testFixtures, error) {
ID: "local", ID: "local",
}, },
} }
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sessionManager.GenerateCode = sequentialGenerateCodeFunc() sessionManager.GenerateCode = sequentialGenerateCodeFunc()

View file

@ -16,6 +16,7 @@ import (
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/api" "github.com/coreos/dex/user/api"
"github.com/coreos/dex/user/manager"
) )
const ( const (
@ -33,11 +34,11 @@ var (
type UserMgmtServer struct { type UserMgmtServer struct {
api *api.UsersAPI api *api.UsersAPI
jwtvFactory JWTVerifierFactory jwtvFactory JWTVerifierFactory
um *user.Manager um *manager.UserManager
cir client.ClientIdentityRepo cir client.ClientIdentityRepo
} }
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *user.Manager, cir client.ClientIdentityRepo) *UserMgmtServer { func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *manager.UserManager, cir client.ClientIdentityRepo) *UserMgmtServer {
return &UserMgmtServer{ return &UserMgmtServer{
api: userMgmtAPI, api: userMgmtAPI,
jwtvFactory: jwtvFactory, jwtvFactory: jwtvFactory,

2
test
View file

@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
source ./build source ./build
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api email" TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log" FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override # user has not provided PKG override

View file

@ -13,6 +13,7 @@ import (
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
) )
var ( var (
@ -81,7 +82,7 @@ func (e Error) Error() string {
// calling User. It is assumed that the clientID has already validated as an // calling User. It is assumed that the clientID has already validated as an
// admin app before calling. // admin app before calling.
type UsersAPI struct { type UsersAPI struct {
manager *user.Manager manager *manager.UserManager
localConnectorID string localConnectorID string
clientIdentityRepo client.ClientIdentityRepo clientIdentityRepo client.ClientIdentityRepo
emailer Emailer emailer Emailer
@ -96,7 +97,7 @@ type Creds struct {
User user.User User user.User
} }
func NewUsersAPI(manager *user.Manager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI { func NewUsersAPI(manager *manager.UserManager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI {
return &UsersAPI{ return &UsersAPI{
manager: manager, manager: manager,
clientIdentityRepo: cir, clientIdentityRepo: cir,

View file

@ -10,9 +10,11 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
) )
type testEmailer struct { type testEmailer struct {
@ -123,7 +125,10 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
Password: []byte("password-2"), Password: []byte("password-2"),
}, },
}) })
mgr := user.NewManager(ur, pwr, repo.InMemTransactionFactory, user.ManagerOptions{}) ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
mgr.Clock = clock mgr.Clock = clock
ci := oidc.ClientIdentity{ ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{

View file

@ -1,4 +1,4 @@
package user package manager
import ( import (
"errors" "errors"
@ -6,8 +6,10 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
) )
var ( var (
@ -19,13 +21,14 @@ var (
// Manager performs user-related "business-logic" functions on user and related objects. // Manager performs user-related "business-logic" functions on user and related objects.
// This is in contrast to the Repos which perform little more than CRUD operations. // This is in contrast to the Repos which perform little more than CRUD operations.
type Manager struct { type UserManager struct {
Clock clockwork.Clock Clock clockwork.Clock
userRepo UserRepo userRepo user.UserRepo
pwRepo PasswordInfoRepo pwRepo user.PasswordInfoRepo
connCfgRepo connector.ConnectorConfigRepo
begin repo.TransactionFactory begin repo.TransactionFactory
userIDGenerator UserIDGenerator userIDGenerator user.UserIDGenerator
} }
type ManagerOptions struct { type ManagerOptions struct {
@ -34,58 +37,59 @@ type ManagerOptions struct {
// variable policies // variable policies
} }
func NewManager(userRepo UserRepo, pwRepo PasswordInfoRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *Manager { func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, connCfgRepo connector.ConnectorConfigRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
return &Manager{ return &UserManager{
Clock: clockwork.NewRealClock(), Clock: clockwork.NewRealClock(),
userRepo: userRepo, userRepo: userRepo,
pwRepo: pwRepo, pwRepo: pwRepo,
connCfgRepo: connCfgRepo,
begin: txnFactory, begin: txnFactory,
userIDGenerator: DefaultUserIDGenerator, userIDGenerator: user.DefaultUserIDGenerator,
} }
} }
func (m *Manager) Get(id string) (User, error) { func (m *UserManager) Get(id string) (user.User, error) {
return m.userRepo.Get(nil, id) return m.userRepo.Get(nil, id)
} }
func (m *Manager) List(filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) { func (m *UserManager) List(filter user.UserFilter, maxResults int, nextPageToken string) ([]user.User, string, error) {
return m.userRepo.List(nil, filter, maxResults, nextPageToken) return m.userRepo.List(nil, filter, maxResults, nextPageToken)
} }
// CreateUser creates a new user with the given hashedPassword; the connID should be the ID of the local connector. // CreateUser creates a new user with the given hashedPassword; the connID should be the ID of the local connector.
// The userID of the created user is returned as the first argument. // The userID of the created user is returned as the first argument.
func (m *Manager) CreateUser(user User, hashedPassword Password, connID string) (string, error) { func (m *UserManager) CreateUser(usr user.User, hashedPassword user.Password, connID string) (string, error) {
tx, err := m.begin() tx, err := m.begin()
if err != nil { if err != nil {
return "", err return "", err
} }
insertedUser, err := m.insertNewUser(tx, user.Email, user.EmailVerified) insertedUser, err := m.insertNewUser(tx, usr.Email, usr.EmailVerified)
if err != nil { if err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
user.ID = insertedUser.ID usr.ID = insertedUser.ID
user.CreatedAt = insertedUser.CreatedAt usr.CreatedAt = insertedUser.CreatedAt
err = m.userRepo.Update(tx, user) err = m.userRepo.Update(tx, usr)
if err != nil { if err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
rid := RemoteIdentity{ rid := user.RemoteIdentity{
ConnectorID: connID, ConnectorID: connID,
ID: user.ID, ID: usr.ID,
} }
if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil { if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
pwi := PasswordInfo{ pwi := user.PasswordInfo{
UserID: user.ID, UserID: usr.ID,
Password: hashedPassword, Password: hashedPassword,
} }
err = m.pwRepo.Create(tx, pwi) err = m.pwRepo.Create(tx, pwi)
@ -99,10 +103,10 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string)
rollback(tx) rollback(tx)
return "", err return "", err
} }
return user.ID, nil return usr.ID, nil
} }
func (m *Manager) Disable(userID string, disabled bool) error { func (m *UserManager) Disable(userID string, disabled bool) error {
tx, err := m.begin() tx, err := m.begin()
if err = m.userRepo.Disable(tx, userID, disabled); err != nil { if err = m.userRepo.Disable(tx, userID, disabled); err != nil {
@ -119,7 +123,7 @@ func (m *Manager) Disable(userID string, disabled bool) error {
} }
// RegisterWithRemoteIdentity creates new user and attaches the given remote identity. // RegisterWithRemoteIdentity creates new user and attaches the given remote identity.
func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) { func (m *UserManager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid user.RemoteIdentity) (string, error) {
tx, err := m.begin() tx, err := m.begin()
if err != nil { if err != nil {
return "", err return "", err
@ -127,20 +131,20 @@ func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, r
if _, err = m.userRepo.GetByRemoteIdentity(tx, rid); err == nil { if _, err = m.userRepo.GetByRemoteIdentity(tx, rid); err == nil {
rollback(tx) rollback(tx)
return "", ErrorDuplicateRemoteIdentity return "", user.ErrorDuplicateRemoteIdentity
} }
if err != ErrorNotFound { if err != user.ErrorNotFound {
rollback(tx) rollback(tx)
return "", err return "", err
} }
user, err := m.insertNewUser(tx, email, emailVerified) usr, err := m.insertNewUser(tx, email, emailVerified)
if err != nil { if err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil { if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
@ -150,44 +154,44 @@ func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, r
rollback(tx) rollback(tx)
return "", err return "", err
} }
return user.ID, nil return usr.ID, nil
} }
// RegisterWithPassword creates a new user with the given name and password. // RegisterWithPassword creates a new user with the given name and password.
// connID is the connector ID of the ConnectorLocal connector. // connID is the connector ID of the ConnectorLocal connector.
func (m *Manager) RegisterWithPassword(email, plaintext, connID string) (string, error) { func (m *UserManager) RegisterWithPassword(email, plaintext, connID string) (string, error) {
tx, err := m.begin() tx, err := m.begin()
if err != nil { if err != nil {
return "", err return "", err
} }
if !ValidPassword(plaintext) { if !user.ValidPassword(plaintext) {
rollback(tx) rollback(tx)
return "", ErrorInvalidPassword return "", user.ErrorInvalidPassword
} }
user, err := m.insertNewUser(tx, email, false) usr, err := m.insertNewUser(tx, email, false)
if err != nil { if err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
rid := RemoteIdentity{ rid := user.RemoteIdentity{
ConnectorID: connID, ConnectorID: connID,
ID: user.ID, ID: usr.ID,
} }
if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil { if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
password, err := NewPasswordFromPlaintext(plaintext) password, err := user.NewPasswordFromPlaintext(plaintext)
if err != nil { if err != nil {
rollback(tx) rollback(tx)
return "", err return "", err
} }
pwi := PasswordInfo{ pwi := user.PasswordInfo{
UserID: user.ID, UserID: usr.ID,
Password: password, Password: password,
} }
@ -202,7 +206,7 @@ func (m *Manager) RegisterWithPassword(email, plaintext, connID string) (string,
rollback(tx) rollback(tx)
return "", err return "", err
} }
return user.ID, nil return usr.ID, nil
} }
type EmailVerifiable interface { type EmailVerifiable interface {
@ -218,31 +222,31 @@ type EmailVerifiable interface {
// create it, ensuring that the token was signed and that the JWT was not // create it, ensuring that the token was signed and that the JWT was not
// expired. // expired.
// The callback url (i.e. where to send the user after the verification) is returned. // The callback url (i.e. where to send the user after the verification) is returned.
func (m *Manager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) { func (m *UserManager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) {
tx, err := m.begin() tx, err := m.begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := m.userRepo.Get(tx, ev.UserID()) usr, err := m.userRepo.Get(tx, ev.UserID())
if err != nil { if err != nil {
rollback(tx) rollback(tx)
return nil, err return nil, err
} }
if user.Email != ev.Email() { if usr.Email != ev.Email() {
rollback(tx) rollback(tx)
return nil, ErrorEVEmailDoesntMatch return nil, ErrorEVEmailDoesntMatch
} }
if user.EmailVerified { if usr.EmailVerified {
rollback(tx) rollback(tx)
return nil, ErrorEmailAlreadyVerified return nil, ErrorEmailAlreadyVerified
} }
user.EmailVerified = true usr.EmailVerified = true
err = m.userRepo.Update(tx, user) err = m.userRepo.Update(tx, usr)
if err != nil { if err != nil {
rollback(tx) rollback(tx)
return nil, err return nil, err
@ -258,19 +262,19 @@ func (m *Manager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) {
type PasswordChangeable interface { type PasswordChangeable interface {
UserID() string UserID() string
Password() Password Password() user.Password
Callback() *url.URL Callback() *url.URL
} }
func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url.URL, error) { func (m *UserManager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url.URL, error) {
tx, err := m.begin() tx, err := m.begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !ValidPassword(plaintext) { if !user.ValidPassword(plaintext) {
rollback(tx) rollback(tx)
return nil, ErrorInvalidPassword return nil, user.ErrorInvalidPassword
} }
pwi, err := m.pwRepo.Get(tx, pwr.UserID()) pwi, err := m.pwRepo.Get(tx, pwr.UserID())
@ -284,7 +288,7 @@ func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url
return nil, ErrorPasswordAlreadyChanged return nil, ErrorPasswordAlreadyChanged
} }
newPass, err := NewPasswordFromPlaintext(plaintext) newPass, err := user.NewPasswordFromPlaintext(plaintext)
if err != nil { if err != nil {
rollback(tx) rollback(tx)
return nil, err return nil, err
@ -305,36 +309,46 @@ func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url
return pwr.Callback(), nil return pwr.Callback(), nil
} }
func (m *Manager) insertNewUser(tx repo.Transaction, email string, emailVerified bool) (User, error) { func (m *UserManager) insertNewUser(tx repo.Transaction, email string, emailVerified bool) (user.User, error) {
if !ValidEmail(email) { if !user.ValidEmail(email) {
return User{}, ErrorInvalidEmail return user.User{}, user.ErrorInvalidEmail
} }
var err error var err error
if _, err = m.userRepo.GetByEmail(tx, email); err == nil { if _, err = m.userRepo.GetByEmail(tx, email); err == nil {
return User{}, ErrorDuplicateEmail return user.User{}, user.ErrorDuplicateEmail
} }
if err != ErrorNotFound { if err != user.ErrorNotFound {
return User{}, err return user.User{}, err
} }
userID, err := m.userIDGenerator() userID, err := m.userIDGenerator()
if err != nil { if err != nil {
return User{}, err return user.User{}, err
} }
user := User{ usr := user.User{
ID: userID, ID: userID,
Email: email, Email: email,
EmailVerified: emailVerified, EmailVerified: emailVerified,
CreatedAt: m.Clock.Now(), CreatedAt: m.Clock.Now(),
} }
err = m.userRepo.Create(tx, user) err = m.userRepo.Create(tx, usr)
if err != nil { if err != nil {
return User{}, err return user.User{}, err
} }
return user, nil return usr, nil
}
func (m *UserManager) addRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error {
if _, err := m.connCfgRepo.GetConnectorByID(tx, rid.ConnectorID); err != nil {
return err
}
if err := m.userRepo.AddRemoteIdentity(tx, userID, rid); err != nil {
return err
}
return nil
} }
func rollback(tx repo.Transaction) { func rollback(tx repo.Transaction) {

View file

@ -1,4 +1,4 @@
package user package manager
import ( import (
"net/url" "net/url"
@ -9,13 +9,16 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
) )
type testFixtures struct { type testFixtures struct {
ur UserRepo ur user.UserRepo
pwr PasswordInfoRepo pwr user.PasswordInfoRepo
mgr *Manager ccr connector.ConnectorConfigRepo
mgr *UserManager
clock clockwork.Clock clock clockwork.Clock
} }
@ -23,25 +26,25 @@ func makeTestFixtures() *testFixtures {
f := &testFixtures{} f := &testFixtures{}
f.clock = clockwork.NewFakeClock() f.clock = clockwork.NewFakeClock()
f.ur = NewUserRepoFromUsers([]UserWithRemoteIdentities{ f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
{ {
User: User{ User: user.User{
ID: "ID-1", ID: "ID-1",
Email: "Email-1@example.com", Email: "Email-1@example.com",
}, },
RemoteIdentities: []RemoteIdentity{ RemoteIdentities: []user.RemoteIdentity{
{ {
ConnectorID: "local", ConnectorID: "local",
ID: "1", ID: "1",
}, },
}, },
}, { }, {
User: User{ User: user.User{
ID: "ID-2", ID: "ID-2",
Email: "Email-2@example.com", Email: "Email-2@example.com",
EmailVerified: true, EmailVerified: true,
}, },
RemoteIdentities: []RemoteIdentity{ RemoteIdentities: []user.RemoteIdentity{
{ {
ConnectorID: "local", ConnectorID: "local",
ID: "2", ID: "2",
@ -49,7 +52,7 @@ func makeTestFixtures() *testFixtures {
}, },
}, },
}) })
f.pwr = NewPasswordInfoRepoFromPasswordInfos([]PasswordInfo{ f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
{ {
UserID: "ID-1", UserID: "ID-1",
Password: []byte("password-1"), Password: []byte("password-1"),
@ -59,7 +62,10 @@ func makeTestFixtures() *testFixtures {
Password: []byte("password-2"), Password: []byte("password-2"),
}, },
}) })
f.mgr = NewManager(f.ur, f.pwr, repo.InMemTransactionFactory, ManagerOptions{}) f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
f.mgr.Clock = f.clock f.mgr.Clock = f.clock
return f return f
} }
@ -68,13 +74,13 @@ func TestRegisterWithRemoteIdentity(t *testing.T) {
tests := []struct { tests := []struct {
email string email string
emailVerified bool emailVerified bool
rid RemoteIdentity rid user.RemoteIdentity
err error err error
}{ }{
{ {
email: "email@example.com", email: "email@example.com",
emailVerified: false, emailVerified: false,
rid: RemoteIdentity{ rid: user.RemoteIdentity{
ConnectorID: "local", ConnectorID: "local",
ID: "1234", ID: "1234",
}, },
@ -82,20 +88,29 @@ func TestRegisterWithRemoteIdentity(t *testing.T) {
}, },
{ {
emailVerified: false, emailVerified: false,
rid: RemoteIdentity{ rid: user.RemoteIdentity{
ConnectorID: "local", ConnectorID: "local",
ID: "1234", ID: "1234",
}, },
err: ErrorInvalidEmail, err: user.ErrorInvalidEmail,
}, },
{ {
email: "email@example.com", email: "email@example.com",
emailVerified: false, emailVerified: false,
rid: RemoteIdentity{ rid: user.RemoteIdentity{
ConnectorID: "local", ConnectorID: "local",
ID: "1", ID: "1",
}, },
err: ErrorDuplicateRemoteIdentity, err: user.ErrorDuplicateRemoteIdentity,
},
{
email: "anotheremail@example.com",
emailVerified: false,
rid: user.RemoteIdentity{
ConnectorID: "idonotexist",
ID: "1",
},
err: connector.ErrorNotFound,
}, },
} }
@ -148,17 +163,17 @@ func TestRegisterWithPassword(t *testing.T) {
}, },
{ {
plaintext: "secretpassword123", plaintext: "secretpassword123",
err: ErrorInvalidEmail, err: user.ErrorInvalidEmail,
}, },
{ {
email: "email@example.com", email: "email@example.com",
err: ErrorInvalidPassword, err: user.ErrorInvalidPassword,
}, },
} }
for i, tt := range tests { for i, tt := range tests {
f := makeTestFixtures() f := makeTestFixtures()
connID := "connID" connID := "local"
userID, err := f.mgr.RegisterWithPassword( userID, err := f.mgr.RegisterWithPassword(
tt.email, tt.email,
tt.plaintext, tt.plaintext,
@ -183,7 +198,7 @@ func TestRegisterWithPassword(t *testing.T) {
t.Errorf("case %d: user.EmailVerified: want=%v, got=%v", i, false, usr.EmailVerified) t.Errorf("case %d: user.EmailVerified: want=%v, got=%v", i, false, usr.EmailVerified)
} }
ridUSR, err := f.ur.GetByRemoteIdentity(nil, RemoteIdentity{ ridUSR, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{
ID: userID, ID: userID,
ConnectorID: connID, ConnectorID: connID,
}) })
@ -220,12 +235,12 @@ func TestVerifyEmail(t *testing.T) {
callback := "http://client.example.com/callback" callback := "http://client.example.com/callback"
expires := time.Hour * 3 expires := time.Hour * 3
makeClaims := func(usr User) jose.Claims { makeClaims := func(usr user.User) jose.Claims {
return map[string]interface{}{ return map[string]interface{}{
"iss": issuer.String(), "iss": issuer.String(),
"aud": clientID, "aud": clientID,
ClaimEmailVerificationCallback: callback, user.ClaimEmailVerificationCallback: callback,
ClaimEmailVerificationEmail: usr.Email, user.ClaimEmailVerificationEmail: usr.Email,
"exp": float64(now.Add(expires).Unix()), "exp": float64(now.Add(expires).Unix()),
"sub": usr.ID, "sub": usr.ID,
"iat": float64(now.Unix()), "iat": float64(now.Unix()),
@ -238,28 +253,28 @@ func TestVerifyEmail(t *testing.T) {
}{ }{
{ {
// happy path // happy path
evClaims: makeClaims(User{ID: "ID-1", Email: "Email-1@example.com"}), evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-1@example.com"}),
}, },
{ {
// non-matching email // non-matching email
evClaims: makeClaims(User{ID: "ID-1", Email: "Email-2@example.com"}), evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-2@example.com"}),
wantErr: true, wantErr: true,
}, },
{ {
// already verified email // already verified email
evClaims: makeClaims(User{ID: "ID-2", Email: "Email-2@example.com"}), evClaims: makeClaims(user.User{ID: "ID-2", Email: "Email-2@example.com"}),
wantErr: true, wantErr: true,
}, },
{ {
// non-existent user. // non-existent user.
evClaims: makeClaims(User{ID: "ID-UNKNOWN", Email: "noone@example.com"}), evClaims: makeClaims(user.User{ID: "ID-UNKNOWN", Email: "noone@example.com"}),
wantErr: true, wantErr: true,
}, },
} }
for i, tt := range tests { for i, tt := range tests {
f := makeTestFixtures() f := makeTestFixtures()
cb, err := f.mgr.VerifyEmail(EmailVerification{tt.evClaims}) cb, err := f.mgr.VerifyEmail(user.EmailVerification{tt.evClaims})
if tt.wantErr { if tt.wantErr {
if err == nil { if err == nil {
t.Errorf("case %d: want non-nil err", i) t.Errorf("case %d: want non-nil err", i)
@ -271,9 +286,9 @@ func TestVerifyEmail(t *testing.T) {
t.Errorf("case %d: want err=nil got=%q", i, err) t.Errorf("case %d: want err=nil got=%q", i, err)
} }
if cb.String() != tt.evClaims[ClaimEmailVerificationCallback] { if cb.String() != tt.evClaims[user.ClaimEmailVerificationCallback] {
t.Errorf("case %d: want=%q, got=%q", i, cb.String(), t.Errorf("case %d: want=%q, got=%q", i, cb.String(),
tt.evClaims[ClaimEmailVerificationCallback]) tt.evClaims[user.ClaimEmailVerificationCallback])
} }
} }
} }
@ -290,8 +305,8 @@ func TestChangePassword(t *testing.T) {
return map[string]interface{}{ return map[string]interface{}{
"iss": issuer.String(), "iss": issuer.String(),
"aud": clientID, "aud": clientID,
ClaimPasswordResetCallback: callback, user.ClaimPasswordResetCallback: callback,
ClaimPasswordResetPassword: password, user.ClaimPasswordResetPassword: password,
"exp": float64(now.Add(expires).Unix()), "exp": float64(now.Add(expires).Unix()),
"sub": usrID, "sub": usrID,
"iat": float64(now.Unix()), "iat": float64(now.Unix()),
@ -329,7 +344,7 @@ func TestChangePassword(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
f := makeTestFixtures() f := makeTestFixtures()
cb, err := f.mgr.ChangePassword(PasswordReset{tt.pwrClaims}, tt.newPassword) cb, err := f.mgr.ChangePassword(user.PasswordReset{tt.pwrClaims}, tt.newPassword)
if tt.wantErr { if tt.wantErr {
if err == nil { if err == nil {
t.Errorf("case %d: want non-nil err", i) t.Errorf("case %d: want non-nil err", i)
@ -346,47 +361,61 @@ func TestChangePassword(t *testing.T) {
if cb != nil { if cb != nil {
cbString = cb.String() cbString = cb.String()
} }
if cbString != tt.pwrClaims[ClaimPasswordResetCallback] { if cbString != tt.pwrClaims[user.ClaimPasswordResetCallback] {
t.Errorf("case %d: want=%q, got=%q", i, cb.String(), t.Errorf("case %d: want=%q, got=%q", i, cb.String(),
tt.pwrClaims[ClaimPasswordResetCallback]) tt.pwrClaims[user.ClaimPasswordResetCallback])
} }
} }
} }
func TestCreateUser(t *testing.T) { func TestCreateUser(t *testing.T) {
tests := []struct { tests := []struct {
usr User usr user.User
hashedPW Password hashedPW user.Password
localID string // defaults to "local"
wantErr bool wantErr bool
}{ }{
{ {
usr: User{ usr: user.User{
DisplayName: "Bob Exampleson", DisplayName: "Bob Exampleson",
Email: "bob@example.com", Email: "bob@example.com",
}, },
hashedPW: Password("I am a hash"), hashedPW: user.Password("I am a hash"),
}, },
{ {
usr: User{ usr: user.User{
DisplayName: "Al Adminson", DisplayName: "Al Adminson",
Email: "al@example.com", Email: "al@example.com",
Admin: true, Admin: true,
}, },
hashedPW: Password("I am a hash"), hashedPW: user.Password("I am a hash"),
}, },
{ {
usr: User{ usr: user.User{
DisplayName: "Ed Emailless", DisplayName: "Ed Emailless",
}, },
hashedPW: Password("I am a hash"), hashedPW: user.Password("I am a hash"),
wantErr: true,
},
{
usr: user.User{
DisplayName: "Eric Exampleson",
Email: "eric@example.com",
},
hashedPW: user.Password("I am a hash"),
localID: "abadlocalid",
wantErr: true, wantErr: true,
}, },
} }
for i, tt := range tests { for i, tt := range tests {
f := makeTestFixtures() f := makeTestFixtures()
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, "local") localID := "local"
if tt.localID != "" {
localID = tt.localID
}
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID)
if tt.wantErr { if tt.wantErr {
if err == nil { if err == nil {
t.Errorf("case %d: want non-nil err", i) t.Errorf("case %d: want non-nil err", i)
@ -422,7 +451,7 @@ func TestCreateUser(t *testing.T) {
t.Errorf("case %d: want=%q, got=%q", i, tt.hashedPW, pwi.Password) t.Errorf("case %d: want=%q, got=%q", i, tt.hashedPW, pwi.Password)
} }
ridUser, err := f.ur.GetByRemoteIdentity(nil, RemoteIdentity{ ridUser, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{
ID: id, ID: id,
ConnectorID: "local", ConnectorID: "local",
}) })