diff --git a/cmd/dex/config_test.go b/cmd/dex/config_test.go index 8ee02d5a..9a36afcf 100644 --- a/cmd/dex/config_test.go +++ b/cmd/dex/config_test.go @@ -74,7 +74,6 @@ web: http: 127.0.0.1:5556 frontend: - dir: ./web extra: foo: bar @@ -144,7 +143,6 @@ logger: HTTP: "127.0.0.1:5556", }, Frontend: server.WebConfig{ - Dir: "./web", Extra: map[string]string{ "foo": "bar", }, @@ -274,7 +272,6 @@ web: http: 127.0.0.1:5556 frontend: - dir: ./web extra: foo: bar @@ -352,7 +349,6 @@ logger: HTTP: "127.0.0.1:5556", }, Frontend: server.WebConfig{ - Dir: "./web", Extra: map[string]string{ "foo": "bar", }, diff --git a/server/handlers_test.go b/server/handlers_test.go index 8ad59d94..96d228af 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -135,7 +135,7 @@ func TestHandleInvalidSAMLCallbacks(t *testing.T) { func TestConnectorLoginDoesNotAllowToChangeConnectorForAuthRequest(t *testing.T) { memStorage := memory.New(logger) - templates, err := loadTemplates(webConfig{}, "../web/templates") + templates, err := loadTemplates(WebConfig{Dir: http.Dir("../web")}, "../web/templates") if err != nil { t.Fatal("failed to load templates") } diff --git a/server/server.go b/server/server.go index a79b7cfd..79c5f5a0 100644 --- a/server/server.go +++ b/server/server.go @@ -108,7 +108,7 @@ type WebConfig struct { // * templates - HTML templates controlled by dex. // * themes/(theme) - Static static served at "( issuer URL )/theme". // - Dir string + Dir http.FileSystem // Defaults to "( issuer URL )/theme/logo.png" LogoURL string @@ -203,18 +203,9 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) supported[respType] = true } - web := webConfig{ - dir: c.Web.Dir, - logoURL: c.Web.LogoURL, - issuerURL: c.Issuer, - issuer: c.Web.Issuer, - theme: c.Web.Theme, - extra: c.Web.Extra, - } - - static, theme, tmpls, err := loadWebConfig(web) + tmpls, err := loadTemplates(c.Web, issuerURL.Path) if err != nil { - return nil, fmt.Errorf("server: failed to load web static: %v", err) + return nil, fmt.Errorf("server: failed to load templates: %v", err) } now := c.Now @@ -343,8 +334,8 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) } fmt.Fprintf(w, "Health check passed") })) - handlePrefix("/static", static) - handlePrefix("/theme", theme) + handlePrefix("/", http.FileServer(c.Web.Dir)) + s.mux = r s.startKeyRotation(ctx, rotationStrategy, now) diff --git a/server/server_test.go b/server/server_test.go index 87ca6c17..b0bacfd3 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -93,7 +93,7 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi Issuer: s.URL, Storage: memory.New(logger), Web: WebConfig{ - Dir: "../web", + Dir: http.Dir("../web"), }, Logger: logger, PrometheusRegistry: prometheus.NewRegistry(), @@ -132,7 +132,7 @@ func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateCo Issuer: s.URL, Storage: memory.New(logger), Web: WebConfig{ - Dir: "../web", + Dir: http.Dir("../web"), }, Logger: logger, PrometheusRegistry: prometheus.NewRegistry(), diff --git a/server/templates.go b/server/templates.go index bed1c6c8..b23083ba 100644 --- a/server/templates.go +++ b/server/templates.go @@ -1,13 +1,12 @@ package server import ( + "bytes" "fmt" "html/template" "io" - "io/ioutil" "net/http" "net/url" - "os" "path" "path/filepath" "sort" @@ -22,18 +21,10 @@ const ( tmplError = "error.html" tmplDevice = "device.html" tmplDeviceSuccess = "device_success.html" + tmplHeader = "header.html" + tmplFooter = "footer.html" ) -var requiredTmpls = []string{ - tmplApproval, - tmplLogin, - tmplPassword, - tmplOOB, - tmplError, - tmplDevice, - tmplDeviceSuccess, -} - type templates struct { loginTmpl *template.Template approvalTmpl *template.Template @@ -44,131 +35,93 @@ type templates struct { deviceSuccessTmpl *template.Template } -type webConfig struct { - dir string - logoURL string - issuer string - theme string - issuerURL string - extra map[string]string -} - -func dirExists(dir string) error { - stat, err := os.Stat(dir) - if err != nil { - if os.IsNotExist(err) { - return fmt.Errorf("directory %q does not exist", dir) - } - return fmt.Errorf("stat directory %q: %v", dir, err) - } - if !stat.IsDir() { - return fmt.Errorf("path %q is a file not a directory", dir) - } - return nil -} - -// loadWebConfig returns static assets, theme assets, and templates used by the frontend by -// reading the directory specified in the webConfig. -// -// The directory layout is expected to be: -// -// ( web directory ) -// |- static -// |- themes -// | |- (theme name) -// |- templates -// -func loadWebConfig(c webConfig) (http.Handler, http.Handler, *templates, error) { - // fallback to the default theme if the legacy theme name is provided - if c.theme == "coreos" || c.theme == "tectonic" { - c.theme = "" - } - if c.theme == "" { - c.theme = "light" - } - if c.issuer == "" { - c.issuer = "dex" - } - if c.dir == "" { - c.dir = "./web" - } - if c.logoURL == "" { - c.logoURL = "theme/logo.png" - } - - if err := dirExists(c.dir); err != nil { - return nil, nil, nil, fmt.Errorf("load web dir: %v", err) - } - - staticDir := filepath.Join(c.dir, "static") - templatesDir := filepath.Join(c.dir, "templates") - themeDir := filepath.Join(c.dir, "themes", c.theme) - - for _, dir := range []string{staticDir, templatesDir, themeDir} { - if err := dirExists(dir); err != nil { - return nil, nil, nil, fmt.Errorf("load dir: %v", err) - } - } - - static := http.FileServer(http.Dir(staticDir)) - theme := http.FileServer(http.Dir(themeDir)) - - templates, err := loadTemplates(c, templatesDir) - return static, theme, templates, err -} - // loadTemplates parses the expected templates from the provided directory. -func loadTemplates(c webConfig, templatesDir string) (*templates, error) { - files, err := ioutil.ReadDir(templatesDir) +func loadTemplates(c WebConfig, issuerPath string) (*templates, error) { + // fallback to the default theme if the legacy theme name is provided + if c.Theme == "coreos" || c.Theme == "tectonic" { + c.Theme = "" + } + if c.Theme == "" { + c.Theme = "light" + } + + if c.Issuer == "" { + c.Issuer = "dex" + } + + if c.LogoURL == "" { + c.LogoURL = "theme/logo.png" + } + + funcs := template.FuncMap{ + "issuer": func() string { return c.Issuer }, + "logo": func() string { return c.LogoURL }, + "url": func(reqPath, assetPath string) string { return relativeURL(issuerPath, reqPath, assetPath) }, + "theme": func(reqPath, assetPath string) string { + return relativeURL(issuerPath, reqPath, path.Join("themes", c.Theme, assetPath)) + }, + "lower": strings.ToLower, + "extra": func(k string) string { return c.Extra[k] }, + } + + group := template.New("") + + // load all of our templates individually. + // some http.FilSystem implementations don't implement Readdir + + loginTemplate, err := loadTemplate(c.Dir, tmplLogin, funcs, group) if err != nil { - return nil, fmt.Errorf("read dir: %v", err) + return nil, err } - filenames := []string{} - for _, file := range files { - if file.IsDir() { - continue - } - filenames = append(filenames, filepath.Join(templatesDir, file.Name())) - } - if len(filenames) == 0 { - return nil, fmt.Errorf("no files in template dir %q", templatesDir) - } - - issuerURL, err := url.Parse(c.issuerURL) + approvalTemplate, err := loadTemplate(c.Dir, tmplApproval, funcs, group) if err != nil { - return nil, fmt.Errorf("error parsing issuerURL: %v", err) + return nil, err } - funcs := map[string]interface{}{ - "issuer": func() string { return c.issuer }, - "logo": func() string { return c.logoURL }, - "url": func(reqPath, assetPath string) string { return relativeURL(issuerURL.Path, reqPath, assetPath) }, - "lower": strings.ToLower, - "extra": func(k string) string { return c.extra[k] }, - } - - tmpls, err := template.New("").Funcs(funcs).ParseFiles(filenames...) + passwordTemplate, err := loadTemplate(c.Dir, tmplPassword, funcs, group) if err != nil { - return nil, fmt.Errorf("parse files: %v", err) + return nil, err } - missingTmpls := []string{} - for _, tmplName := range requiredTmpls { - if tmpls.Lookup(tmplName) == nil { - missingTmpls = append(missingTmpls, tmplName) - } + + oobTemplate, err := loadTemplate(c.Dir, tmplOOB, funcs, group) + if err != nil { + return nil, err } - if len(missingTmpls) > 0 { - return nil, fmt.Errorf("missing template(s): %s", missingTmpls) + + errorTemplate, err := loadTemplate(c.Dir, tmplError, funcs, group) + if err != nil { + return nil, err } + + deviceTemplate, err := loadTemplate(c.Dir, tmplDevice, funcs, group) + if err != nil { + return nil, err + } + + deviceSuccessTemplate, err := loadTemplate(c.Dir, tmplDeviceSuccess, funcs, group) + if err != nil { + return nil, err + } + + _, err = loadTemplate(c.Dir, tmplHeader, funcs, group) + if err != nil { + // we don't actually care if this template exists + } + + _, err = loadTemplate(c.Dir, tmplFooter, funcs, group) + if err != nil { + // we don't actually care if this template exists + } + return &templates{ - loginTmpl: tmpls.Lookup(tmplLogin), - approvalTmpl: tmpls.Lookup(tmplApproval), - passwordTmpl: tmpls.Lookup(tmplPassword), - oobTmpl: tmpls.Lookup(tmplOOB), - errorTmpl: tmpls.Lookup(tmplError), - deviceTmpl: tmpls.Lookup(tmplDevice), - deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess), + loginTmpl: loginTemplate, + approvalTmpl: approvalTemplate, + passwordTmpl: passwordTemplate, + oobTmpl: oobTemplate, + errorTmpl: errorTemplate, + deviceTmpl: deviceTemplate, + deviceSuccessTmpl: deviceSuccessTemplate, }, nil } @@ -239,6 +192,22 @@ func relativeURL(serverPath, reqPath, assetPath string) string { return relativeURL } +// load a template by name from the templates dir +func loadTemplate(dir http.FileSystem, name string, funcs template.FuncMap, group *template.Template) (*template.Template, error) { + file, err := dir.Open(filepath.Join("templates", name)) + if err != nil { + return nil, err + } + + defer file.Close() + + var buffer bytes.Buffer + buffer.ReadFrom(file) + contents := buffer.String() + + return group.New(name).Funcs(funcs).Parse(contents) +} + var scopeDescriptions = map[string]string{ "offline_access": "Have offline access", "profile": "View basic profile information", diff --git a/web/templates/header.html b/web/templates/header.html index 0d4fea0f..78dde15e 100644 --- a/web/templates/header.html +++ b/web/templates/header.html @@ -6,8 +6,8 @@ {{ issuer }} - - + +