forked from mystiq/dex
Merge pull request #199 from ericchiang/validate_connector
api: validate local connector existence before creating user
This commit is contained in:
commit
521aeae3db
20 changed files with 317 additions and 134 deletions
|
@ -6,17 +6,18 @@ import (
|
|||
|
||||
"github.com/coreos/dex/schema/adminschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
// AdminAPI provides the logic necessary to implement the Admin API.
|
||||
type AdminAPI struct {
|
||||
userManager *user.Manager
|
||||
userManager *manager.UserManager
|
||||
userRepo user.UserRepo
|
||||
passwordInfoRepo user.PasswordInfoRepo
|
||||
localConnectorID string
|
||||
}
|
||||
|
||||
func NewAdminAPI(userManager *user.Manager, userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, localConnectorID string) *AdminAPI {
|
||||
func NewAdminAPI(userManager *manager.UserManager, userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, localConnectorID string) *AdminAPI {
|
||||
if localConnectorID == "" {
|
||||
panic("must specify non-blank localConnectorID")
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
"github.com/coreos/dex/pkg/log"
|
||||
ptime "github.com/coreos/dex/pkg/time"
|
||||
"github.com/coreos/dex/server"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
var version = "DEV"
|
||||
|
@ -99,8 +99,9 @@ func main() {
|
|||
|
||||
userRepo := db.NewUserRepo(dbc)
|
||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||
userManager := user.NewManager(userRepo,
|
||||
pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{})
|
||||
connCfgRepo := db.NewConnectorConfigRepo(dbc)
|
||||
userManager := manager.NewUserManager(userRepo,
|
||||
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
|
||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
|
||||
if err != nil {
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) {
|
||||
|
@ -41,6 +43,19 @@ type memConnectorConfigRepo struct {
|
|||
configs []ConnectorConfig
|
||||
}
|
||||
|
||||
func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo {
|
||||
return &memConnectorConfigRepo{configs: cfgs}
|
||||
}
|
||||
|
||||
func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
|
||||
return r.configs, nil
|
||||
}
|
||||
|
||||
func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) {
|
||||
for _, cfg := range r.configs {
|
||||
if cfg.ConnectorID() == id {
|
||||
return cfg, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrorNotFound
|
||||
}
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
package connector
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/coreos/pkg/health"
|
||||
)
|
||||
|
||||
var ErrorNotFound = errors.New("connector not found in repository")
|
||||
|
||||
type Connector interface {
|
||||
ID() string
|
||||
LoginURL(sessionKey, prompt string) (string, error)
|
||||
|
@ -34,4 +38,5 @@ type ConnectorConfig interface {
|
|||
|
||||
type ConnectorConfigRepo interface {
|
||||
All() ([]ConnectorConfig, error)
|
||||
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -9,6 +10,7 @@ import (
|
|||
"github.com/lib/pq"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -91,6 +93,18 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
|
|||
return cfgs, nil
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
|
||||
qt := pq.QuoteIdentifier(connectorConfigTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
|
||||
var c connectorConfigModel
|
||||
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, connector.ErrorNotFound
|
||||
}
|
||||
}
|
||||
return c.ConnectorConfig()
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
||||
insert := make([]interface{}, len(cfgs))
|
||||
for i, cfg := range cfgs {
|
||||
|
@ -119,3 +133,15 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
|
|||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
|
||||
if tx == nil {
|
||||
return r.dbMap
|
||||
}
|
||||
|
||||
gorpTx, ok := tx.(*gorp.Transaction)
|
||||
if !ok {
|
||||
panic("wrong kind of transaction passed to a DB repo")
|
||||
}
|
||||
return gorpTx
|
||||
}
|
||||
|
|
71
functional/repo/connector_repo_test.go
Normal file
71
functional/repo/connector_repo_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package repo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
)
|
||||
|
||||
type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo
|
||||
|
||||
var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory
|
||||
|
||||
func init() {
|
||||
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
|
||||
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
|
||||
} else {
|
||||
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory {
|
||||
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
|
||||
dbMap := initDB(dsn)
|
||||
|
||||
repo := db.NewConnectorConfigRepo(dbMap)
|
||||
if err := repo.Set(cfgs); err != nil {
|
||||
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
|
||||
}
|
||||
return repo
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectorConfigRepoGetByID(t *testing.T) {
|
||||
tests := []struct {
|
||||
cfgs []connector.ConnectorConfig
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
cfgs: []connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
},
|
||||
id: "local",
|
||||
},
|
||||
{
|
||||
cfgs: []connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local1"},
|
||||
&connector.LocalConnectorConfig{ID: "local2"},
|
||||
},
|
||||
id: "local2",
|
||||
},
|
||||
{
|
||||
cfgs: []connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local1"},
|
||||
&connector.LocalConnectorConfig{ID: "local2"},
|
||||
},
|
||||
id: "foo",
|
||||
err: connector.ErrorNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs)
|
||||
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
|
||||
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,8 +10,10 @@ import (
|
|||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -42,11 +44,14 @@ func (t *tokenHandlerTransport) RoundTrip(r *http.Request) (*http.Response, erro
|
|||
return &resp, nil
|
||||
}
|
||||
|
||||
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *user.Manager) {
|
||||
func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.PasswordInfo) (user.UserRepo, user.PasswordInfoRepo, *manager.UserManager) {
|
||||
ur := user.NewUserRepoFromUsers(users)
|
||||
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)
|
||||
|
||||
um := user.NewManager(ur, pwr, repo.InMemTransactionFactory, user.ManagerOptions{})
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs(
|
||||
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
|
||||
)
|
||||
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
um.Clock = clock
|
||||
return ur, pwr, um
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
type ServerConfig struct {
|
||||
|
@ -133,7 +134,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
refTokRepo := refresh.NewRefreshTokenRepo()
|
||||
|
||||
txnFactory := repo.InMemTransactionFactory
|
||||
userManager := user.NewManager(userRepo, pwiRepo, txnFactory, user.ManagerOptions{})
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
|
||||
srv.ClientIdentityRepo = ciRepo
|
||||
srv.KeySetRepo = kRepo
|
||||
srv.ConnectorConfigRepo = cfgRepo
|
||||
|
@ -171,7 +172,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
|||
cfgRepo := db.NewConnectorConfigRepo(dbc)
|
||||
userRepo := db.NewUserRepo(dbc)
|
||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||
userManager := user.NewManager(userRepo, pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{})
|
||||
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
||||
|
||||
sm := session.NewSessionManager(sRepo, skRepo)
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
// handleVerifyEmailResendFunc will resend an email-verification email given a valid JWT for the user and a redirect URL.
|
||||
|
@ -190,7 +191,7 @@ type emailVerifiedTemplateData struct {
|
|||
}
|
||||
|
||||
func handleEmailVerifyFunc(verifiedTpl *template.Template, issuer url.URL, keysFunc func() ([]key.PublicKey,
|
||||
error), userManager *user.Manager) http.HandlerFunc {
|
||||
error), userManager *manager.UserManager) http.HandlerFunc {
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
|
@ -217,12 +218,12 @@ func handleEmailVerifyFunc(verifiedTpl *template.Template, issuer url.URL, keysF
|
|||
cbURL, err := userManager.VerifyEmail(ev)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case user.ErrorEmailAlreadyVerified:
|
||||
case manager.ErrorEmailAlreadyVerified:
|
||||
execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{
|
||||
Error: "Invalid Verification Link",
|
||||
Message: "Your email link has expired or has already been verified.",
|
||||
}, http.StatusBadRequest)
|
||||
case user.ErrorEVEmailDoesntMatch:
|
||||
case manager.ErrorEVEmailDoesntMatch:
|
||||
execTemplateWithStatus(w, verifiedTpl, emailVerifiedTemplateData{
|
||||
Error: "Invalid Verification Link",
|
||||
Message: "Your email link does not match the email address on file. Perhaps you have a more recent verification link?",
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
)
|
||||
|
@ -18,7 +19,7 @@ type invitationTemplateData struct {
|
|||
type InvitationHandler struct {
|
||||
issuerURL url.URL
|
||||
passwordResetURL url.URL
|
||||
um *user.Manager
|
||||
um *manager.UserManager
|
||||
keysFunc func() ([]key.PublicKey, error)
|
||||
signerFunc func() (jose.Signer, error)
|
||||
redirectValidityWindow time.Duration
|
||||
|
@ -55,13 +56,13 @@ func (h *InvitationHandler) handleGET(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
_, err = h.um.VerifyEmail(invite)
|
||||
if err != nil && err != user.ErrorEmailAlreadyVerified {
|
||||
if err != nil && err != manager.ErrorEmailAlreadyVerified {
|
||||
// Allow AlreadyVerified folks to pass through- otherwise
|
||||
// folks who encounter an error after passing this point will
|
||||
// never be able to set their passwords.
|
||||
log.Debugf("error attempting to verify email: %v", err)
|
||||
switch err {
|
||||
case user.ErrorEVEmailDoesntMatch:
|
||||
case manager.ErrorEVEmailDoesntMatch:
|
||||
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidRequest,
|
||||
"Your email does not match the email address on file"))
|
||||
return
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
type sendResetPasswordEmailData struct {
|
||||
|
@ -181,7 +182,7 @@ type resetPasswordTemplateData struct {
|
|||
type ResetPasswordHandler struct {
|
||||
tpl *template.Template
|
||||
issuerURL url.URL
|
||||
um *user.Manager
|
||||
um *manager.UserManager
|
||||
keysFunc func() ([]key.PublicKey, error)
|
||||
}
|
||||
|
||||
|
@ -237,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
|
|||
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case user.ErrorPasswordAlreadyChanged:
|
||||
case manager.ErrorPasswordAlreadyChanged:
|
||||
r.data.Error = "Link Expired"
|
||||
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
|
||||
r.data.DontShowForm = true
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -222,7 +223,7 @@ func handleRegisterFunc(s *Server) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func registerFromLocalConnector(userManager *user.Manager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
|
||||
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
|
||||
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -237,7 +238,7 @@ func registerFromLocalConnector(userManager *user.Manager, sessionManager *sessi
|
|||
return userID, nil
|
||||
}
|
||||
|
||||
func registerFromRemoteConnector(userManager *user.Manager, ses *session.Session, email string, emailVerified bool) (string, error) {
|
||||
func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
|
||||
if ses.Identity.ID == "" {
|
||||
return "", errors.New("No Identity found in session.")
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/coreos/dex/user"
|
||||
usersapi "github.com/coreos/dex/user/api"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -68,7 +69,7 @@ type Server struct {
|
|||
HealthChecks []health.Checkable
|
||||
Connectors []connector.Connector
|
||||
UserRepo user.UserRepo
|
||||
UserManager *user.Manager
|
||||
UserManager *manager.UserManager
|
||||
PasswordInfoRepo user.PasswordInfoRepo
|
||||
RefreshTokenRepo refresh.RefreshTokenRepo
|
||||
UserEmailer *useremail.UserEmailer
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/coreos/dex/session"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -91,7 +92,6 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
|
|||
func makeTestFixtures() (*testFixtures, error) {
|
||||
userRepo := user.NewUserRepoFromUsers(testUsers)
|
||||
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
|
||||
manager := user.NewManager(userRepo, pwRepo, repo.InMemTransactionFactory, user.ManagerOptions{})
|
||||
|
||||
connConfigs := []connector.ConnectorConfig{
|
||||
&connector.OIDCConnectorConfig{
|
||||
|
@ -111,6 +111,9 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
ID: "local",
|
||||
},
|
||||
}
|
||||
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
|
||||
|
||||
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
|
||||
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/api"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -33,11 +34,11 @@ var (
|
|||
type UserMgmtServer struct {
|
||||
api *api.UsersAPI
|
||||
jwtvFactory JWTVerifierFactory
|
||||
um *user.Manager
|
||||
um *manager.UserManager
|
||||
cir client.ClientIdentityRepo
|
||||
}
|
||||
|
||||
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *user.Manager, cir client.ClientIdentityRepo) *UserMgmtServer {
|
||||
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *manager.UserManager, cir client.ClientIdentityRepo) *UserMgmtServer {
|
||||
return &UserMgmtServer{
|
||||
api: userMgmtAPI,
|
||||
jwtvFactory: jwtvFactory,
|
||||
|
|
2
test
2
test
|
@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
|
|||
|
||||
source ./build
|
||||
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api email"
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email"
|
||||
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
||||
|
||||
# user has not provided PKG override
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/coreos/dex/pkg/log"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -81,7 +82,7 @@ func (e Error) Error() string {
|
|||
// calling User. It is assumed that the clientID has already validated as an
|
||||
// admin app before calling.
|
||||
type UsersAPI struct {
|
||||
manager *user.Manager
|
||||
manager *manager.UserManager
|
||||
localConnectorID string
|
||||
clientIdentityRepo client.ClientIdentityRepo
|
||||
emailer Emailer
|
||||
|
@ -96,7 +97,7 @@ type Creds struct {
|
|||
User user.User
|
||||
}
|
||||
|
||||
func NewUsersAPI(manager *user.Manager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI {
|
||||
func NewUsersAPI(manager *manager.UserManager, cir client.ClientIdentityRepo, emailer Emailer, localConnectorID string) *UsersAPI {
|
||||
return &UsersAPI{
|
||||
manager: manager,
|
||||
clientIdentityRepo: cir,
|
||||
|
|
|
@ -10,9 +10,11 @@ import (
|
|||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
type testEmailer struct {
|
||||
|
@ -123,7 +125,10 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
Password: []byte("password-2"),
|
||||
},
|
||||
})
|
||||
mgr := user.NewManager(ur, pwr, repo.InMemTransactionFactory, user.ManagerOptions{})
|
||||
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
|
||||
mgr.Clock = clock
|
||||
ci := oidc.ClientIdentity{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package user
|
||||
package manager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -6,8 +6,10 @@ import (
|
|||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/user"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -19,13 +21,14 @@ var (
|
|||
|
||||
// Manager performs user-related "business-logic" functions on user and related objects.
|
||||
// This is in contrast to the Repos which perform little more than CRUD operations.
|
||||
type Manager struct {
|
||||
type UserManager struct {
|
||||
Clock clockwork.Clock
|
||||
|
||||
userRepo UserRepo
|
||||
pwRepo PasswordInfoRepo
|
||||
userRepo user.UserRepo
|
||||
pwRepo user.PasswordInfoRepo
|
||||
connCfgRepo connector.ConnectorConfigRepo
|
||||
begin repo.TransactionFactory
|
||||
userIDGenerator UserIDGenerator
|
||||
userIDGenerator user.UserIDGenerator
|
||||
}
|
||||
|
||||
type ManagerOptions struct {
|
||||
|
@ -34,58 +37,59 @@ type ManagerOptions struct {
|
|||
// variable policies
|
||||
}
|
||||
|
||||
func NewManager(userRepo UserRepo, pwRepo PasswordInfoRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *Manager {
|
||||
return &Manager{
|
||||
func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, connCfgRepo connector.ConnectorConfigRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
|
||||
return &UserManager{
|
||||
Clock: clockwork.NewRealClock(),
|
||||
|
||||
userRepo: userRepo,
|
||||
pwRepo: pwRepo,
|
||||
connCfgRepo: connCfgRepo,
|
||||
begin: txnFactory,
|
||||
userIDGenerator: DefaultUserIDGenerator,
|
||||
userIDGenerator: user.DefaultUserIDGenerator,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Get(id string) (User, error) {
|
||||
func (m *UserManager) Get(id string) (user.User, error) {
|
||||
return m.userRepo.Get(nil, id)
|
||||
}
|
||||
|
||||
func (m *Manager) List(filter UserFilter, maxResults int, nextPageToken string) ([]User, string, error) {
|
||||
func (m *UserManager) List(filter user.UserFilter, maxResults int, nextPageToken string) ([]user.User, string, error) {
|
||||
return m.userRepo.List(nil, filter, maxResults, nextPageToken)
|
||||
}
|
||||
|
||||
// CreateUser creates a new user with the given hashedPassword; the connID should be the ID of the local connector.
|
||||
// The userID of the created user is returned as the first argument.
|
||||
func (m *Manager) CreateUser(user User, hashedPassword Password, connID string) (string, error) {
|
||||
func (m *UserManager) CreateUser(usr user.User, hashedPassword user.Password, connID string) (string, error) {
|
||||
tx, err := m.begin()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
insertedUser, err := m.insertNewUser(tx, user.Email, user.EmailVerified)
|
||||
insertedUser, err := m.insertNewUser(tx, usr.Email, usr.EmailVerified)
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
||||
user.ID = insertedUser.ID
|
||||
user.CreatedAt = insertedUser.CreatedAt
|
||||
err = m.userRepo.Update(tx, user)
|
||||
usr.ID = insertedUser.ID
|
||||
usr.CreatedAt = insertedUser.CreatedAt
|
||||
err = m.userRepo.Update(tx, usr)
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
||||
rid := RemoteIdentity{
|
||||
rid := user.RemoteIdentity{
|
||||
ConnectorID: connID,
|
||||
ID: user.ID,
|
||||
ID: usr.ID,
|
||||
}
|
||||
if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil {
|
||||
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
||||
pwi := PasswordInfo{
|
||||
UserID: user.ID,
|
||||
pwi := user.PasswordInfo{
|
||||
UserID: usr.ID,
|
||||
Password: hashedPassword,
|
||||
}
|
||||
err = m.pwRepo.Create(tx, pwi)
|
||||
|
@ -99,10 +103,10 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string)
|
|||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
return user.ID, nil
|
||||
return usr.ID, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Disable(userID string, disabled bool) error {
|
||||
func (m *UserManager) Disable(userID string, disabled bool) error {
|
||||
tx, err := m.begin()
|
||||
|
||||
if err = m.userRepo.Disable(tx, userID, disabled); err != nil {
|
||||
|
@ -119,7 +123,7 @@ func (m *Manager) Disable(userID string, disabled bool) error {
|
|||
}
|
||||
|
||||
// RegisterWithRemoteIdentity creates new user and attaches the given remote identity.
|
||||
func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) {
|
||||
func (m *UserManager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid user.RemoteIdentity) (string, error) {
|
||||
tx, err := m.begin()
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -127,20 +131,20 @@ func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, r
|
|||
|
||||
if _, err = m.userRepo.GetByRemoteIdentity(tx, rid); err == nil {
|
||||
rollback(tx)
|
||||
return "", ErrorDuplicateRemoteIdentity
|
||||
return "", user.ErrorDuplicateRemoteIdentity
|
||||
}
|
||||
if err != ErrorNotFound {
|
||||
if err != user.ErrorNotFound {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
||||
user, err := m.insertNewUser(tx, email, emailVerified)
|
||||
usr, err := m.insertNewUser(tx, email, emailVerified)
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil {
|
||||
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
@ -150,44 +154,44 @@ func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, r
|
|||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
return user.ID, nil
|
||||
return usr.ID, nil
|
||||
}
|
||||
|
||||
// RegisterWithPassword creates a new user with the given name and password.
|
||||
// connID is the connector ID of the ConnectorLocal connector.
|
||||
func (m *Manager) RegisterWithPassword(email, plaintext, connID string) (string, error) {
|
||||
func (m *UserManager) RegisterWithPassword(email, plaintext, connID string) (string, error) {
|
||||
tx, err := m.begin()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !ValidPassword(plaintext) {
|
||||
if !user.ValidPassword(plaintext) {
|
||||
rollback(tx)
|
||||
return "", ErrorInvalidPassword
|
||||
return "", user.ErrorInvalidPassword
|
||||
}
|
||||
|
||||
user, err := m.insertNewUser(tx, email, false)
|
||||
usr, err := m.insertNewUser(tx, email, false)
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
||||
rid := RemoteIdentity{
|
||||
rid := user.RemoteIdentity{
|
||||
ConnectorID: connID,
|
||||
ID: user.ID,
|
||||
ID: usr.ID,
|
||||
}
|
||||
if err := m.userRepo.AddRemoteIdentity(tx, user.ID, rid); err != nil {
|
||||
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
|
||||
password, err := NewPasswordFromPlaintext(plaintext)
|
||||
password, err := user.NewPasswordFromPlaintext(plaintext)
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
pwi := PasswordInfo{
|
||||
UserID: user.ID,
|
||||
pwi := user.PasswordInfo{
|
||||
UserID: usr.ID,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
|
@ -202,7 +206,7 @@ func (m *Manager) RegisterWithPassword(email, plaintext, connID string) (string,
|
|||
rollback(tx)
|
||||
return "", err
|
||||
}
|
||||
return user.ID, nil
|
||||
return usr.ID, nil
|
||||
}
|
||||
|
||||
type EmailVerifiable interface {
|
||||
|
@ -218,31 +222,31 @@ type EmailVerifiable interface {
|
|||
// create it, ensuring that the token was signed and that the JWT was not
|
||||
// expired.
|
||||
// The callback url (i.e. where to send the user after the verification) is returned.
|
||||
func (m *Manager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) {
|
||||
func (m *UserManager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) {
|
||||
tx, err := m.begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := m.userRepo.Get(tx, ev.UserID())
|
||||
usr, err := m.userRepo.Get(tx, ev.UserID())
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.Email != ev.Email() {
|
||||
if usr.Email != ev.Email() {
|
||||
rollback(tx)
|
||||
return nil, ErrorEVEmailDoesntMatch
|
||||
}
|
||||
|
||||
if user.EmailVerified {
|
||||
if usr.EmailVerified {
|
||||
rollback(tx)
|
||||
return nil, ErrorEmailAlreadyVerified
|
||||
}
|
||||
|
||||
user.EmailVerified = true
|
||||
usr.EmailVerified = true
|
||||
|
||||
err = m.userRepo.Update(tx, user)
|
||||
err = m.userRepo.Update(tx, usr)
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return nil, err
|
||||
|
@ -258,19 +262,19 @@ func (m *Manager) VerifyEmail(ev EmailVerifiable) (*url.URL, error) {
|
|||
|
||||
type PasswordChangeable interface {
|
||||
UserID() string
|
||||
Password() Password
|
||||
Password() user.Password
|
||||
Callback() *url.URL
|
||||
}
|
||||
|
||||
func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url.URL, error) {
|
||||
func (m *UserManager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url.URL, error) {
|
||||
tx, err := m.begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !ValidPassword(plaintext) {
|
||||
if !user.ValidPassword(plaintext) {
|
||||
rollback(tx)
|
||||
return nil, ErrorInvalidPassword
|
||||
return nil, user.ErrorInvalidPassword
|
||||
}
|
||||
|
||||
pwi, err := m.pwRepo.Get(tx, pwr.UserID())
|
||||
|
@ -284,7 +288,7 @@ func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url
|
|||
return nil, ErrorPasswordAlreadyChanged
|
||||
}
|
||||
|
||||
newPass, err := NewPasswordFromPlaintext(plaintext)
|
||||
newPass, err := user.NewPasswordFromPlaintext(plaintext)
|
||||
if err != nil {
|
||||
rollback(tx)
|
||||
return nil, err
|
||||
|
@ -305,36 +309,46 @@ func (m *Manager) ChangePassword(pwr PasswordChangeable, plaintext string) (*url
|
|||
return pwr.Callback(), nil
|
||||
}
|
||||
|
||||
func (m *Manager) insertNewUser(tx repo.Transaction, email string, emailVerified bool) (User, error) {
|
||||
if !ValidEmail(email) {
|
||||
return User{}, ErrorInvalidEmail
|
||||
func (m *UserManager) insertNewUser(tx repo.Transaction, email string, emailVerified bool) (user.User, error) {
|
||||
if !user.ValidEmail(email) {
|
||||
return user.User{}, user.ErrorInvalidEmail
|
||||
}
|
||||
|
||||
var err error
|
||||
if _, err = m.userRepo.GetByEmail(tx, email); err == nil {
|
||||
return User{}, ErrorDuplicateEmail
|
||||
return user.User{}, user.ErrorDuplicateEmail
|
||||
}
|
||||
if err != ErrorNotFound {
|
||||
return User{}, err
|
||||
if err != user.ErrorNotFound {
|
||||
return user.User{}, err
|
||||
}
|
||||
|
||||
userID, err := m.userIDGenerator()
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
return user.User{}, err
|
||||
}
|
||||
|
||||
user := User{
|
||||
usr := user.User{
|
||||
ID: userID,
|
||||
Email: email,
|
||||
EmailVerified: emailVerified,
|
||||
CreatedAt: m.Clock.Now(),
|
||||
}
|
||||
|
||||
err = m.userRepo.Create(tx, user)
|
||||
err = m.userRepo.Create(tx, usr)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
return user.User{}, err
|
||||
}
|
||||
return user, nil
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
func (m *UserManager) addRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error {
|
||||
if _, err := m.connCfgRepo.GetConnectorByID(tx, rid.ConnectorID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.userRepo.AddRemoteIdentity(tx, userID, rid); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rollback(tx repo.Transaction) {
|
|
@ -1,4 +1,4 @@
|
|||
package user
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
|
@ -9,13 +9,16 @@ import (
|
|||
"github.com/jonboulle/clockwork"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/dex/user"
|
||||
)
|
||||
|
||||
type testFixtures struct {
|
||||
ur UserRepo
|
||||
pwr PasswordInfoRepo
|
||||
mgr *Manager
|
||||
ur user.UserRepo
|
||||
pwr user.PasswordInfoRepo
|
||||
ccr connector.ConnectorConfigRepo
|
||||
mgr *UserManager
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
|
@ -23,25 +26,25 @@ func makeTestFixtures() *testFixtures {
|
|||
f := &testFixtures{}
|
||||
f.clock = clockwork.NewFakeClock()
|
||||
|
||||
f.ur = NewUserRepoFromUsers([]UserWithRemoteIdentities{
|
||||
f.ur = user.NewUserRepoFromUsers([]user.UserWithRemoteIdentities{
|
||||
{
|
||||
User: User{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
Email: "Email-1@example.com",
|
||||
},
|
||||
RemoteIdentities: []RemoteIdentity{
|
||||
RemoteIdentities: []user.RemoteIdentity{
|
||||
{
|
||||
ConnectorID: "local",
|
||||
ID: "1",
|
||||
},
|
||||
},
|
||||
}, {
|
||||
User: User{
|
||||
User: user.User{
|
||||
ID: "ID-2",
|
||||
Email: "Email-2@example.com",
|
||||
EmailVerified: true,
|
||||
},
|
||||
RemoteIdentities: []RemoteIdentity{
|
||||
RemoteIdentities: []user.RemoteIdentity{
|
||||
{
|
||||
ConnectorID: "local",
|
||||
ID: "2",
|
||||
|
@ -49,7 +52,7 @@ func makeTestFixtures() *testFixtures {
|
|||
},
|
||||
},
|
||||
})
|
||||
f.pwr = NewPasswordInfoRepoFromPasswordInfos([]PasswordInfo{
|
||||
f.pwr = user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
|
||||
{
|
||||
UserID: "ID-1",
|
||||
Password: []byte("password-1"),
|
||||
|
@ -59,7 +62,10 @@ func makeTestFixtures() *testFixtures {
|
|||
Password: []byte("password-2"),
|
||||
},
|
||||
})
|
||||
f.mgr = NewManager(f.ur, f.pwr, repo.InMemTransactionFactory, ManagerOptions{})
|
||||
f.ccr = connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
|
||||
&connector.LocalConnectorConfig{ID: "local"},
|
||||
})
|
||||
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, repo.InMemTransactionFactory, ManagerOptions{})
|
||||
f.mgr.Clock = f.clock
|
||||
return f
|
||||
}
|
||||
|
@ -68,13 +74,13 @@ func TestRegisterWithRemoteIdentity(t *testing.T) {
|
|||
tests := []struct {
|
||||
email string
|
||||
emailVerified bool
|
||||
rid RemoteIdentity
|
||||
rid user.RemoteIdentity
|
||||
err error
|
||||
}{
|
||||
{
|
||||
email: "email@example.com",
|
||||
emailVerified: false,
|
||||
rid: RemoteIdentity{
|
||||
rid: user.RemoteIdentity{
|
||||
ConnectorID: "local",
|
||||
ID: "1234",
|
||||
},
|
||||
|
@ -82,20 +88,29 @@ func TestRegisterWithRemoteIdentity(t *testing.T) {
|
|||
},
|
||||
{
|
||||
emailVerified: false,
|
||||
rid: RemoteIdentity{
|
||||
rid: user.RemoteIdentity{
|
||||
ConnectorID: "local",
|
||||
ID: "1234",
|
||||
},
|
||||
err: ErrorInvalidEmail,
|
||||
err: user.ErrorInvalidEmail,
|
||||
},
|
||||
{
|
||||
email: "email@example.com",
|
||||
emailVerified: false,
|
||||
rid: RemoteIdentity{
|
||||
rid: user.RemoteIdentity{
|
||||
ConnectorID: "local",
|
||||
ID: "1",
|
||||
},
|
||||
err: ErrorDuplicateRemoteIdentity,
|
||||
err: user.ErrorDuplicateRemoteIdentity,
|
||||
},
|
||||
{
|
||||
email: "anotheremail@example.com",
|
||||
emailVerified: false,
|
||||
rid: user.RemoteIdentity{
|
||||
ConnectorID: "idonotexist",
|
||||
ID: "1",
|
||||
},
|
||||
err: connector.ErrorNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -148,17 +163,17 @@ func TestRegisterWithPassword(t *testing.T) {
|
|||
},
|
||||
{
|
||||
plaintext: "secretpassword123",
|
||||
err: ErrorInvalidEmail,
|
||||
err: user.ErrorInvalidEmail,
|
||||
},
|
||||
{
|
||||
email: "email@example.com",
|
||||
err: ErrorInvalidPassword,
|
||||
err: user.ErrorInvalidPassword,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
connID := "connID"
|
||||
connID := "local"
|
||||
userID, err := f.mgr.RegisterWithPassword(
|
||||
tt.email,
|
||||
tt.plaintext,
|
||||
|
@ -183,7 +198,7 @@ func TestRegisterWithPassword(t *testing.T) {
|
|||
t.Errorf("case %d: user.EmailVerified: want=%v, got=%v", i, false, usr.EmailVerified)
|
||||
}
|
||||
|
||||
ridUSR, err := f.ur.GetByRemoteIdentity(nil, RemoteIdentity{
|
||||
ridUSR, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{
|
||||
ID: userID,
|
||||
ConnectorID: connID,
|
||||
})
|
||||
|
@ -220,12 +235,12 @@ func TestVerifyEmail(t *testing.T) {
|
|||
callback := "http://client.example.com/callback"
|
||||
expires := time.Hour * 3
|
||||
|
||||
makeClaims := func(usr User) jose.Claims {
|
||||
makeClaims := func(usr user.User) jose.Claims {
|
||||
return map[string]interface{}{
|
||||
"iss": issuer.String(),
|
||||
"aud": clientID,
|
||||
ClaimEmailVerificationCallback: callback,
|
||||
ClaimEmailVerificationEmail: usr.Email,
|
||||
user.ClaimEmailVerificationCallback: callback,
|
||||
user.ClaimEmailVerificationEmail: usr.Email,
|
||||
"exp": float64(now.Add(expires).Unix()),
|
||||
"sub": usr.ID,
|
||||
"iat": float64(now.Unix()),
|
||||
|
@ -238,28 +253,28 @@ func TestVerifyEmail(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
// happy path
|
||||
evClaims: makeClaims(User{ID: "ID-1", Email: "Email-1@example.com"}),
|
||||
evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-1@example.com"}),
|
||||
},
|
||||
{
|
||||
// non-matching email
|
||||
evClaims: makeClaims(User{ID: "ID-1", Email: "Email-2@example.com"}),
|
||||
evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-2@example.com"}),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
// already verified email
|
||||
evClaims: makeClaims(User{ID: "ID-2", Email: "Email-2@example.com"}),
|
||||
evClaims: makeClaims(user.User{ID: "ID-2", Email: "Email-2@example.com"}),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
// non-existent user.
|
||||
evClaims: makeClaims(User{ID: "ID-UNKNOWN", Email: "noone@example.com"}),
|
||||
evClaims: makeClaims(user.User{ID: "ID-UNKNOWN", Email: "noone@example.com"}),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
cb, err := f.mgr.VerifyEmail(EmailVerification{tt.evClaims})
|
||||
cb, err := f.mgr.VerifyEmail(user.EmailVerification{tt.evClaims})
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
|
@ -271,9 +286,9 @@ func TestVerifyEmail(t *testing.T) {
|
|||
t.Errorf("case %d: want err=nil got=%q", i, err)
|
||||
}
|
||||
|
||||
if cb.String() != tt.evClaims[ClaimEmailVerificationCallback] {
|
||||
if cb.String() != tt.evClaims[user.ClaimEmailVerificationCallback] {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, cb.String(),
|
||||
tt.evClaims[ClaimEmailVerificationCallback])
|
||||
tt.evClaims[user.ClaimEmailVerificationCallback])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -290,8 +305,8 @@ func TestChangePassword(t *testing.T) {
|
|||
return map[string]interface{}{
|
||||
"iss": issuer.String(),
|
||||
"aud": clientID,
|
||||
ClaimPasswordResetCallback: callback,
|
||||
ClaimPasswordResetPassword: password,
|
||||
user.ClaimPasswordResetCallback: callback,
|
||||
user.ClaimPasswordResetPassword: password,
|
||||
"exp": float64(now.Add(expires).Unix()),
|
||||
"sub": usrID,
|
||||
"iat": float64(now.Unix()),
|
||||
|
@ -329,7 +344,7 @@ func TestChangePassword(t *testing.T) {
|
|||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
cb, err := f.mgr.ChangePassword(PasswordReset{tt.pwrClaims}, tt.newPassword)
|
||||
cb, err := f.mgr.ChangePassword(user.PasswordReset{tt.pwrClaims}, tt.newPassword)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
|
@ -346,47 +361,61 @@ func TestChangePassword(t *testing.T) {
|
|||
if cb != nil {
|
||||
cbString = cb.String()
|
||||
}
|
||||
if cbString != tt.pwrClaims[ClaimPasswordResetCallback] {
|
||||
if cbString != tt.pwrClaims[user.ClaimPasswordResetCallback] {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, cb.String(),
|
||||
tt.pwrClaims[ClaimPasswordResetCallback])
|
||||
tt.pwrClaims[user.ClaimPasswordResetCallback])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
usr User
|
||||
hashedPW Password
|
||||
usr user.User
|
||||
hashedPW user.Password
|
||||
localID string // defaults to "local"
|
||||
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
usr: User{
|
||||
usr: user.User{
|
||||
DisplayName: "Bob Exampleson",
|
||||
Email: "bob@example.com",
|
||||
},
|
||||
hashedPW: Password("I am a hash"),
|
||||
hashedPW: user.Password("I am a hash"),
|
||||
},
|
||||
{
|
||||
usr: User{
|
||||
usr: user.User{
|
||||
DisplayName: "Al Adminson",
|
||||
Email: "al@example.com",
|
||||
Admin: true,
|
||||
},
|
||||
hashedPW: Password("I am a hash"),
|
||||
hashedPW: user.Password("I am a hash"),
|
||||
},
|
||||
{
|
||||
usr: User{
|
||||
usr: user.User{
|
||||
DisplayName: "Ed Emailless",
|
||||
},
|
||||
hashedPW: Password("I am a hash"),
|
||||
hashedPW: user.Password("I am a hash"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
usr: user.User{
|
||||
DisplayName: "Eric Exampleson",
|
||||
Email: "eric@example.com",
|
||||
},
|
||||
hashedPW: user.Password("I am a hash"),
|
||||
localID: "abadlocalid",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, "local")
|
||||
localID := "local"
|
||||
if tt.localID != "" {
|
||||
localID = tt.localID
|
||||
}
|
||||
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
|
@ -422,7 +451,7 @@ func TestCreateUser(t *testing.T) {
|
|||
t.Errorf("case %d: want=%q, got=%q", i, tt.hashedPW, pwi.Password)
|
||||
}
|
||||
|
||||
ridUser, err := f.ur.GetByRemoteIdentity(nil, RemoteIdentity{
|
||||
ridUser, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{
|
||||
ID: id,
|
||||
ConnectorID: "local",
|
||||
})
|
Loading…
Reference in a new issue