forked from mystiq/dex
server: fix registration redirect for servers listenin at non-base URLs
This commit is contained in:
parent
7525e5623c
commit
fa8f98acac
4 changed files with 52 additions and 7 deletions
|
@ -180,7 +180,7 @@ func TestHandleAuthCrossClient(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
hdlr := handleAuthFunc(f.srv, idpcs, nil, true)
|
hdlr := handleAuthFunc(f.srv, url.URL{}, idpcs, nil, true)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
query := url.Values{
|
query := url.Values{
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -266,7 +267,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp
|
||||||
execTemplate(w, tpl, td)
|
execTemplate(w, tpl, td)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
|
func handleAuthFunc(srv OIDCServer, baseURL url.URL, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
|
||||||
idx := makeConnectorMap(idpcs)
|
idx := makeConnectorMap(idpcs)
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
|
@ -358,7 +359,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
||||||
if ok {
|
if ok {
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Set("code", key)
|
q.Set("code", key)
|
||||||
ru := httpPathRegister + "?" + q.Encode()
|
ru := path.Join(baseURL.Path, httpPathRegister) + "?" + q.Encode()
|
||||||
w.Header().Set("Location", ru)
|
w.Header().Set("Location", ru)
|
||||||
w.WriteHeader(http.StatusFound)
|
w.WriteHeader(http.StatusFound)
|
||||||
return
|
return
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"html/template"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -54,7 +55,7 @@ func (c *fakeConnector) TrustedEmailProvider() bool {
|
||||||
|
|
||||||
func TestHandleAuthFuncMethodNotAllowed(t *testing.T) {
|
func TestHandleAuthFuncMethodNotAllowed(t *testing.T) {
|
||||||
for _, m := range []string{"POST", "PUT", "DELETE"} {
|
for _, m := range []string{"POST", "PUT", "DELETE"} {
|
||||||
hdlr := handleAuthFunc(nil, nil, nil, true)
|
hdlr := handleAuthFunc(nil, url.URL{}, nil, nil, true)
|
||||||
req, err := http.NewRequest(m, "http://example.com", nil)
|
req, err := http.NewRequest(m, "http://example.com", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("case %s: unable to create HTTP request: %v", m, err)
|
t.Errorf("case %s: unable to create HTTP request: %v", m, err)
|
||||||
|
@ -72,13 +73,28 @@ func TestHandleAuthFuncMethodNotAllowed(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newLocalConnector(t *testing.T, id string) connector.Connector {
|
||||||
|
config := connector.LocalConnectorConfig{ID: id}
|
||||||
|
tmpl, err := template.New(connector.LoginPageTemplateName).Parse("")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating login template: %v", err)
|
||||||
|
}
|
||||||
|
conn, err := config.Connector(url.URL{}, nil, tmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating connector: %v", err)
|
||||||
|
}
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
idpcs := []connector.Connector{
|
idpcs := []connector.Connector{
|
||||||
&fakeConnector{loginURL: "http://fake.example.com"},
|
&fakeConnector{loginURL: "http://fake.example.com"},
|
||||||
|
newLocalConnector(t, "local"),
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
query url.Values
|
query url.Values
|
||||||
|
baseURL url.URL
|
||||||
wantCode int
|
wantCode int
|
||||||
wantLocation string
|
wantLocation string
|
||||||
}{
|
}{
|
||||||
|
@ -210,6 +226,34 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
},
|
},
|
||||||
wantCode: http.StatusBadRequest,
|
wantCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// registration
|
||||||
|
{
|
||||||
|
query: url.Values{
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||||
|
"client_id": []string{"client.example.com"},
|
||||||
|
"connector_id": []string{"local"},
|
||||||
|
"register": []string{"1"},
|
||||||
|
"scope": []string{"openid"},
|
||||||
|
},
|
||||||
|
baseURL: url.URL{Scheme: "https", Host: "dex.example.com"}, // Root URL.
|
||||||
|
wantCode: http.StatusFound,
|
||||||
|
wantLocation: "/register?code=code-2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
query: url.Values{
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||||
|
"client_id": []string{"client.example.com"},
|
||||||
|
"connector_id": []string{"local"},
|
||||||
|
"register": []string{"1"},
|
||||||
|
"scope": []string{"openid"},
|
||||||
|
},
|
||||||
|
baseURL: url.URL{Scheme: "https", Host: "dex.example.com", Path: "/foobar"},
|
||||||
|
wantCode: http.StatusFound,
|
||||||
|
wantLocation: "/foobar/register?code=code-2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
@ -218,7 +262,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
t.Fatalf("error making test fixtures: %v", err)
|
t.Fatalf("error making test fixtures: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hdlr := handleAuthFunc(f.srv, idpcs, nil, true)
|
hdlr := handleAuthFunc(f.srv, tt.baseURL, idpcs, nil, true)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
|
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
|
||||||
req, err := http.NewRequest("GET", u, nil)
|
req, err := http.NewRequest("GET", u, nil)
|
||||||
|
@ -322,7 +366,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
hdlr := handleAuthFunc(f.srv, idpcs, nil, true)
|
hdlr := handleAuthFunc(f.srv, url.URL{}, idpcs, nil, true)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
|
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
|
||||||
req, err := http.NewRequest("GET", u, nil)
|
req, err := http.NewRequest("GET", u, nil)
|
||||||
|
|
|
@ -238,7 +238,7 @@ func (s *Server) HTTPHandler() http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
handleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig()))
|
handleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig()))
|
||||||
handleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate, s.EnableRegistration))
|
handleFunc(httpPathAuth, handleAuthFunc(s, s.IssuerURL, s.Connectors, s.LoginTemplate, s.EnableRegistration))
|
||||||
handleFunc(httpPathOOB, handleOOBFunc(s, s.OOBTemplate))
|
handleFunc(httpPathOOB, handleOOBFunc(s, s.OOBTemplate))
|
||||||
handleFunc(httpPathToken, handleTokenFunc(s))
|
handleFunc(httpPathToken, handleTokenFunc(s))
|
||||||
handleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock))
|
handleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock))
|
||||||
|
|
Loading…
Reference in a new issue