diff --git a/server/cross_client_test.go b/server/cross_client_test.go index 22d5b252..00099622 100644 --- a/server/cross_client_test.go +++ b/server/cross_client_test.go @@ -180,7 +180,7 @@ func TestHandleAuthCrossClient(t *testing.T) { } for i, tt := range tests { - hdlr := handleAuthFunc(f.srv, idpcs, nil, true) + hdlr := handleAuthFunc(f.srv, url.URL{}, idpcs, nil, true) w := httptest.NewRecorder() query := url.Values{ diff --git a/server/http.go b/server/http.go index 007a9d4e..773246ac 100644 --- a/server/http.go +++ b/server/http.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/url" + "path" "strings" "time" @@ -266,7 +267,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp execTemplate(w, tpl, td) } -func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc { +func handleAuthFunc(srv OIDCServer, baseURL url.URL, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc { idx := makeConnectorMap(idpcs) return func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { @@ -358,7 +359,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T if ok { q := url.Values{} q.Set("code", key) - ru := httpPathRegister + "?" + q.Encode() + ru := path.Join(baseURL.Path, httpPathRegister) + "?" + q.Encode() w.Header().Set("Location", ru) w.WriteHeader(http.StatusFound) return diff --git a/server/http_test.go b/server/http_test.go index 28f0897a..fc5b2acc 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "html/template" "math/big" "net/http" "net/http/httptest" @@ -54,7 +55,7 @@ func (c *fakeConnector) TrustedEmailProvider() bool { func TestHandleAuthFuncMethodNotAllowed(t *testing.T) { for _, m := range []string{"POST", "PUT", "DELETE"} { - hdlr := handleAuthFunc(nil, nil, nil, true) + hdlr := handleAuthFunc(nil, url.URL{}, nil, nil, true) req, err := http.NewRequest(m, "http://example.com", nil) if err != nil { t.Errorf("case %s: unable to create HTTP request: %v", m, err) @@ -72,13 +73,28 @@ func TestHandleAuthFuncMethodNotAllowed(t *testing.T) { } } +func newLocalConnector(t *testing.T, id string) connector.Connector { + config := connector.LocalConnectorConfig{ID: id} + tmpl, err := template.New(connector.LoginPageTemplateName).Parse("") + if err != nil { + t.Fatalf("creating login template: %v", err) + } + conn, err := config.Connector(url.URL{}, nil, tmpl) + if err != nil { + t.Fatalf("creating connector: %v", err) + } + return conn +} + func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { idpcs := []connector.Connector{ &fakeConnector{loginURL: "http://fake.example.com"}, + newLocalConnector(t, "local"), } tests := []struct { query url.Values + baseURL url.URL wantCode int wantLocation string }{ @@ -210,6 +226,34 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { }, wantCode: http.StatusBadRequest, }, + + // registration + { + query: url.Values{ + "response_type": []string{"code"}, + "redirect_uri": []string{"http://client.example.com/callback"}, + "client_id": []string{"client.example.com"}, + "connector_id": []string{"local"}, + "register": []string{"1"}, + "scope": []string{"openid"}, + }, + baseURL: url.URL{Scheme: "https", Host: "dex.example.com"}, // Root URL. + wantCode: http.StatusFound, + wantLocation: "/register?code=code-2", + }, + { + query: url.Values{ + "response_type": []string{"code"}, + "redirect_uri": []string{"http://client.example.com/callback"}, + "client_id": []string{"client.example.com"}, + "connector_id": []string{"local"}, + "register": []string{"1"}, + "scope": []string{"openid"}, + }, + baseURL: url.URL{Scheme: "https", Host: "dex.example.com", Path: "/foobar"}, + wantCode: http.StatusFound, + wantLocation: "/foobar/register?code=code-2", + }, } for i, tt := range tests { @@ -218,7 +262,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { t.Fatalf("error making test fixtures: %v", err) } - hdlr := handleAuthFunc(f.srv, idpcs, nil, true) + hdlr := handleAuthFunc(f.srv, tt.baseURL, idpcs, nil, true) w := httptest.NewRecorder() u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) req, err := http.NewRequest("GET", u, nil) @@ -322,7 +366,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { } for i, tt := range tests { - hdlr := handleAuthFunc(f.srv, idpcs, nil, true) + hdlr := handleAuthFunc(f.srv, url.URL{}, idpcs, nil, true) w := httptest.NewRecorder() u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) req, err := http.NewRequest("GET", u, nil) diff --git a/server/server.go b/server/server.go index 120015d1..6ec37535 100644 --- a/server/server.go +++ b/server/server.go @@ -238,7 +238,7 @@ func (s *Server) HTTPHandler() http.Handler { } handleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig())) - handleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate, s.EnableRegistration)) + handleFunc(httpPathAuth, handleAuthFunc(s, s.IssuerURL, s.Connectors, s.LoginTemplate, s.EnableRegistration)) handleFunc(httpPathOOB, handleOOBFunc(s, s.OOBTemplate)) handleFunc(httpPathToken, handleTokenFunc(s)) handleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock))