diff --git a/client/client.go b/client/client.go index 40b86969..55590995 100644 --- a/client/client.go +++ b/client/client.go @@ -7,6 +7,7 @@ import ( "io" "net/url" "reflect" + "strings" "golang.org/x/crypto/bcrypt" @@ -38,6 +39,8 @@ func (v ValidationError) Error() string { const ( bcryptHashCost = 10 + + OOBRedirectURI = "urn:ietf:wg:oauth:2.0:oob" ) func HashSecret(creds oidc.ClientCredentials) ([]byte, error) { @@ -61,6 +64,34 @@ type Client struct { Public bool } +func (c Client) ValidRedirectURL(u *url.URL) (url.URL, error) { + if c.Public { + if u == nil { + return url.URL{}, ErrorInvalidRedirectURL + } + if u.String() == OOBRedirectURI { + return *u, nil + } + + if u.Scheme != "http" { + return url.URL{}, ErrorInvalidRedirectURL + } + + hostPort := strings.Split(u.Host, ":") + if len(hostPort) != 2 { + return url.URL{}, ErrorInvalidRedirectURL + } + + if hostPort[0] != "localhost" || u.Path != "" || u.RawPath != "" || u.RawQuery != "" || u.Fragment != "" { + return url.URL{}, ErrorInvalidRedirectURL + } + + return *u, nil + } + + return ValidRedirectURL(u, c.Metadata.RedirectURIs) +} + type ClientRepo interface { Get(tx repo.Transaction, clientID string) (Client, error) diff --git a/client/client_test.go b/client/client_test.go index 5766b2f8..bd984f76 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -204,7 +204,101 @@ func TestClientsFromReader(t *testing.T) { } } } +func TestClientValidRedirectURL(t *testing.T) { + makeClient := func(public bool, urls []string) Client { + cli := Client{ + Metadata: oidc.ClientMetadata{ + RedirectURIs: make([]url.URL, len(urls)), + }, + Public: public, + } + for i, s := range urls { + cli.Metadata.RedirectURIs[i] = mustParseURL(t, s) + } + return cli + } + tests := []struct { + u string + cli Client + + wantU string + wantErr bool + }{ + { + u: "http://auth.example.com", + cli: makeClient(false, []string{"http://auth.example.com"}), + wantU: "http://auth.example.com", + }, + { + u: "http://auth2.example.com", + cli: makeClient(false, []string{"http://auth.example.com", "http://auth2.example.com"}), + wantU: "http://auth2.example.com", + }, + { + u: "", + cli: makeClient(false, []string{"http://auth.example.com"}), + wantU: "http://auth.example.com", + }, + { + u: "", + cli: makeClient(false, []string{"http://auth.example.com", "http://auth2.example.com"}), + wantErr: true, + }, + { + u: "http://localhost:8080", + cli: makeClient(true, []string{}), + wantU: "http://localhost:8080", + }, + { + u: OOBRedirectURI, + cli: makeClient(true, []string{}), + wantU: OOBRedirectURI, + }, + { + u: "", + cli: makeClient(true, []string{}), + wantErr: true, + }, + { + u: "http://localhost:8080/hey_there", + cli: makeClient(true, []string{}), + wantErr: true, + }, + { + u: "http://auth.google.com:8080", + cli: makeClient(true, []string{}), + wantErr: true, + }, + } + + for i, tt := range tests { + var testURL *url.URL + if tt.u == "" { + testURL = nil + } else { + u := mustParseURL(t, tt.u) + testURL = &u + } + + u, err := tt.cli.ValidRedirectURL(testURL) + if tt.wantErr { + if err == nil { + t.Errorf("case %d: want non-nil error", i) + } + continue + } + + if err != nil { + t.Errorf("case %d: unexpected error: %v", i, err) + } + + if diff := pretty.Compare(mustParseURL(t, tt.wantU), u); diff != "" { + t.Fatalf("case %d: Compare(wantU, u): %v", i, diff) + } + } + +} func mustParseURL(t *testing.T, s string) url.URL { u, err := url.Parse(s) if err != nil { diff --git a/integration/oidc_test.go b/integration/oidc_test.go index 64c404b5..c7f5377c 100644 --- a/integration/oidc_test.go +++ b/integration/oidc_test.go @@ -303,30 +303,69 @@ func TestHTTPClientCredsToken(t *testing.T) { }, }, } - cis := []client.LoadableClient{{Client: ci}} - srv, err := mockServer(cis) - if err != nil { - t.Fatalf("Unexpected error setting up server: %v", err) + ci2 := ci + ci2.Credentials.ID = "not_a_client" + + ciPublic := ci + ciPublic.Public = true + ciPublic.Credentials.ID = "public" + + cis := []client.LoadableClient{{Client: ci}, {Client: ciPublic}} + tests := []struct { + cli client.Client + clients []client.LoadableClient + wantErr bool + }{ + { + cli: ci, + clients: cis, + wantErr: false, + }, + { + cli: ci2, + clients: cis, + wantErr: true, + }, + { + cli: ciPublic, + clients: cis, + wantErr: true, + }, } - cl, err := mockClient(srv, ci) - if err != nil { - t.Fatalf("Unexpected error setting up OIDC client: %v", err) - } + for i, tt := range tests { + srv, err := mockServer(tt.clients) + if err != nil { + t.Fatalf("case %d: Unexpected error setting up server: %v", i, err) + } - tok, err := cl.ClientCredsToken([]string{"openid"}) - if err != nil { - t.Fatalf("Failed getting client token: %v", err) - } + cl, err := mockClient(srv, tt.cli) + if err != nil { + t.Fatalf("case %d: Unexpected error setting up OIDC client: %v", i, err) + } - claims, err := tok.Claims() - if err != nil { - t.Fatalf("Failed parsing claims from client token: %v", err) - } + tok, err := cl.ClientCredsToken([]string{"openid"}) + if tt.wantErr { + if err == nil { + t.Errorf("case %d: want non-nil error", i) + } + continue + } - if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil { - t.Fatalf("Failed to verify claims: %v", err) + if err != nil { + t.Fatalf("case %d: Failed getting client token: %v", i, err) + continue + } + + claims, err := tok.Claims() + if err != nil { + t.Fatalf("case %d: Failed parsing claims from client token: %v", i, err) + } + + if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil { + t.Fatalf("case %d: Failed to verify claims: %v", i, err) + } } } diff --git a/server/http.go b/server/http.go index d5100999..9d2cc984 100644 --- a/server/http.go +++ b/server/http.go @@ -300,7 +300,6 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T } cli, err := srv.Client(acr.ClientID) - cm := cli.Metadata if err != nil { log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err) writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State) @@ -312,13 +311,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T return } - if len(cm.RedirectURIs) == 0 { - log.Errorf("Client %q has no redirect URLs", acr.ClientID) - writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State) - return - } - - redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURIs) + redirectURL, err := cli.ValidRedirectURL(acr.RedirectURL) if err != nil { switch err { case (client.ErrorCantChooseRedirectURL): diff --git a/server/http_test.go b/server/http_test.go index 1a2de802..637f53b7 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -105,6 +105,30 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { wantLocation: "http://fake.example.com", }, + // valid redirect_uri for public client + { + query: url.Values{ + "response_type": []string{"code"}, + "redirect_uri": []string{"http://localhost:8080"}, + "client_id": []string{testPublicClientID}, + "connector_id": []string{"fake"}, + "scope": []string{"openid"}, + }, + wantCode: http.StatusFound, + wantLocation: "http://fake.example.com", + }, + // valid OOB redirect_uri for public client + { + query: url.Values{ + "response_type": []string{"code"}, + "redirect_uri": []string{client.OOBRedirectURI}, + "client_id": []string{testPublicClientID}, + "connector_id": []string{"fake"}, + "scope": []string{"openid"}, + }, + wantCode: http.StatusFound, + wantLocation: "http://fake.example.com", + }, // provided redirect_uri does not match client { query: url.Values{ @@ -173,6 +197,17 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { }, wantCode: http.StatusBadRequest, }, + // invalid redirect_uri for public client + { + query: url.Values{ + "response_type": []string{"code"}, + "redirect_uri": []string{client.OOBRedirectURI + "oops"}, + "client_id": []string{testPublicClientID}, + "connector_id": []string{"fake"}, + "scope": []string{"openid"}, + }, + wantCode: http.StatusBadRequest, + }, } for i, tt := range tests { diff --git a/server/server.go b/server/server.go index 2f6cb263..4fe53e25 100644 --- a/server/server.go +++ b/server/server.go @@ -408,6 +408,15 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { } func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) { + cli, err := s.Client(creds.ID) + if err != nil { + return nil, err + } + + if cli.Public { + return nil, oauth2.NewError(oauth2.ErrorInvalidClient) + } + ok, err := s.ClientManager.Authenticate(creds) if err != nil { log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err) diff --git a/server/testutil_test.go b/server/testutil_test.go index 16317a03..212a30ae 100644 --- a/server/testutil_test.go +++ b/server/testutil_test.go @@ -39,6 +39,13 @@ var ( ID: testClientID, Secret: clientTestSecret, } + + testPublicClientID = "publicclient.example.com" + publicClientTestSecret = base64.URLEncoding.EncodeToString([]byte("secret")) + testPublicClientCredentials = oidc.ClientCredentials{ + ID: testPublicClientID, + Secret: publicClientTestSecret, + } testClients = []client.LoadableClient{ { Client: client.Client{ @@ -50,6 +57,12 @@ var ( }, }, }, + { + Client: client.Client{ + Credentials: testPublicClientCredentials, + Public: true, + }, + }, } testConnectorID1 = "IDPC-1"