templates: add new relativeURL function

Signed-off-by: Yannis Zarkadas <yanniszark@arrikto.com>
This commit is contained in:
Yannis Zarkadas 2019-09-27 17:04:43 +03:00
parent 839130f01c
commit 27944d4f8f
2 changed files with 131 additions and 13 deletions

View file

@ -6,7 +6,9 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"os" "os"
"path"
"path/filepath" "path/filepath"
"sort" "sort"
"strings" "strings"
@ -94,7 +96,7 @@ func loadWebConfig(c webConfig) (static, theme http.Handler, templates *template
c.dir = "./web" c.dir = "./web"
} }
if c.logoURL == "" { if c.logoURL == "" {
c.logoURL = join(c.issuerURL, "theme/logo.png") c.logoURL = "theme/logo.png"
} }
if err := dirExists(c.dir); err != nil { if err := dirExists(c.dir); err != nil {
@ -136,10 +138,15 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
return nil, fmt.Errorf("no files in template dir %q", templatesDir) return nil, fmt.Errorf("no files in template dir %q", templatesDir)
} }
issuerURL, err := url.Parse(c.issuerURL)
if err != nil {
return nil, fmt.Errorf("error parsing issuerURL: %v", err)
}
funcs := map[string]interface{}{ funcs := map[string]interface{}{
"issuer": func() string { return c.issuer }, "issuer": func() string { return c.issuer },
"logo": func() string { return c.logoURL }, "logo": func() string { return c.logoURL },
"url": func(s string) string { return join(c.issuerURL, s) }, "url": func(reqPath, assetPath string) string { return relativeURL(issuerURL.Path, reqPath, assetPath) },
"lower": strings.ToLower, "lower": strings.ToLower,
"extra": func(k string) string { return c.extra[k] }, "extra": func(k string) string { return c.extra[k] },
} }
@ -166,6 +173,69 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
}, nil }, nil
} }
// relativeURL returns the URL of the asset relative to the URL of the request path.
// The serverPath is consulted to trim any prefix due in case it is not listening
// to the root path.
//
// Algorithm:
// 1. Remove common prefix of serverPath and reqPath
// 2. Remove common prefix of assetPath and reqPath
// 3. For each part of reqPath remaining(minus one), go up one level (..)
// 4. For each part of assetPath remaining, append it to result
//
//eg
//server listens at localhost/dex so serverPath is dex
//reqPath is /dex/auth
//assetPath is static/main.css
//relativeURL("/dex", "/dex/auth", "static/main.css") = "../static/main.css"
func relativeURL(serverPath, reqPath, assetPath string) string {
splitPath := func(p string) []string {
res := []string{}
parts := strings.Split(path.Clean(p), "/")
for _, part := range parts {
if part != "" {
res = append(res, part)
}
}
return res
}
stripCommonParts := func(s1, s2 []string) ([]string, []string) {
min := len(s1)
if len(s2) < min {
min = len(s2)
}
splitIndex := min
for i := 0; i < min; i++ {
if s1[i] != s2[i] {
splitIndex = i
break
}
}
return s1[splitIndex:], s2[splitIndex:]
}
server, req, asset := splitPath(serverPath), splitPath(reqPath), splitPath(assetPath)
// Remove common prefix of request path with server path
server, req = stripCommonParts(server, req)
// Remove common prefix of request path with asset path
asset, req = stripCommonParts(asset, req)
// For each part of the request remaining (minus one) -> go up one level (..)
// For each part of the asset remaining -> append it
var relativeURL string
for i := 0; i < len(req)-1; i++ {
relativeURL = path.Join("..", relativeURL)
}
relativeURL = path.Join(relativeURL, path.Join(asset...))
return relativeURL
}
var scopeDescriptions = map[string]string{ var scopeDescriptions = map[string]string{
"offline_access": "Have offline access", "offline_access": "Have offline access",
"profile": "View basic profile information", "profile": "View basic profile information",
@ -184,26 +254,28 @@ func (n byName) Len() int { return len(n) }
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func (t *templates) login(w http.ResponseWriter, connectors []connectorInfo) error { func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
sort.Sort(byName(connectors)) sort.Sort(byName(connectors))
data := struct { data := struct {
Connectors []connectorInfo Connectors []connectorInfo
}{connectors} ReqPath string
}{connectors, r.URL.Path}
return renderTemplate(w, t.loginTmpl, data) return renderTemplate(w, t.loginTmpl, data)
} }
func (t *templates) password(w http.ResponseWriter, postURL, lastUsername, usernamePrompt string, lastWasInvalid, showBacklink bool) error { func (t *templates) password(r *http.Request, w http.ResponseWriter, postURL, lastUsername, usernamePrompt string, lastWasInvalid, showBacklink bool, reqPath string) error {
data := struct { data := struct {
PostURL string PostURL string
BackLink bool BackLink bool
Username string Username string
UsernamePrompt string UsernamePrompt string
Invalid bool Invalid bool
}{postURL, showBacklink, lastUsername, usernamePrompt, lastWasInvalid} ReqPath string
}{postURL, showBacklink, lastUsername, usernamePrompt, lastWasInvalid, r.URL.Path}
return renderTemplate(w, t.passwordTmpl, data) return renderTemplate(w, t.passwordTmpl, data)
} }
func (t *templates) approval(w http.ResponseWriter, authReqID, username, clientName string, scopes []string) error { func (t *templates) approval(r *http.Request, w http.ResponseWriter, authReqID, username, clientName string, scopes []string, reqPath string) error {
accesses := []string{} accesses := []string{}
for _, scope := range scopes { for _, scope := range scopes {
access, ok := scopeDescriptions[scope] access, ok := scopeDescriptions[scope]
@ -217,23 +289,26 @@ func (t *templates) approval(w http.ResponseWriter, authReqID, username, clientN
Client string Client string
AuthReqID string AuthReqID string
Scopes []string Scopes []string
}{username, clientName, authReqID, accesses} ReqPath string
}{username, clientName, authReqID, accesses, r.URL.Path}
return renderTemplate(w, t.approvalTmpl, data) return renderTemplate(w, t.approvalTmpl, data)
} }
func (t *templates) oob(w http.ResponseWriter, code string) error { func (t *templates) oob(r *http.Request, w http.ResponseWriter, code string, reqPath string) error {
data := struct { data := struct {
Code string Code string
}{code} ReqPath string
}{code, r.URL.Path}
return renderTemplate(w, t.oobTmpl, data) return renderTemplate(w, t.oobTmpl, data)
} }
func (t *templates) err(w http.ResponseWriter, errCode int, errMsg string) error { func (t *templates) err(r *http.Request, w http.ResponseWriter, errCode int, errMsg string) error {
w.WriteHeader(errCode) w.WriteHeader(errCode)
data := struct { data := struct {
ErrType string ErrType string
ErrMsg string ErrMsg string
}{http.StatusText(errCode), errMsg} ReqPath string
}{http.StatusText(errCode), errMsg, r.URL.Path}
if err := t.errorTmpl.Execute(w, data); err != nil { if err := t.errorTmpl.Execute(w, data); err != nil {
return fmt.Errorf("Error rendering template %s: %s", t.errorTmpl.Name(), err) return fmt.Errorf("Error rendering template %s: %s", t.errorTmpl.Name(), err)
} }

View file

@ -1 +1,44 @@
package server package server
import "testing"
func TestRelativeURL(t *testing.T) {
tests := []struct {
name string
serverPath string
reqPath string
assetPath string
expected string
}{
{
name: "server-root-req-one-level-asset-two-level",
serverPath: "/",
reqPath: "/auth",
assetPath: "/theme/main.css",
expected: "theme/main.css",
},
{
name: "server-one-level-req-one-level-asset-two-level",
serverPath: "/dex",
reqPath: "/dex/auth",
assetPath: "/theme/main.css",
expected: "theme/main.css",
},
{
name: "server-root-req-two-level-asset-three-level",
serverPath: "/dex",
reqPath: "/dex/auth/connector",
assetPath: "assets/css/main.css",
expected: "../assets/css/main.css",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual := relativeURL(test.serverPath, test.reqPath, test.assetPath)
if actual != test.expected {
t.Fatalf("Got '%s'. Expected '%s'", actual, test.expected)
}
})
}
}