forked from mystiq/dex
Merge pull request #144 from bobbyrullo/no_register
server,cmd: Add flag for disabling registation
This commit is contained in:
commit
48b3b38c8b
7 changed files with 45 additions and 27 deletions
|
@ -44,6 +44,8 @@ func main() {
|
||||||
emailFrom := fs.String("email-from", "no-reply@coreos.com", "emails sent from dex will come from this address")
|
emailFrom := fs.String("email-from", "no-reply@coreos.com", "emails sent from dex will come from this address")
|
||||||
emailConfig := fs.String("email-cfg", "./static/fixtures/emailer.json", "configures emailer.")
|
emailConfig := fs.String("email-cfg", "./static/fixtures/emailer.json", "configures emailer.")
|
||||||
|
|
||||||
|
enableRegistration := fs.Bool("enable-registration", false, "Allows users to self-register")
|
||||||
|
|
||||||
noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing")
|
noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing")
|
||||||
|
|
||||||
// UI-related:
|
// UI-related:
|
||||||
|
@ -120,6 +122,7 @@ func main() {
|
||||||
EmailerConfigFile: *emailConfig,
|
EmailerConfigFile: *emailConfig,
|
||||||
IssuerName: *issuerName,
|
IssuerName: *issuerName,
|
||||||
IssuerLogoURL: *issuerLogoURL,
|
IssuerLogoURL: *issuerLogoURL,
|
||||||
|
EnableRegistration: *enableRegistration,
|
||||||
}
|
}
|
||||||
|
|
||||||
if *noDB {
|
if *noDB {
|
||||||
|
|
|
@ -33,6 +33,7 @@ type ServerConfig struct {
|
||||||
EmailFromAddress string
|
EmailFromAddress string
|
||||||
EmailerConfigFile string
|
EmailerConfigFile string
|
||||||
StateConfig StateConfigurer
|
StateConfig StateConfigurer
|
||||||
|
EnableRegistration bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateConfigurer interface {
|
type StateConfigurer interface {
|
||||||
|
@ -56,7 +57,7 @@ func (cfg *ServerConfig) Server() (*Server, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tpl, err := getTemplates(cfg.IssuerName, cfg.IssuerLogoURL, cfg.TemplateDir)
|
tpl, err := getTemplates(cfg.IssuerName, cfg.IssuerLogoURL, cfg.EnableRegistration, cfg.TemplateDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -69,6 +70,8 @@ func (cfg *ServerConfig) Server() (*Server, error) {
|
||||||
|
|
||||||
HealthChecks: []health.Checkable{km},
|
HealthChecks: []health.Checkable{km},
|
||||||
Connectors: []connector.Connector{},
|
Connectors: []connector.Connector{},
|
||||||
|
|
||||||
|
EnableRegistration: cfg.EnableRegistration,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = cfg.StateConfig.Configure(&srv)
|
err = cfg.StateConfig.Configure(&srv)
|
||||||
|
@ -183,7 +186,8 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTemplates(issuerName, issuerLogoURL string, dir string) (*template.Template, error) {
|
func getTemplates(issuerName, issuerLogoURL string,
|
||||||
|
enableRegister bool, dir string) (*template.Template, error) {
|
||||||
tpl := template.New("").Funcs(map[string]interface{}{
|
tpl := template.New("").Funcs(map[string]interface{}{
|
||||||
"issuerName": func() string {
|
"issuerName": func() string {
|
||||||
return issuerName
|
return issuerName
|
||||||
|
@ -191,6 +195,9 @@ func getTemplates(issuerName, issuerLogoURL string, dir string) (*template.Templ
|
||||||
"issuerLogoURL": func() string {
|
"issuerLogoURL": func() string {
|
||||||
return issuerLogoURL
|
return issuerLogoURL
|
||||||
},
|
},
|
||||||
|
"enableRegister": func() bool {
|
||||||
|
return enableRegister
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return tpl.ParseGlob(dir + "/*.html")
|
return tpl.ParseGlob(dir + "/*.html")
|
||||||
|
|
|
@ -254,7 +254,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp
|
||||||
execTemplate(w, tpl, td)
|
execTemplate(w, tpl, td)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template) http.HandlerFunc {
|
func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
|
||||||
idx := makeConnectorMap(idpcs)
|
idx := makeConnectorMap(idpcs)
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
|
@ -264,7 +264,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
||||||
}
|
}
|
||||||
|
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
register := q.Get("register") == "1"
|
register := q.Get("register") == "1" && registrationEnabled
|
||||||
e := q.Get("error")
|
e := q.Get("error")
|
||||||
if e != "" {
|
if e != "" {
|
||||||
sessionKey := q.Get("state")
|
sessionKey := q.Get("state")
|
||||||
|
|
|
@ -51,7 +51,7 @@ func (c *fakeConnector) TrustedEmailProvider() bool {
|
||||||
|
|
||||||
func TestHandleAuthFuncMethodNotAllowed(t *testing.T) {
|
func TestHandleAuthFuncMethodNotAllowed(t *testing.T) {
|
||||||
for _, m := range []string{"POST", "PUT", "DELETE"} {
|
for _, m := range []string{"POST", "PUT", "DELETE"} {
|
||||||
hdlr := handleAuthFunc(nil, nil, nil)
|
hdlr := handleAuthFunc(nil, nil, nil, true)
|
||||||
req, err := http.NewRequest(m, "http://example.com", nil)
|
req, err := http.NewRequest(m, "http://example.com", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("case %s: unable to create HTTP request: %v", m, err)
|
t.Errorf("case %s: unable to create HTTP request: %v", m, err)
|
||||||
|
@ -170,7 +170,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
hdlr := handleAuthFunc(srv, idpcs, nil)
|
hdlr := handleAuthFunc(srv, idpcs, nil, true)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
|
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
|
||||||
req, err := http.NewRequest("GET", u, nil)
|
req, err := http.NewRequest("GET", u, nil)
|
||||||
|
@ -271,7 +271,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
hdlr := handleAuthFunc(srv, idpcs, nil)
|
hdlr := handleAuthFunc(srv, idpcs, nil, true)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
|
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
|
||||||
req, err := http.NewRequest("GET", u, nil)
|
req, err := http.NewRequest("GET", u, nil)
|
||||||
|
|
|
@ -72,6 +72,7 @@ type Server struct {
|
||||||
PasswordInfoRepo user.PasswordInfoRepo
|
PasswordInfoRepo user.PasswordInfoRepo
|
||||||
RefreshTokenRepo refresh.RefreshTokenRepo
|
RefreshTokenRepo refresh.RefreshTokenRepo
|
||||||
UserEmailer *useremail.UserEmailer
|
UserEmailer *useremail.UserEmailer
|
||||||
|
EnableRegistration bool
|
||||||
|
|
||||||
localConnectorID string
|
localConnectorID string
|
||||||
}
|
}
|
||||||
|
@ -198,11 +199,15 @@ func (s *Server) HTTPHandler() http.Handler {
|
||||||
clock := clockwork.NewRealClock()
|
clock := clockwork.NewRealClock()
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig()))
|
mux.HandleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig()))
|
||||||
mux.HandleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate))
|
mux.HandleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate, s.EnableRegistration))
|
||||||
mux.HandleFunc(httpPathToken, handleTokenFunc(s))
|
mux.HandleFunc(httpPathToken, handleTokenFunc(s))
|
||||||
mux.HandleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock))
|
mux.HandleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock))
|
||||||
mux.Handle(httpPathHealth, makeHealthHandler(checks))
|
mux.Handle(httpPathHealth, makeHealthHandler(checks))
|
||||||
|
|
||||||
|
if s.EnableRegistration {
|
||||||
mux.HandleFunc(httpPathRegister, handleRegisterFunc(s))
|
mux.HandleFunc(httpPathRegister, handleRegisterFunc(s))
|
||||||
|
}
|
||||||
|
|
||||||
mux.HandleFunc(httpPathEmailVerify, handleEmailVerifyFunc(s.VerifyEmailTemplate,
|
mux.HandleFunc(httpPathEmailVerify, handleEmailVerifyFunc(s.VerifyEmailTemplate,
|
||||||
s.IssuerURL, s.KeyManager.PublicKeys, s.UserManager))
|
s.IssuerURL, s.KeyManager.PublicKeys, s.UserManager))
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,8 @@ func makeTestFixtures() (*testFixtures, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tpl, err := getTemplates("dex", "https://coreos.com/assets/images/brand/coreos-mark-30px.png", templatesLocation)
|
tpl, err := getTemplates("dex", "https://coreos.com/assets/images/brand/coreos-mark-30px.png",
|
||||||
|
true, templatesLocation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,8 +70,10 @@
|
||||||
{{ if .Register }}
|
{{ if .Register }}
|
||||||
Already have an account? <a href="{{ .RegisterOrLoginURL }}">Log in</a>
|
Already have an account? <a href="{{ .RegisterOrLoginURL }}">Log in</a>
|
||||||
{{ else }}
|
{{ else }}
|
||||||
|
{{ if enableRegister }}
|
||||||
Don't have an account yet? <a href="{{ .RegisterOrLoginURL }}">Register</a>
|
Don't have an account yet? <a href="{{ .RegisterOrLoginURL }}">Register</a>
|
||||||
{{ end }}
|
{{ end }}
|
||||||
|
{{ end }}
|
||||||
</div>
|
</div>
|
||||||
{{ end }}
|
{{ end }}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue