diff --git a/server/oauth2.go b/server/oauth2.go index b10609c9..528b25a6 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -12,6 +12,7 @@ import ( "fmt" "hash" "io" + "net" "net/http" "net/url" "strconv" @@ -518,9 +519,18 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool { if redirectURI == redirectURIOOB { return true } - if !strings.HasPrefix(redirectURI, "http://localhost:") { + + // verify that the host is of form "http://localhost:(port)(path)" or "http://localhost(path)" + u, err := url.Parse(redirectURI) + if err != nil { return false } - n, err := strconv.Atoi(strings.TrimPrefix(redirectURI, "https://localhost:")) - return err == nil && n <= 0 + if u.Scheme != "http" { + return false + } + if u.Host == "localhost" { + return true + } + host, _, err := net.SplitHostPort(u.Host) + return err == nil && host == "localhost" } diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 83de2256..dcf4947b 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -195,3 +195,67 @@ func TestAccessTokenHash(t *testing.T) { t.Errorf("expected %q got %q", googleAccessTokenHash, atHash) } } + +func TestValidRedirectURI(t *testing.T) { + tests := []struct { + client storage.Client + redirectURI string + wantValid bool + }{ + { + client: storage.Client{ + RedirectURIs: []string{"http://foo.com/bar"}, + }, + redirectURI: "http://foo.com/bar", + wantValid: true, + }, + { + client: storage.Client{ + RedirectURIs: []string{"http://foo.com/bar"}, + }, + redirectURI: "http://foo.com/bar/baz", + }, + { + client: storage.Client{ + Public: true, + }, + redirectURI: "urn:ietf:wg:oauth:2.0:oob", + wantValid: true, + }, + { + client: storage.Client{ + Public: true, + }, + redirectURI: "http://localhost:8080/", + wantValid: true, + }, + { + client: storage.Client{ + Public: true, + }, + redirectURI: "http://localhost:991/bar", + wantValid: true, + }, + { + client: storage.Client{ + Public: true, + }, + redirectURI: "http://localhost", + wantValid: true, + }, + { + client: storage.Client{ + Public: true, + }, + redirectURI: "http://localhost.localhost:8080/", + wantValid: false, + }, + } + for _, test := range tests { + got := validateRedirectURI(test.client, test.redirectURI) + if got != test.wantValid { + t.Errorf("client=%#v, redirectURI=%q, wanted valid=%t, got=%t", + test.client, test.redirectURI, test.wantValid, got) + } + } +}