diff --git a/connector/connector_ldap.go b/connector/connector_ldap.go index 8ba45aed..a9d64686 100644 --- a/connector/connector_ldap.go +++ b/connector/connector_ldap.go @@ -251,9 +251,9 @@ func (c *LDAPConnector) LoginURL(sessionKey, prompt string) (string, error) { return path.Join(c.namespace.Path, "login") + "?" + enc, nil } -func (c *LDAPConnector) Register(mux *http.ServeMux, errorURL url.URL) { - route := path.Join(c.namespace.Path, "login") - mux.Handle(route, handlePasswordLogin(c.loginFunc, c.loginTpl, c, route, errorURL)) +func (c *LDAPConnector) Handler(errorURL url.URL) http.Handler { + route := path.Join(c.namespace.Path, "/login") + return handlePasswordLogin(c.loginFunc, c.loginTpl, c, route, errorURL) } func (c *LDAPConnector) Sync() chan struct{} { diff --git a/connector/connector_local.go b/connector/connector_local.go index a5822587..1f89c277 100644 --- a/connector/connector_local.go +++ b/connector/connector_local.go @@ -85,9 +85,9 @@ func (c *LocalConnector) LoginURL(sessionKey, prompt string) (string, error) { return path.Join(c.namespace.Path, "login") + "?" + enc, nil } -func (c *LocalConnector) Register(mux *http.ServeMux, errorURL url.URL) { - route := c.namespace.Path + "/login" - mux.Handle(route, handlePasswordLogin(c.loginFunc, c.loginTpl, c.idp, route, errorURL)) +func (c *LocalConnector) Handler(errorURL url.URL) http.Handler { + route := path.Join(c.namespace.Path, "/login") + return handlePasswordLogin(c.loginFunc, c.loginTpl, c.idp, route, errorURL) } func (c *LocalConnector) Sync() chan struct{} { diff --git a/connector/connector_oauth2.go b/connector/connector_oauth2.go index 3a5ed145..4e6354a0 100644 --- a/connector/connector_oauth2.go +++ b/connector/connector_oauth2.go @@ -53,8 +53,8 @@ func (c *OAuth2Connector) LoginURL(sessionKey, prompt string) (string, error) { return c.conn.Client().AuthCodeURL(sessionKey, oauth2.GrantTypeAuthCode, prompt), nil } -func (c *OAuth2Connector) Register(mux *http.ServeMux, errorURL url.URL) { - mux.Handle(c.cbURL.Path, c.handleCallbackFunc(c.loginFunc, errorURL)) +func (c *OAuth2Connector) Handler(errorURL url.URL) http.Handler { + return c.handleCallbackFunc(c.loginFunc, errorURL) } func (c *OAuth2Connector) handleCallbackFunc(lf oidc.LoginFunc, errorURL url.URL) http.HandlerFunc { diff --git a/connector/connector_oidc.go b/connector/connector_oidc.go index e7abb7eb..ee320c86 100644 --- a/connector/connector_oidc.go +++ b/connector/connector_oidc.go @@ -90,8 +90,8 @@ func (c *OIDCConnector) LoginURL(sessionKey, prompt string) (string, error) { return oac.AuthCodeURL(sessionKey, "", prompt), nil } -func (c *OIDCConnector) Register(mux *http.ServeMux, errorURL url.URL) { - mux.Handle(c.cbURL.Path, c.handleCallbackFunc(c.loginFunc, errorURL)) +func (c *OIDCConnector) Handler(errorURL url.URL) http.Handler { + return c.handleCallbackFunc(c.loginFunc, errorURL) } func (c *OIDCConnector) Sync() chan struct{} { diff --git a/connector/interface.go b/connector/interface.go index 6c79ffe1..d3013dd5 100644 --- a/connector/interface.go +++ b/connector/interface.go @@ -21,12 +21,12 @@ type Connector interface { // and OAuth2 prompt type. LoginURL(sessionKey, prompt string) (string, error) - // Register allows connectors to register a callback handler with the + // Handler allows connectors to register a callback handler with the // dex server. // - // Connectors should register with a path that extends the namespace - // URL provided when the Connector is instantiated. - Register(mux *http.ServeMux, errorURL url.URL) + // Connectors will handle any path that extends the namespace URL provided + // when the Connector is instantiated. + Handler(errorURL url.URL) http.Handler // Sync triggers any long-running tasks needed to maintain the // Connector's operation. For example, this would encompass diff --git a/server/http_test.go b/server/http_test.go index 637f53b7..28f0897a 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -40,7 +40,9 @@ func (f *fakeConnector) LoginURL(sessionKey, prompt string) (string, error) { return f.loginURL, nil } -func (f *fakeConnector) Register(mux *http.ServeMux, errorURL url.URL) {} +func (f *fakeConnector) Handler(errorURL url.URL) http.Handler { + return http.HandlerFunc(http.NotFound) +} func (f *fakeConnector) Sync() chan struct{} { return nil diff --git a/server/server.go b/server/server.go index 32c1f0ec..3cb9e5f3 100644 --- a/server/server.go +++ b/server/server.go @@ -269,7 +269,9 @@ func (s *Server) HTTPHandler() http.Handler { if err != nil { log.Fatal(err) } - idpc.Register(mux, *errorURL) + // NOTE(ericchiang): This path MUST end in a "/" in order to indicate a + // path prefix rather than an absolute path. + mux.Handle(path.Join(httpPathAuth, idpc.ID())+"/", idpc.Handler(*errorURL)) } apiBasePath := path.Join(httpPathAPI, APIVersion)