client, server: public client restrictions

* disallow ClientCreds for public clients
* clients can only redirect to localhost or OOB
This commit is contained in:
Bobby Rullo 2016-06-17 15:35:03 -07:00
parent 4f85f3a479
commit cdcf08066d
7 changed files with 240 additions and 26 deletions

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"net/url" "net/url"
"reflect" "reflect"
"strings"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -38,6 +39,8 @@ func (v ValidationError) Error() string {
const ( const (
bcryptHashCost = 10 bcryptHashCost = 10
OOBRedirectURI = "urn:ietf:wg:oauth:2.0:oob"
) )
func HashSecret(creds oidc.ClientCredentials) ([]byte, error) { func HashSecret(creds oidc.ClientCredentials) ([]byte, error) {
@ -61,6 +64,34 @@ type Client struct {
Public bool Public bool
} }
func (c Client) ValidRedirectURL(u *url.URL) (url.URL, error) {
if c.Public {
if u == nil {
return url.URL{}, ErrorInvalidRedirectURL
}
if u.String() == OOBRedirectURI {
return *u, nil
}
if u.Scheme != "http" {
return url.URL{}, ErrorInvalidRedirectURL
}
hostPort := strings.Split(u.Host, ":")
if len(hostPort) != 2 {
return url.URL{}, ErrorInvalidRedirectURL
}
if hostPort[0] != "localhost" || u.Path != "" || u.RawPath != "" || u.RawQuery != "" || u.Fragment != "" {
return url.URL{}, ErrorInvalidRedirectURL
}
return *u, nil
}
return ValidRedirectURL(u, c.Metadata.RedirectURIs)
}
type ClientRepo interface { type ClientRepo interface {
Get(tx repo.Transaction, clientID string) (Client, error) Get(tx repo.Transaction, clientID string) (Client, error)

View file

@ -204,7 +204,101 @@ func TestClientsFromReader(t *testing.T) {
} }
} }
} }
func TestClientValidRedirectURL(t *testing.T) {
makeClient := func(public bool, urls []string) Client {
cli := Client{
Metadata: oidc.ClientMetadata{
RedirectURIs: make([]url.URL, len(urls)),
},
Public: public,
}
for i, s := range urls {
cli.Metadata.RedirectURIs[i] = mustParseURL(t, s)
}
return cli
}
tests := []struct {
u string
cli Client
wantU string
wantErr bool
}{
{
u: "http://auth.example.com",
cli: makeClient(false, []string{"http://auth.example.com"}),
wantU: "http://auth.example.com",
},
{
u: "http://auth2.example.com",
cli: makeClient(false, []string{"http://auth.example.com", "http://auth2.example.com"}),
wantU: "http://auth2.example.com",
},
{
u: "",
cli: makeClient(false, []string{"http://auth.example.com"}),
wantU: "http://auth.example.com",
},
{
u: "",
cli: makeClient(false, []string{"http://auth.example.com", "http://auth2.example.com"}),
wantErr: true,
},
{
u: "http://localhost:8080",
cli: makeClient(true, []string{}),
wantU: "http://localhost:8080",
},
{
u: OOBRedirectURI,
cli: makeClient(true, []string{}),
wantU: OOBRedirectURI,
},
{
u: "",
cli: makeClient(true, []string{}),
wantErr: true,
},
{
u: "http://localhost:8080/hey_there",
cli: makeClient(true, []string{}),
wantErr: true,
},
{
u: "http://auth.google.com:8080",
cli: makeClient(true, []string{}),
wantErr: true,
},
}
for i, tt := range tests {
var testURL *url.URL
if tt.u == "" {
testURL = nil
} else {
u := mustParseURL(t, tt.u)
testURL = &u
}
u, err := tt.cli.ValidRedirectURL(testURL)
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil error", i)
}
continue
}
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if diff := pretty.Compare(mustParseURL(t, tt.wantU), u); diff != "" {
t.Fatalf("case %d: Compare(wantU, u): %v", i, diff)
}
}
}
func mustParseURL(t *testing.T, s string) url.URL { func mustParseURL(t *testing.T, s string) url.URL {
u, err := url.Parse(s) u, err := url.Parse(s)
if err != nil { if err != nil {

View file

@ -303,30 +303,69 @@ func TestHTTPClientCredsToken(t *testing.T) {
}, },
}, },
} }
cis := []client.LoadableClient{{Client: ci}}
srv, err := mockServer(cis) ci2 := ci
if err != nil { ci2.Credentials.ID = "not_a_client"
t.Fatalf("Unexpected error setting up server: %v", err)
ciPublic := ci
ciPublic.Public = true
ciPublic.Credentials.ID = "public"
cis := []client.LoadableClient{{Client: ci}, {Client: ciPublic}}
tests := []struct {
cli client.Client
clients []client.LoadableClient
wantErr bool
}{
{
cli: ci,
clients: cis,
wantErr: false,
},
{
cli: ci2,
clients: cis,
wantErr: true,
},
{
cli: ciPublic,
clients: cis,
wantErr: true,
},
} }
cl, err := mockClient(srv, ci) for i, tt := range tests {
srv, err := mockServer(tt.clients)
if err != nil { if err != nil {
t.Fatalf("Unexpected error setting up OIDC client: %v", err) t.Fatalf("case %d: Unexpected error setting up server: %v", i, err)
}
cl, err := mockClient(srv, tt.cli)
if err != nil {
t.Fatalf("case %d: Unexpected error setting up OIDC client: %v", i, err)
} }
tok, err := cl.ClientCredsToken([]string{"openid"}) tok, err := cl.ClientCredsToken([]string{"openid"})
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil error", i)
}
continue
}
if err != nil { if err != nil {
t.Fatalf("Failed getting client token: %v", err) t.Fatalf("case %d: Failed getting client token: %v", i, err)
continue
} }
claims, err := tok.Claims() claims, err := tok.Claims()
if err != nil { if err != nil {
t.Fatalf("Failed parsing claims from client token: %v", err) t.Fatalf("case %d: Failed parsing claims from client token: %v", i, err)
} }
if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil { if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil {
t.Fatalf("Failed to verify claims: %v", err) t.Fatalf("case %d: Failed to verify claims: %v", i, err)
}
} }
} }

