forked from mystiq/dex
203 lines
4.5 KiB
Go
203 lines
4.5 KiB
Go
|
package server
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"net/url"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/coreos/go-oidc/oidc"
|
||
|
|
||
|
"github.com/coreos/dex/client"
|
||
|
"github.com/coreos/dex/connector"
|
||
|
)
|
||
|
|
||
|
func makeCrossClientTestFixtures() (*testFixtures, error) {
|
||
|
f, err := makeTestFixtures()
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("couldn't make test fixtures: %v", err)
|
||
|
}
|
||
|
|
||
|
creds := map[string]oidc.ClientCredentials{}
|
||
|
for _, cliData := range []struct {
|
||
|
id string
|
||
|
authorized []string
|
||
|
}{
|
||
|
{
|
||
|
id: "client_a",
|
||
|
}, {
|
||
|
id: "client_b",
|
||
|
authorized: []string{"client_a"},
|
||
|
}, {
|
||
|
id: "client_c",
|
||
|
authorized: []string{"client_a", "client_b"},
|
||
|
},
|
||
|
} {
|
||
|
u := url.URL{
|
||
|
Scheme: "https://",
|
||
|
Path: cliData.id,
|
||
|
Host: "auth.example.com",
|
||
|
}
|
||
|
cliCreds, err := f.clientRepo.New(client.Client{
|
||
|
Credentials: oidc.ClientCredentials{
|
||
|
ID: cliData.id,
|
||
|
},
|
||
|
Metadata: oidc.ClientMetadata{
|
||
|
RedirectURIs: []url.URL{u},
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("Unexpected error creating clients: %v", err)
|
||
|
}
|
||
|
creds[cliData.id] = *cliCreds
|
||
|
err = f.clientRepo.SetTrustedPeers(cliData.id, cliData.authorized)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("Unexpected error setting cross-client authorizers: %v", err)
|
||
|
}
|
||
|
}
|
||
|
return f, nil
|
||
|
}
|
||
|
|
||
|
func TestServerCrossClientAuthAllowed(t *testing.T) {
|
||
|
f, err := makeCrossClientTestFixtures()
|
||
|
if err != nil {
|
||
|
t.Fatalf("couldn't make test fixtures: %v", err)
|
||
|
}
|
||
|
|
||
|
tests := []struct {
|
||
|
reqClient string
|
||
|
authClient string
|
||
|
wantAuthorized bool
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{
|
||
|
reqClient: "client_b",
|
||
|
authClient: "client_a",
|
||
|
wantAuthorized: false,
|
||
|
wantErr: false,
|
||
|
},
|
||
|
{
|
||
|
reqClient: "client_a",
|
||
|
authClient: "client_b",
|
||
|
wantAuthorized: true,
|
||
|
wantErr: false,
|
||
|
},
|
||
|
{
|
||
|
reqClient: "client_a",
|
||
|
authClient: "client_c",
|
||
|
wantAuthorized: true,
|
||
|
wantErr: false,
|
||
|
},
|
||
|
{
|
||
|
reqClient: "client_c",
|
||
|
authClient: "client_b",
|
||
|
wantAuthorized: false,
|
||
|
wantErr: false,
|
||
|
},
|
||
|
{
|
||
|
reqClient: "client_c",
|
||
|
authClient: "nope",
|
||
|
wantErr: false,
|
||
|
},
|
||
|
}
|
||
|
for i, tt := range tests {
|
||
|
got, err := f.srv.CrossClientAuthAllowed(tt.reqClient, tt.authClient)
|
||
|
if tt.wantErr {
|
||
|
if err == nil {
|
||
|
t.Errorf("case %d: want non-nil err", i)
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
if err != nil {
|
||
|
t.Errorf("case %d: unexpected err %v: ", i, err)
|
||
|
}
|
||
|
|
||
|
if got != tt.wantAuthorized {
|
||
|
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAuthorized, got)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestHandleAuthCrossClient(t *testing.T) {
|
||
|
f, err := makeCrossClientTestFixtures()
|
||
|
if err != nil {
|
||
|
t.Fatalf("couldn't make test fixtures: %v", err)
|
||
|
}
|
||
|
|
||
|
tests := []struct {
|
||
|
scopes []string
|
||
|
clientID string
|
||
|
wantCode int
|
||
|
}{
|
||
|
{
|
||
|
scopes: []string{ScopeGoogleCrossClient + "client_a"},
|
||
|
clientID: "client_b",
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
},
|
||
|
{
|
||
|
scopes: []string{ScopeGoogleCrossClient + "client_b"},
|
||
|
clientID: "client_a",
|
||
|
wantCode: http.StatusFound,
|
||
|
},
|
||
|
{
|
||
|
scopes: []string{ScopeGoogleCrossClient + "client_b"},
|
||
|
clientID: "client_a",
|
||
|
wantCode: http.StatusFound,
|
||
|
},
|
||
|
{
|
||
|
scopes: []string{ScopeGoogleCrossClient + "client_c"},
|
||
|
clientID: "client_a",
|
||
|
wantCode: http.StatusFound,
|
||
|
},
|
||
|
{
|
||
|
// Two clients that client_a is authorized to mint tokens for.
|
||
|
scopes: []string{
|
||
|
ScopeGoogleCrossClient + "client_c",
|
||
|
ScopeGoogleCrossClient + "client_b",
|
||
|
},
|
||
|
clientID: "client_a",
|
||
|
wantCode: http.StatusFound,
|
||
|
},
|
||
|
{
|
||
|
// Two clients that client_a is authorized to mint tokens for.
|
||
|
scopes: []string{
|
||
|
ScopeGoogleCrossClient + "client_c",
|
||
|
ScopeGoogleCrossClient + "client_a",
|
||
|
},
|
||
|
clientID: "client_b",
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
idpcs := []connector.Connector{
|
||
|
&fakeConnector{loginURL: "http://fake.example.com"},
|
||
|
}
|
||
|
|
||
|
for i, tt := range tests {
|
||
|
hdlr := handleAuthFunc(f.srv, idpcs, nil, true)
|
||
|
w := httptest.NewRecorder()
|
||
|
|
||
|
query := url.Values{
|
||
|
"response_type": []string{"code"},
|
||
|
"client_id": []string{tt.clientID},
|
||
|
"connector_id": []string{"fake"},
|
||
|
"scope": []string{strings.Join(append([]string{"openid"}, tt.scopes...), " ")},
|
||
|
}
|
||
|
u := fmt.Sprintf("http://server.example.com?%s", query.Encode())
|
||
|
req, err := http.NewRequest("GET", u, nil)
|
||
|
if err != nil {
|
||
|
t.Errorf("case %d: unable to form HTTP request: %v", i, err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
hdlr.ServeHTTP(w, req)
|
||
|
if tt.wantCode != w.Code {
|
||
|
t.Errorf("case %d: HTTP code mismatch: want=%d got=%d", i, tt.wantCode, w.Code)
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|