diff --git a/cmd/dex-worker/main.go b/cmd/dex-worker/main.go index bf1facfc..1ebbc10e 100644 --- a/cmd/dex-worker/main.go +++ b/cmd/dex-worker/main.go @@ -44,6 +44,8 @@ func main() { emailFrom := fs.String("email-from", "no-reply@coreos.com", "emails sent from dex will come from this address") emailConfig := fs.String("email-cfg", "./static/fixtures/emailer.json", "configures emailer.") + enableRegistration := fs.Bool("enable-registration", false, "Allows users to self-register") + noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing") // UI-related: @@ -113,13 +115,14 @@ func main() { } scfg := server.ServerConfig{ - IssuerURL: *issuer, - TemplateDir: *templates, - EmailTemplateDirs: emailTemplateDirs, - EmailFromAddress: *emailFrom, - EmailerConfigFile: *emailConfig, - IssuerName: *issuerName, - IssuerLogoURL: *issuerLogoURL, + IssuerURL: *issuer, + TemplateDir: *templates, + EmailTemplateDirs: emailTemplateDirs, + EmailFromAddress: *emailFrom, + EmailerConfigFile: *emailConfig, + IssuerName: *issuerName, + IssuerLogoURL: *issuerLogoURL, + EnableRegistration: *enableRegistration, } if *noDB { diff --git a/server/config.go b/server/config.go index 3c3d4432..fe7c41f7 100644 --- a/server/config.go +++ b/server/config.go @@ -25,14 +25,15 @@ import ( ) type ServerConfig struct { - IssuerURL string - IssuerName string - IssuerLogoURL string - TemplateDir string - EmailTemplateDirs []string - EmailFromAddress string - EmailerConfigFile string - StateConfig StateConfigurer + IssuerURL string + IssuerName string + IssuerLogoURL string + TemplateDir string + EmailTemplateDirs []string + EmailFromAddress string + EmailerConfigFile string + StateConfig StateConfigurer + EnableRegistration bool } type StateConfigurer interface { @@ -56,7 +57,7 @@ func (cfg *ServerConfig) Server() (*Server, error) { return nil, err } - tpl, err := getTemplates(cfg.IssuerName, cfg.IssuerLogoURL, cfg.TemplateDir) + tpl, err := getTemplates(cfg.IssuerName, cfg.IssuerLogoURL, cfg.EnableRegistration, cfg.TemplateDir) if err != nil { return nil, err } @@ -69,6 +70,8 @@ func (cfg *ServerConfig) Server() (*Server, error) { HealthChecks: []health.Checkable{km}, Connectors: []connector.Connector{}, + + EnableRegistration: cfg.EnableRegistration, } err = cfg.StateConfig.Configure(&srv) @@ -183,7 +186,8 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error { return nil } -func getTemplates(issuerName, issuerLogoURL string, dir string) (*template.Template, error) { +func getTemplates(issuerName, issuerLogoURL string, + enableRegister bool, dir string) (*template.Template, error) { tpl := template.New("").Funcs(map[string]interface{}{ "issuerName": func() string { return issuerName @@ -191,6 +195,9 @@ func getTemplates(issuerName, issuerLogoURL string, dir string) (*template.Templ "issuerLogoURL": func() string { return issuerLogoURL }, + "enableRegister": func() bool { + return enableRegister + }, }) return tpl.ParseGlob(dir + "/*.html") diff --git a/server/http.go b/server/http.go index 170817d6..4746b114 100644 --- a/server/http.go +++ b/server/http.go @@ -254,7 +254,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp execTemplate(w, tpl, td) } -func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template) http.HandlerFunc { +func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc { idx := makeConnectorMap(idpcs) return func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { @@ -264,7 +264,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T } q := r.URL.Query() - register := q.Get("register") == "1" + register := q.Get("register") == "1" && registrationEnabled e := q.Get("error") if e != "" { sessionKey := q.Get("state") diff --git a/server/http_test.go b/server/http_test.go index 4f812204..acc172f3 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -51,7 +51,7 @@ func (c *fakeConnector) TrustedEmailProvider() bool { func TestHandleAuthFuncMethodNotAllowed(t *testing.T) { for _, m := range []string{"POST", "PUT", "DELETE"} { - hdlr := handleAuthFunc(nil, nil, nil) + hdlr := handleAuthFunc(nil, nil, nil, true) req, err := http.NewRequest(m, "http://example.com", nil) if err != nil { t.Errorf("case %s: unable to create HTTP request: %v", m, err) @@ -170,7 +170,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { } for i, tt := range tests { - hdlr := handleAuthFunc(srv, idpcs, nil) + hdlr := handleAuthFunc(srv, idpcs, nil, true) w := httptest.NewRecorder() u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) req, err := http.NewRequest("GET", u, nil) @@ -271,7 +271,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { } for i, tt := range tests { - hdlr := handleAuthFunc(srv, idpcs, nil) + hdlr := handleAuthFunc(srv, idpcs, nil, true) w := httptest.NewRecorder() u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) req, err := http.NewRequest("GET", u, nil) diff --git a/server/server.go b/server/server.go index 38714781..91f7b5d0 100644 --- a/server/server.go +++ b/server/server.go @@ -72,6 +72,7 @@ type Server struct { PasswordInfoRepo user.PasswordInfoRepo RefreshTokenRepo refresh.RefreshTokenRepo UserEmailer *useremail.UserEmailer + EnableRegistration bool localConnectorID string } @@ -198,11 +199,15 @@ func (s *Server) HTTPHandler() http.Handler { clock := clockwork.NewRealClock() mux := http.NewServeMux() mux.HandleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig())) - mux.HandleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate)) + mux.HandleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate, s.EnableRegistration)) mux.HandleFunc(httpPathToken, handleTokenFunc(s)) mux.HandleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock)) mux.Handle(httpPathHealth, makeHealthHandler(checks)) - mux.HandleFunc(httpPathRegister, handleRegisterFunc(s)) + + if s.EnableRegistration { + mux.HandleFunc(httpPathRegister, handleRegisterFunc(s)) + } + mux.HandleFunc(httpPathEmailVerify, handleEmailVerifyFunc(s.VerifyEmailTemplate, s.IssuerURL, s.KeyManager.PublicKeys, s.UserManager)) diff --git a/server/testutil.go b/server/testutil.go index 8df077e4..b43d61b5 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -126,7 +126,8 @@ func makeTestFixtures() (*testFixtures, error) { return nil, err } - tpl, err := getTemplates("dex", "https://coreos.com/assets/images/brand/coreos-mark-30px.png", templatesLocation) + tpl, err := getTemplates("dex", "https://coreos.com/assets/images/brand/coreos-mark-30px.png", + true, templatesLocation) if err != nil { return nil, err } diff --git a/static/html/login.html b/static/html/login.html index 330497f5..2e9af292 100644 --- a/static/html/login.html +++ b/static/html/login.html @@ -68,9 +68,11 @@ {{ if not .Error }} {{ end }}