View file

@ -300,7 +300,6 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
} }
cli, err := srv.Client(acr.ClientID) cli, err := srv.Client(acr.ClientID)
cm := cli.Metadata
if err != nil { if err != nil {
log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err) log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State) writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
@ -312,13 +311,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
return return
} }
if len(cm.RedirectURIs) == 0 { redirectURL, err := cli.ValidRedirectURL(acr.RedirectURL)
log.Errorf("Client %q has no redirect URLs", acr.ClientID)
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
return
}
redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURIs)
if err != nil { if err != nil {
switch err { switch err {
case (client.ErrorCantChooseRedirectURL): case (client.ErrorCantChooseRedirectURL):

View file

@ -105,6 +105,30 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
wantLocation: "http://fake.example.com", wantLocation: "http://fake.example.com",
}, },
// valid redirect_uri for public client
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{"http://localhost:8080"},
"client_id": []string{testPublicClientID},
"connector_id": []string{"fake"},
"scope": []string{"openid"},
},
wantCode: http.StatusFound,
wantLocation: "http://fake.example.com",
},
// valid OOB redirect_uri for public client
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{client.OOBRedirectURI},
"client_id": []string{testPublicClientID},
"connector_id": []string{"fake"},
"scope": []string{"openid"},
},
wantCode: http.StatusFound,
wantLocation: "http://fake.example.com",
},
// provided redirect_uri does not match client // provided redirect_uri does not match client
{ {
query: url.Values{ query: url.Values{
@ -173,6 +197,17 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
}, },
wantCode: http.StatusBadRequest, wantCode: http.StatusBadRequest,
}, },
// invalid redirect_uri for public client
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{client.OOBRedirectURI + "oops"},
"client_id": []string{testPublicClientID},
"connector_id": []string{"fake"},
"scope": []string{"openid"},
},
wantCode: http.StatusBadRequest,
},
} }
for i, tt := range tests { for i, tt := range tests {

View file

@ -408,6 +408,15 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
} }
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) { func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
cli, err := s.Client(creds.ID)
if err != nil {
return nil, err
}
if cli.Public {
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
}
ok, err := s.ClientManager.Authenticate(creds) ok, err := s.ClientManager.Authenticate(creds)
if err != nil { if err != nil {
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err) log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)

View file

@ -39,6 +39,13 @@ var (
ID: testClientID, ID: testClientID,
Secret: clientTestSecret, Secret: clientTestSecret,
} }
testPublicClientID = "publicclient.example.com"
publicClientTestSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
testPublicClientCredentials = oidc.ClientCredentials{
ID: testPublicClientID,
Secret: publicClientTestSecret,
}
testClients = []client.LoadableClient{ testClients = []client.LoadableClient{
{ {
Client: client.Client{ Client: client.Client{
@ -50,6 +57,12 @@ var (
}, },
}, },
}, },
{
Client: client.Client{
Credentials: testPublicClientCredentials,
Public: true,
},
},
} }
testConnectorID1 = "IDPC-1" testConnectorID1 = "IDPC-1"