forked from mystiq/dex
client, server: public client restrictions
* disallow ClientCreds for public clients * clients can only redirect to localhost or OOB
This commit is contained in:
parent
4f85f3a479
commit
cdcf08066d
7 changed files with 240 additions and 26 deletions
|
@ -7,6 +7,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
@ -38,6 +39,8 @@ func (v ValidationError) Error() string {
|
||||||
|
|
||||||
const (
|
const (
|
||||||
bcryptHashCost = 10
|
bcryptHashCost = 10
|
||||||
|
|
||||||
|
OOBRedirectURI = "urn:ietf:wg:oauth:2.0:oob"
|
||||||
)
|
)
|
||||||
|
|
||||||
func HashSecret(creds oidc.ClientCredentials) ([]byte, error) {
|
func HashSecret(creds oidc.ClientCredentials) ([]byte, error) {
|
||||||
|
@ -61,6 +64,34 @@ type Client struct {
|
||||||
Public bool
|
Public bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c Client) ValidRedirectURL(u *url.URL) (url.URL, error) {
|
||||||
|
if c.Public {
|
||||||
|
if u == nil {
|
||||||
|
return url.URL{}, ErrorInvalidRedirectURL
|
||||||
|
}
|
||||||
|
if u.String() == OOBRedirectURI {
|
||||||
|
return *u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Scheme != "http" {
|
||||||
|
return url.URL{}, ErrorInvalidRedirectURL
|
||||||
|
}
|
||||||
|
|
||||||
|
hostPort := strings.Split(u.Host, ":")
|
||||||
|
if len(hostPort) != 2 {
|
||||||
|
return url.URL{}, ErrorInvalidRedirectURL
|
||||||
|
}
|
||||||
|
|
||||||
|
if hostPort[0] != "localhost" || u.Path != "" || u.RawPath != "" || u.RawQuery != "" || u.Fragment != "" {
|
||||||
|
return url.URL{}, ErrorInvalidRedirectURL
|
||||||
|
}
|
||||||
|
|
||||||
|
return *u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValidRedirectURL(u, c.Metadata.RedirectURIs)
|
||||||
|
}
|
||||||
|
|
||||||
type ClientRepo interface {
|
type ClientRepo interface {
|
||||||
Get(tx repo.Transaction, clientID string) (Client, error)
|
Get(tx repo.Transaction, clientID string) (Client, error)
|
||||||
|
|
||||||
|
|
|
@ -204,7 +204,101 @@ func TestClientsFromReader(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func TestClientValidRedirectURL(t *testing.T) {
|
||||||
|
makeClient := func(public bool, urls []string) Client {
|
||||||
|
cli := Client{
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: make([]url.URL, len(urls)),
|
||||||
|
},
|
||||||
|
Public: public,
|
||||||
|
}
|
||||||
|
for i, s := range urls {
|
||||||
|
cli.Metadata.RedirectURIs[i] = mustParseURL(t, s)
|
||||||
|
}
|
||||||
|
return cli
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
u string
|
||||||
|
cli Client
|
||||||
|
|
||||||
|
wantU string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
u: "http://auth.example.com",
|
||||||
|
cli: makeClient(false, []string{"http://auth.example.com"}),
|
||||||
|
wantU: "http://auth.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
u: "http://auth2.example.com",
|
||||||
|
cli: makeClient(false, []string{"http://auth.example.com", "http://auth2.example.com"}),
|
||||||
|
wantU: "http://auth2.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
u: "",
|
||||||
|
cli: makeClient(false, []string{"http://auth.example.com"}),
|
||||||
|
wantU: "http://auth.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
u: "",
|
||||||
|
cli: makeClient(false, []string{"http://auth.example.com", "http://auth2.example.com"}),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
u: "http://localhost:8080",
|
||||||
|
cli: makeClient(true, []string{}),
|
||||||
|
wantU: "http://localhost:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
u: OOBRedirectURI,
|
||||||
|
cli: makeClient(true, []string{}),
|
||||||
|
wantU: OOBRedirectURI,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
u: "",
|
||||||
|
cli: makeClient(true, []string{}),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
u: "http://localhost:8080/hey_there",
|
||||||
|
cli: makeClient(true, []string{}),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
u: "http://auth.google.com:8080",
|
||||||
|
cli: makeClient(true, []string{}),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
var testURL *url.URL
|
||||||
|
if tt.u == "" {
|
||||||
|
testURL = nil
|
||||||
|
} else {
|
||||||
|
u := mustParseURL(t, tt.u)
|
||||||
|
testURL = &u
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := tt.cli.ValidRedirectURL(testURL)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("case %d: want non-nil error", i)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := pretty.Compare(mustParseURL(t, tt.wantU), u); diff != "" {
|
||||||
|
t.Fatalf("case %d: Compare(wantU, u): %v", i, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
func mustParseURL(t *testing.T, s string) url.URL {
|
func mustParseURL(t *testing.T, s string) url.URL {
|
||||||
u, err := url.Parse(s)
|
u, err := url.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -303,30 +303,69 @@ func TestHTTPClientCredsToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cis := []client.LoadableClient{{Client: ci}}
|
|
||||||
|
|
||||||
srv, err := mockServer(cis)
|
ci2 := ci
|
||||||
if err != nil {
|
ci2.Credentials.ID = "not_a_client"
|
||||||
t.Fatalf("Unexpected error setting up server: %v", err)
|
|
||||||
|
ciPublic := ci
|
||||||
|
ciPublic.Public = true
|
||||||
|
ciPublic.Credentials.ID = "public"
|
||||||
|
|
||||||
|
cis := []client.LoadableClient{{Client: ci}, {Client: ciPublic}}
|
||||||
|
tests := []struct {
|
||||||
|
cli client.Client
|
||||||
|
clients []client.LoadableClient
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
cli: ci,
|
||||||
|
clients: cis,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
cli: ci2,
|
||||||
|
clients: cis,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
cli: ciPublic,
|
||||||
|
clients: cis,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
cl, err := mockClient(srv, ci)
|
for i, tt := range tests {
|
||||||
|
srv, err := mockServer(tt.clients)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error setting up OIDC client: %v", err)
|
t.Fatalf("case %d: Unexpected error setting up server: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cl, err := mockClient(srv, tt.cli)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("case %d: Unexpected error setting up OIDC client: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tok, err := cl.ClientCredsToken([]string{"openid"})
|
tok, err := cl.ClientCredsToken([]string{"openid"})
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("case %d: want non-nil error", i)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed getting client token: %v", err)
|
t.Fatalf("case %d: Failed getting client token: %v", i, err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := tok.Claims()
|
claims, err := tok.Claims()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed parsing claims from client token: %v", err)
|
t.Fatalf("case %d: Failed parsing claims from client token: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil {
|
if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil {
|
||||||
t.Fatalf("Failed to verify claims: %v", err)
|
t.Fatalf("case %d: Failed to verify claims: %v", i, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -300,7 +300,6 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
||||||
}
|
}
|
||||||
|
|
||||||
cli, err := srv.Client(acr.ClientID)
|
cli, err := srv.Client(acr.ClientID)
|
||||||
cm := cli.Metadata
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
|
log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
|
||||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
|
||||||
|
@ -312,13 +311,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cm.RedirectURIs) == 0 {
|
redirectURL, err := cli.ValidRedirectURL(acr.RedirectURL)
|
||||||
log.Errorf("Client %q has no redirect URLs", acr.ClientID)
|
|
||||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURIs)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err {
|
switch err {
|
||||||
case (client.ErrorCantChooseRedirectURL):
|
case (client.ErrorCantChooseRedirectURL):
|
||||||
|
|
|
@ -105,6 +105,30 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
wantLocation: "http://fake.example.com",
|
wantLocation: "http://fake.example.com",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// valid redirect_uri for public client
|
||||||
|
{
|
||||||
|
query: url.Values{
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"redirect_uri": []string{"http://localhost:8080"},
|
||||||
|
"client_id": []string{testPublicClientID},
|
||||||
|
"connector_id": []string{"fake"},
|
||||||
|
"scope": []string{"openid"},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusFound,
|
||||||
|
wantLocation: "http://fake.example.com",
|
||||||
|
},
|
||||||
|
// valid OOB redirect_uri for public client
|
||||||
|
{
|
||||||
|
query: url.Values{
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"redirect_uri": []string{client.OOBRedirectURI},
|
||||||
|
"client_id": []string{testPublicClientID},
|
||||||
|
"connector_id": []string{"fake"},
|
||||||
|
"scope": []string{"openid"},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusFound,
|
||||||
|
wantLocation: "http://fake.example.com",
|
||||||
|
},
|
||||||
// provided redirect_uri does not match client
|
// provided redirect_uri does not match client
|
||||||
{
|
{
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
|
@ -173,6 +197,17 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
},
|
},
|
||||||
wantCode: http.StatusBadRequest,
|
wantCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
|
// invalid redirect_uri for public client
|
||||||
|
{
|
||||||
|
query: url.Values{
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"redirect_uri": []string{client.OOBRedirectURI + "oops"},
|
||||||
|
"client_id": []string{testPublicClientID},
|
||||||
|
"connector_id": []string{"fake"},
|
||||||
|
"scope": []string{"openid"},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
|
|
@ -408,6 +408,15 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
|
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
|
||||||
|
cli, err := s.Client(creds.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cli.Public {
|
||||||
|
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
|
||||||
|
}
|
||||||
|
|
||||||
ok, err := s.ClientManager.Authenticate(creds)
|
ok, err := s.ClientManager.Authenticate(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
|
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
|
||||||
|
|
|
@ -39,6 +39,13 @@ var (
|
||||||
ID: testClientID,
|
ID: testClientID,
|
||||||
Secret: clientTestSecret,
|
Secret: clientTestSecret,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
testPublicClientID = "publicclient.example.com"
|
||||||
|
publicClientTestSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
|
||||||
|
testPublicClientCredentials = oidc.ClientCredentials{
|
||||||
|
ID: testPublicClientID,
|
||||||
|
Secret: publicClientTestSecret,
|
||||||
|
}
|
||||||
testClients = []client.LoadableClient{
|
testClients = []client.LoadableClient{
|
||||||
{
|
{
|
||||||
Client: client.Client{
|
Client: client.Client{
|
||||||
|
@ -50,6 +57,12 @@ var (
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Client: client.Client{
|
||||||
|
Credentials: testPublicClientCredentials,
|
||||||
|
Public: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testConnectorID1 = "IDPC-1"
|
testConnectorID1 = "IDPC-1"
|
||||||
|
|
Loading…
Reference in a new issue