forked from mystiq/dex
client, server: public client restrictions
* disallow ClientCreds for public clients * clients can only redirect to localhost or OOB
This commit is contained in:
parent
4f85f3a479
commit
cdcf08066d
7 changed files with 240 additions and 26 deletions
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
|
@ -38,6 +39,8 @@ func (v ValidationError) Error() string {
|
|||
|
||||
const (
|
||||
bcryptHashCost = 10
|
||||
|
||||
OOBRedirectURI = "urn:ietf:wg:oauth:2.0:oob"
|
||||
)
|
||||
|
||||
func HashSecret(creds oidc.ClientCredentials) ([]byte, error) {
|
||||
|
@ -61,6 +64,34 @@ type Client struct {
|
|||
Public bool
|
||||
}
|
||||
|
||||
func (c Client) ValidRedirectURL(u *url.URL) (url.URL, error) {
|
||||
if c.Public {
|
||||
if u == nil {
|
||||
return url.URL{}, ErrorInvalidRedirectURL
|
||||
}
|
||||
if u.String() == OOBRedirectURI {
|
||||
return *u, nil
|
||||
}
|
||||
|
||||
if u.Scheme != "http" {
|
||||
return url.URL{}, ErrorInvalidRedirectURL
|
||||
}
|
||||
|
||||
hostPort := strings.Split(u.Host, ":")
|
||||
if len(hostPort) != 2 {
|
||||
return url.URL{}, ErrorInvalidRedirectURL
|
||||
}
|
||||
|
||||
if hostPort[0] != "localhost" || u.Path != "" || u.RawPath != "" || u.RawQuery != "" || u.Fragment != "" {
|
||||
return url.URL{}, ErrorInvalidRedirectURL
|
||||
}
|
||||
|
||||
return *u, nil
|
||||
}
|
||||
|
||||
return ValidRedirectURL(u, c.Metadata.RedirectURIs)
|
||||
}
|
||||
|
||||
type ClientRepo interface {
|
||||
Get(tx repo.Transaction, clientID string) (Client, error)
|
||||
|
||||
|
|
|
@ -204,7 +204,101 @@ func TestClientsFromReader(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
func TestClientValidRedirectURL(t *testing.T) {
|
||||
makeClient := func(public bool, urls []string) Client {
|
||||
cli := Client{
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: make([]url.URL, len(urls)),
|
||||
},
|
||||
Public: public,
|
||||
}
|
||||
for i, s := range urls {
|
||||
cli.Metadata.RedirectURIs[i] = mustParseURL(t, s)
|
||||
}
|
||||
return cli
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
u string
|
||||
cli Client
|
||||
|
||||
wantU string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
u: "http://auth.example.com",
|
||||
cli: makeClient(false, []string{"http://auth.example.com"}),
|
||||
wantU: "http://auth.example.com",
|
||||
},
|
||||
{
|
||||
u: "http://auth2.example.com",
|
||||
cli: makeClient(false, []string{"http://auth.example.com", "http://auth2.example.com"}),
|
||||
wantU: "http://auth2.example.com",
|
||||
},
|
||||
{
|
||||
u: "",
|
||||
cli: makeClient(false, []string{"http://auth.example.com"}),
|
||||
wantU: "http://auth.example.com",
|
||||
},
|
||||
{
|
||||
u: "",
|
||||
cli: makeClient(false, []string{"http://auth.example.com", "http://auth2.example.com"}),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
u: "http://localhost:8080",
|
||||
cli: makeClient(true, []string{}),
|
||||
wantU: "http://localhost:8080",
|
||||
},
|
||||
{
|
||||
u: OOBRedirectURI,
|
||||
cli: makeClient(true, []string{}),
|
||||
wantU: OOBRedirectURI,
|
||||
},
|
||||
{
|
||||
u: "",
|
||||
cli: makeClient(true, []string{}),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
u: "http://localhost:8080/hey_there",
|
||||
cli: makeClient(true, []string{}),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
u: "http://auth.google.com:8080",
|
||||
cli: makeClient(true, []string{}),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
var testURL *url.URL
|
||||
if tt.u == "" {
|
||||
testURL = nil
|
||||
} else {
|
||||
u := mustParseURL(t, tt.u)
|
||||
testURL = &u
|
||||
}
|
||||
|
||||
u, err := tt.cli.ValidRedirectURL(testURL)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil error", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(mustParseURL(t, tt.wantU), u); diff != "" {
|
||||
t.Fatalf("case %d: Compare(wantU, u): %v", i, diff)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
func mustParseURL(t *testing.T, s string) url.URL {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
|
|
|
@ -303,30 +303,69 @@ func TestHTTPClientCredsToken(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
cis := []client.LoadableClient{{Client: ci}}
|
||||
|
||||
srv, err := mockServer(cis)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error setting up server: %v", err)
|
||||
ci2 := ci
|
||||
ci2.Credentials.ID = "not_a_client"
|
||||
|
||||
ciPublic := ci
|
||||
ciPublic.Public = true
|
||||
ciPublic.Credentials.ID = "public"
|
||||
|
||||
cis := []client.LoadableClient{{Client: ci}, {Client: ciPublic}}
|
||||
tests := []struct {
|
||||
cli client.Client
|
||||
clients []client.LoadableClient
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
cli: ci,
|
||||
clients: cis,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
cli: ci2,
|
||||
clients: cis,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
cli: ciPublic,
|
||||
clients: cis,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
cl, err := mockClient(srv, ci)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error setting up OIDC client: %v", err)
|
||||
}
|
||||
for i, tt := range tests {
|
||||
srv, err := mockServer(tt.clients)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: Unexpected error setting up server: %v", i, err)
|
||||
}
|
||||
|
||||
tok, err := cl.ClientCredsToken([]string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed getting client token: %v", err)
|
||||
}
|
||||
cl, err := mockClient(srv, tt.cli)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: Unexpected error setting up OIDC client: %v", i, err)
|
||||
}
|
||||
|
||||
claims, err := tok.Claims()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed parsing claims from client token: %v", err)
|
||||
}
|
||||
tok, err := cl.ClientCredsToken([]string{"openid"})
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil error", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil {
|
||||
t.Fatalf("Failed to verify claims: %v", err)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: Failed getting client token: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
claims, err := tok.Claims()
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: Failed parsing claims from client token: %v", i, err)
|
||||
}
|
||||
|
||||
if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil {
|
||||
t.Fatalf("case %d: Failed to verify claims: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -300,7 +300,6 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
|||
}
|
||||
|
||||
cli, err := srv.Client(acr.ClientID)
|
||||
cm := cli.Metadata
|
||||
if err != nil {
|
||||
log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
|
||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
|
||||
|
@ -312,13 +311,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
|||
return
|
||||
}
|
||||
|
||||
if len(cm.RedirectURIs) == 0 {
|
||||
log.Errorf("Client %q has no redirect URLs", acr.ClientID)
|
||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURIs)
|
||||
redirectURL, err := cli.ValidRedirectURL(acr.RedirectURL)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case (client.ErrorCantChooseRedirectURL):
|
||||
|
|
|
@ -105,6 +105,30 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
wantLocation: "http://fake.example.com",
|
||||
},
|
||||
|
||||
// valid redirect_uri for public client
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"http://localhost:8080"},
|
||||
"client_id": []string{testPublicClientID},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
wantLocation: "http://fake.example.com",
|
||||
},
|
||||
// valid OOB redirect_uri for public client
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{client.OOBRedirectURI},
|
||||
"client_id": []string{testPublicClientID},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
wantLocation: "http://fake.example.com",
|
||||
},
|
||||
// provided redirect_uri does not match client
|
||||
{
|
||||
query: url.Values{
|
||||
|
@ -173,6 +197,17 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
// invalid redirect_uri for public client
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{client.OOBRedirectURI + "oops"},
|
||||
"client_id": []string{testPublicClientID},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
|
|
@ -408,6 +408,15 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
|||
}
|
||||
|
||||
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
|
||||
cli, err := s.Client(creds.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cli.Public {
|
||||
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||
}
|
||||
|
||||
ok, err := s.ClientManager.Authenticate(creds)
|
||||
if err != nil {
|
||||
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
|
||||
|
|
|
@ -39,6 +39,13 @@ var (
|
|||
ID: testClientID,
|
||||
Secret: clientTestSecret,
|
||||
}
|
||||
|
||||
testPublicClientID = "publicclient.example.com"
|
||||
publicClientTestSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
|
||||
testPublicClientCredentials = oidc.ClientCredentials{
|
||||
ID: testPublicClientID,
|
||||
Secret: publicClientTestSecret,
|
||||
}
|
||||
testClients = []client.LoadableClient{
|
||||
{
|
||||
Client: client.Client{
|
||||
|
@ -50,6 +57,12 @@ var (
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Client: client.Client{
|
||||
Credentials: testPublicClientCredentials,
|
||||
Public: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testConnectorID1 = "IDPC-1"
|
||||
|
|
Loading…
Reference in a new issue