server: fix registration redirect for servers listenin at non-base URLs

This commit is contained in:
Eric Chiang 2016-08-19 15:49:31 -07:00
parent 7525e5623c
commit fa8f98acac
4 changed files with 52 additions and 7 deletions

View file

@ -180,7 +180,7 @@ func TestHandleAuthCrossClient(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
hdlr := handleAuthFunc(f.srv, idpcs, nil, true) hdlr := handleAuthFunc(f.srv, url.URL{}, idpcs, nil, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
query := url.Values{ query := url.Values{

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"path"
"strings" "strings"
"time" "time"
@ -266,7 +267,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp
execTemplate(w, tpl, td) execTemplate(w, tpl, td)
} }
func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc { func handleAuthFunc(srv OIDCServer, baseURL url.URL, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
idx := makeConnectorMap(idpcs) idx := makeConnectorMap(idpcs)
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" { if r.Method != "GET" {
@ -358,7 +359,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
if ok { if ok {
q := url.Values{} q := url.Values{}
q.Set("code", key) q.Set("code", key)
ru := httpPathRegister + "?" + q.Encode() ru := path.Join(baseURL.Path, httpPathRegister) + "?" + q.Encode()
w.Header().Set("Location", ru) w.Header().Set("Location", ru)
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
return return

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"html/template"
"math/big" "math/big"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -54,7 +55,7 @@ func (c *fakeConnector) TrustedEmailProvider() bool {
func TestHandleAuthFuncMethodNotAllowed(t *testing.T) { func TestHandleAuthFuncMethodNotAllowed(t *testing.T) {
for _, m := range []string{"POST", "PUT", "DELETE"} { for _, m := range []string{"POST", "PUT", "DELETE"} {
hdlr := handleAuthFunc(nil, nil, nil, true) hdlr := handleAuthFunc(nil, url.URL{}, nil, nil, true)
req, err := http.NewRequest(m, "http://example.com", nil) req, err := http.NewRequest(m, "http://example.com", nil)
if err != nil { if err != nil {
t.Errorf("case %s: unable to create HTTP request: %v", m, err) t.Errorf("case %s: unable to create HTTP request: %v", m, err)
@ -72,13 +73,28 @@ func TestHandleAuthFuncMethodNotAllowed(t *testing.T) {
} }
} }
func newLocalConnector(t *testing.T, id string) connector.Connector {
config := connector.LocalConnectorConfig{ID: id}
tmpl, err := template.New(connector.LoginPageTemplateName).Parse("")
if err != nil {
t.Fatalf("creating login template: %v", err)
}
conn, err := config.Connector(url.URL{}, nil, tmpl)
if err != nil {
t.Fatalf("creating connector: %v", err)
}
return conn
}
func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
idpcs := []connector.Connector{ idpcs := []connector.Connector{
&fakeConnector{loginURL: "http://fake.example.com"}, &fakeConnector{loginURL: "http://fake.example.com"},
newLocalConnector(t, "local"),
} }
tests := []struct { tests := []struct {
query url.Values query url.Values
baseURL url.URL
wantCode int wantCode int
wantLocation string wantLocation string
}{ }{
@ -210,6 +226,34 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
}, },
wantCode: http.StatusBadRequest, wantCode: http.StatusBadRequest,
}, },
// registration
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{"http://client.example.com/callback"},
"client_id": []string{"client.example.com"},
"connector_id": []string{"local"},
"register": []string{"1"},
"scope": []string{"openid"},
},
baseURL: url.URL{Scheme: "https", Host: "dex.example.com"}, // Root URL.
wantCode: http.StatusFound,
wantLocation: "/register?code=code-2",
},
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{"http://client.example.com/callback"},
"client_id": []string{"client.example.com"},
"connector_id": []string{"local"},
"register": []string{"1"},
"scope": []string{"openid"},
},
baseURL: url.URL{Scheme: "https", Host: "dex.example.com", Path: "/foobar"},
wantCode: http.StatusFound,
wantLocation: "/foobar/register?code=code-2",
},
} }
for i, tt := range tests { for i, tt := range tests {
@ -218,7 +262,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
t.Fatalf("error making test fixtures: %v", err) t.Fatalf("error making test fixtures: %v", err)
} }
hdlr := handleAuthFunc(f.srv, idpcs, nil, true) hdlr := handleAuthFunc(f.srv, tt.baseURL, idpcs, nil, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
req, err := http.NewRequest("GET", u, nil) req, err := http.NewRequest("GET", u, nil)
@ -322,7 +366,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
hdlr := handleAuthFunc(f.srv, idpcs, nil, true) hdlr := handleAuthFunc(f.srv, url.URL{}, idpcs, nil, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
req, err := http.NewRequest("GET", u, nil) req, err := http.NewRequest("GET", u, nil)

View file

@ -238,7 +238,7 @@ func (s *Server) HTTPHandler() http.Handler {
} }
handleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig())) handleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig()))
handleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate, s.EnableRegistration)) handleFunc(httpPathAuth, handleAuthFunc(s, s.IssuerURL, s.Connectors, s.LoginTemplate, s.EnableRegistration))
handleFunc(httpPathOOB, handleOOBFunc(s, s.OOBTemplate)) handleFunc(httpPathOOB, handleOOBFunc(s, s.OOBTemplate))
handleFunc(httpPathToken, handleTokenFunc(s)) handleFunc(httpPathToken, handleTokenFunc(s))
handleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock)) handleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock))