Merge pull request #426 from bobbyrullo/cross_client_2

Cross client work
This commit is contained in:
bobbyrullo 2016-06-07 17:36:06 -07:00
commit a9d854e144
33 changed files with 936 additions and 201 deletions

View file

@ -141,7 +141,10 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem
}
// metadata is guaranteed to have at least one redirect_uri by earlier validation.
creds, err := a.clientManager.New(cli)
creds, err := a.clientManager.New(cli, &clientmanager.ClientOptions{
TrustedPeers: req.Client.TrustedPeers,
})
if err != nil {
return adminschema.ClientCreateResponse{}, mapError(err)
}

View file

@ -15,6 +15,7 @@ import (
)
var (
ErrorInvalidClientID = errors.New("not a valid client ID")
ErrorInvalidRedirectURL = errors.New("not a valid redirect url for the given client")
ErrorCantChooseRedirectURL = errors.New("must provide a redirect url; client has many")
ErrorNoValidRedirectURLs = errors.New("no valid redirect URLs for this client.")
@ -60,6 +61,12 @@ type ClientRepo interface {
New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error)
Update(tx repo.Transaction, client Client) error
// GetTrustedPeers returns the list of clients authorized to mint ID token for the given client.
GetTrustedPeers(tx repo.Transaction, clientID string) ([]string, error)
// SetTrustedPeers sets the list of clients authorized to mint ID token for the given client.
SetTrustedPeers(tx repo.Transaction, clientID string, clientIDs []string) error
}
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.

View file

@ -21,6 +21,10 @@ const (
maxSecretLength = 72
)
type ClientOptions struct {
TrustedPeers []string
}
type SecretGenerator func() ([]byte, error)
func DefaultSecretGenerator() ([]byte, error) {
@ -63,7 +67,7 @@ func NewClientManager(clientRepo client.ClientRepo, txnFactory repo.TransactionF
}
}
func (m *ClientManager) New(cli client.Client) (*oidc.ClientCredentials, error) {
func (m *ClientManager) New(cli client.Client, options *ClientOptions) (*oidc.ClientCredentials, error) {
tx, err := m.begin()
if err != nil {
return nil, err
@ -83,6 +87,13 @@ func (m *ClientManager) New(cli client.Client) (*oidc.ClientCredentials, error)
return nil, err
}
if options != nil && len(options.TrustedPeers) > 0 {
err = m.clientRepo.SetTrustedPeers(tx, creds.ID, options.TrustedPeers)
if err != nil {
return nil, err
}
}
err = tx.Commit()
if err != nil {
return nil, err

View file

@ -132,7 +132,7 @@ func TestAuthenticate(t *testing.T) {
cli := client.Client{
Metadata: cm,
}
cc, err := f.mgr.New(cli)
cc, err := f.mgr.New(cli, nil)
if err != nil {
t.Fatalf(err.Error())
}

View file

@ -34,7 +34,7 @@ func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials,
cli := client.Client{
Metadata: meta,
}
return d.ciManager.New(cli)
return d.ciManager.New(cli, nil)
}
func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) {

View file

@ -16,7 +16,8 @@ import (
)
const (
clientTableName = "client_identity"
clientTableName = "client_identity"
trustedPeerTableName = "trusted_peers"
// postgres error codes
pgErrorCodeUniqueViolation = "23505" // unique_violation
@ -29,6 +30,13 @@ func init() {
autoinc: false,
pkey: []string{"id"},
})
register(table{
name: trustedPeerTableName,
model: trustedPeerModel{},
autoinc: false,
pkey: []string{"client_id", "trusted_client_id"},
})
}
func newClientModel(cli client.Client) (*clientModel, error) {
@ -58,6 +66,11 @@ type clientModel struct {
DexAdmin bool `db:"dex_admin"`
}
type trustedPeerModel struct {
ClientID string `db:"client_id"`
TrustedClientID string `db:"trusted_client_id"`
}
func (m *clientModel) Client() (*client.Client, error) {
ci := client.Client{
Credentials: oidc.ClientCredentials{
@ -254,3 +267,63 @@ func (r *clientRepo) update(tx repo.Transaction, cli client.Client) error {
_, err = ex.Update(cm)
return err
}
func (r *clientRepo) GetTrustedPeers(tx repo.Transaction, clientID string) ([]string, error) {
ex := r.executor(tx)
if clientID == "" {
return nil, client.ErrorInvalidClientID
}
qt := r.quote(trustedPeerTableName)
var ids []string
_, err := ex.Select(&ids, fmt.Sprintf("SELECT trusted_client_id from %s where client_id = $1", qt), clientID)
if err != nil {
if err != sql.ErrNoRows {
return nil, err
}
return nil, nil
}
return ids, nil
}
func (r *clientRepo) SetTrustedPeers(tx repo.Transaction, clientID string, clientIDs []string) error {
ex := r.executor(tx)
qt := r.quote(trustedPeerTableName)
// First delete all existing rows
_, err := ex.Exec(fmt.Sprintf("DELETE from %s where client_id = $1", qt), clientID)
if err != nil {
return err
}
// Ensure that the client exists.
_, err = r.get(tx, clientID)
if err != nil {
return err
}
// Verify that all the clients are valid
for _, curID := range clientIDs {
_, err := r.get(tx, curID)
if err != nil {
return err
}
}
// Set the clients
rows := []interface{}{}
for _, curID := range clientIDs {
rows = append(rows, &trustedPeerModel{
ClientID: clientID,
TrustedClientID: curID,
})
}
err = ex.Insert(rows...)
if err != nil {
return err
}
return nil
}

View file

@ -5,7 +5,7 @@ import (
"fmt"
"github.com/go-gorp/gorp"
"github.com/rubenv/sql-migrate"
migrate "github.com/rubenv/sql-migrate"
"github.com/coreos/dex/db/migrations"
)

View file

@ -70,4 +70,10 @@ CREATE TABLE session_key (
expires_at bigint,
stale integer
);
CREATE TABLE trusted_peers (
client_id text NOT NULL,
trusted_client_id text NOT NULL
);
`

View file

@ -0,0 +1,5 @@
-- +migrate Up
CREATE TABLE IF NOT EXISTS "trusted_peers" (
"client_id" text not null,
"trusted_client_id" text not null,
primary key ("client_id", "trusted_client_id")) ;

View file

@ -72,5 +72,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
"-- +migrate Up\n\n-- This migration is a fix for a bug that allowed duplicate emails if they used different cases (see #338).\n-- When migrating, dex will not take the liberty of deleting rows for duplicate cases. Instead it will\n-- raise an exception and call for an admin to remove duplicates manually.\n\nCREATE OR REPLACE FUNCTION raise_exp() RETURNS VOID AS $$\nBEGIN\n RAISE EXCEPTION 'Found duplicate emails when using case insensitive comparision, cannot perform migration.';\nEND;\n$$ LANGUAGE plpgsql;\n\nSELECT LOWER(email),\n COUNT(email),\n CASE\n WHEN COUNT(email) > 1 THEN raise_exp()\n ELSE NULL\n END\nFROM authd_user\nGROUP BY LOWER(email);\n\nUPDATE authd_user SET email = LOWER(email);\n",
},
},
{
Id: "0012_add_cross_client_authorizers.sql",
Up: []string{
"-- +migrate Up\nCREATE TABLE IF NOT EXISTS \"trusted_peers\" (\n \"client_id\" text not null,\n \"trusted_client_id\" text not null,\n primary key (\"client_id\", \"trusted_client_id\")) ;\n",
},
},
},
}

View file

@ -248,7 +248,7 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us
if err != sql.ErrNoRows {
return nil, err
}
return nil, err
return nil, nil
}
if len(rims) == 0 {
return nil, nil

View file

@ -316,7 +316,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
cli := client.Client{
Metadata: cm,
}
cc, err := m.New(cli)
cc, err := m.New(cli, nil)
if err != nil {
t.Fatalf(err.Error())
}

View file

@ -86,7 +86,9 @@ func makeAdminAPITestFixtures() *adminAPITestFixtures {
var cliCount int
secGen := func() ([]byte, error) {
return []byte(fmt.Sprintf("client_%v", cliCount)), nil
id := []byte(fmt.Sprintf("client_%v", cliCount))
cliCount++
return id, nil
}
cr := db.NewClientRepo(dbMap)
clientIDGenerator := func(hostport string) (string, error) {
@ -379,6 +381,7 @@ func TestCreateClient(t *testing.T) {
}
return u
}
addIDAndSecret := func(cli adminschema.Client) *adminschema.Client {
cli.Id = "client_auth.example.com"
cli.Secret = base64.URLEncoding.EncodeToString([]byte("client_0"))
@ -404,16 +407,20 @@ func TestCreateClient(t *testing.T) {
adminMultiRedirect := adminClientGood
adminMultiRedirect.RedirectURIs = []string{"https://auth.example.com/", "https://auth2.example.com/"}
clientMultiRedirect := clientGoodAdmin
clientMultiRedirect := clientGood
clientMultiRedirect.Metadata.RedirectURIs = append(
clientMultiRedirect.Metadata.RedirectURIs,
*mustParseURL("https://auth2.example.com/"))
adminClientWithPeers := adminClientGood
adminClientWithPeers.TrustedPeers = []string{"test_client_0"}
tests := []struct {
req adminschema.ClientCreateRequest
want adminschema.ClientCreateResponse
wantClient client.Client
wantError int
req adminschema.ClientCreateRequest
want adminschema.ClientCreateResponse
wantClient client.Client
wantError int
wantTrustedPeers []string
}{
{
req: adminschema.ClientCreateRequest{},
@ -462,13 +469,35 @@ func TestCreateClient(t *testing.T) {
},
wantClient: clientMultiRedirect,
},
{
req: adminschema.ClientCreateRequest{
Client: &adminClientWithPeers,
},
want: adminschema.ClientCreateResponse{
Client: addIDAndSecret(adminClientWithPeers),
},
wantClient: clientGood,
wantTrustedPeers: []string{"test_client_0"},
},
}
for i, tt := range tests {
if i != 3 {
continue
}
f := makeAdminAPITestFixtures()
for j, r := range []string{"https://client0.example.com",
"https://client1.example.com"} {
_, err := f.cr.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: fmt.Sprintf("test_client_%d", j),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{*mustParseURL(r)},
},
})
if err != nil {
t.Errorf("case %d, client %d: unexpected error creating client: %v", i, j, err)
continue
}
}
resp, err := f.adClient.Client.Create(&tt.req).Do()
if tt.wantError != 0 {

View file

@ -618,7 +618,7 @@ func TestRefreshTokenEndpoints(t *testing.T) {
t.Errorf("case %d: expected client ids did not match actual: %s", i, diff)
}
for _, clientID := range ids {
if err := f.client.Clients.Revoke(tt.userID, clientID).Do(); err != nil {
if err := f.client.RefreshClient.Revoke(tt.userID, clientID).Do(); err != nil {
t.Errorf("case %d: failed to revoke client: %v", i, err)
}
}

View file

@ -34,7 +34,10 @@ __Version:__ v1
redirectURIs: [
string
],
secret: string // The client secret. Ignored in client create requests.
secret: string // The client secret. Ignored in client create requests.,
trustedPeers: [
string
]
}
```

View file

@ -148,6 +148,10 @@ type Client struct {
// Secret: The client secret. Ignored in client create requests.
Secret string `json:"secret,omitempty"`
// TrustedPeers: Array of ClientIDs of clients that are allowed to mint
// ID tokens for the client being created.
TrustedPeers []string `json:"trustedPeers,omitempty"`
}
type ClientCreateRequest struct {

View file

@ -84,6 +84,13 @@ const DiscoveryJSON = `{
"clientURI": {
"type": "string",
"description": "OPTIONAL. URL of the home page of the Client. The value of this field MUST point to a valid Web page. If present, the server SHOULD display this URL to the End-User in a followable fashion. If desired, representation of this Claim in different languages and scripts is represented as described in Section 2.1 ( Metadata Languages and Scripts ) ."
},
"trustedPeers": {
"type": "array",
"items": {
"type": "string"
},
"description": "Array of ClientIDs of clients that are allowed to mint ID tokens for the client being created."
}
}
},
@ -228,4 +235,5 @@ const DiscoveryJSON = `{
}
}
}
`

View file

@ -78,6 +78,13 @@
"clientURI": {
"type": "string",
"description": "OPTIONAL. URL of the home page of the Client. The value of this field MUST point to a valid Web page. If present, the server SHOULD display this URL to the End-User in a followable fashion. If desired, representation of this Claim in different languages and scripts is represented as described in Section 2.1 ( Metadata Languages and Scripts ) ."
},
"trustedPeers": {
"type": "array",
"items": {
"type": "string"
},
"description": "Array of ClientIDs of clients that are allowed to mint ID tokens for the client being created."
}
}
},
@ -222,3 +229,4 @@
}
}
}

View file

@ -199,7 +199,7 @@ A client with associated public metadata.
> __Description__
> List all clients that hold refresh tokens for the authenticated user.
> List all clients that hold refresh tokens for the specified user.
> __Parameters__
@ -221,19 +221,19 @@ A client with associated public metadata.
> __Summary__
> Revoke Clients
> Revoke RefreshClient
> __Description__
> Revoke all refresh tokens issues to the client for the authenticated user.
> Revoke all refresh tokens issues to the client for the specified user.
> __Parameters__
> |Name|Located in|Description|Required|Type|
|:-----|:-----|:-----|:-----|:-----|
| userid | path | | Yes | string |
| clientid | path | | Yes | string |
| userid | path | | Yes | string |
> __Responses__
@ -310,8 +310,8 @@ A client with associated public metadata.
> |Name|Located in|Description|Required|Type|
|:-----|:-----|:-----|:-----|:-----|
| maxResults | query | | No | integer |
| nextPageToken | query | | No | string |
| maxResults | query | | No | integer |
> __Responses__

View file

@ -334,82 +334,7 @@ func (c *ClientsListCall) Do() (*ClientPage, error) {
}
// method id "dex.Client.Revoke":
type ClientsRevokeCall struct {
s *Service
userid string
clientid string
opt_ map[string]interface{}
}
// Revoke: Revoke all refresh tokens issues to the client for the
// authenticated user.
func (r *ClientsService) Revoke(userid string, clientid string) *ClientsRevokeCall {
c := &ClientsRevokeCall{s: r.s, opt_: make(map[string]interface{})}
c.userid = userid
c.clientid = clientid
return c
}
// Fields allows partial responses to be retrieved.
// See https://developers.google.com/gdata/docs/2.0/basics#PartialResponse
// for more information.
func (c *ClientsRevokeCall) Fields(s ...googleapi.Field) *ClientsRevokeCall {
c.opt_["fields"] = googleapi.CombineFields(s)
return c
}
func (c *ClientsRevokeCall) Do() error {
var body io.Reader = nil
params := make(url.Values)
params.Set("alt", "json")
if v, ok := c.opt_["fields"]; ok {
params.Set("fields", fmt.Sprintf("%v", v))
}
urls := googleapi.ResolveRelative(c.s.BasePath, "account/{userid}/refresh/{clientid}")
urls += "?" + params.Encode()
req, _ := http.NewRequest("DELETE", urls, body)
googleapi.Expand(req.URL, map[string]string{
"userid": c.userid,
"clientid": c.clientid,
})
req.Header.Set("User-Agent", "google-api-go-client/0.5")
res, err := c.s.client.Do(req)
if err != nil {
return err
}
defer googleapi.CloseBody(res)
if err := googleapi.CheckResponse(res); err != nil {
return err
}
return nil
// {
// "description": "Revoke all refresh tokens issues to the client for the authenticated user.",
// "httpMethod": "DELETE",
// "id": "dex.Client.Revoke",
// "parameterOrder": [
// "userid",
// "clientid"
// ],
// "parameters": {
// "clientid": {
// "location": "path",
// "required": true,
// "type": "string"
// },
// "userid": {
// "location": "path",
// "required": true,
// "type": "string"
// }
// },
// "path": "account/{userid}/refresh/{clientid}"
// }
}
// method id "dex.Client.List":
// method id "dex.RefreshClient.List":
type RefreshClientListCall struct {
s *Service
@ -417,7 +342,7 @@ type RefreshClientListCall struct {
opt_ map[string]interface{}
}
// List: List all clients that hold refresh tokens for the authenticated
// List: List all clients that hold refresh tokens for the specified
// user.
func (r *RefreshClientService) List(userid string) *RefreshClientListCall {
c := &RefreshClientListCall{s: r.s, opt_: make(map[string]interface{})}
@ -461,9 +386,9 @@ func (c *RefreshClientListCall) Do() (*RefreshClientList, error) {
}
return ret, nil
// {
// "description": "List all clients that hold refresh tokens for the authenticated user.",
// "description": "List all clients that hold refresh tokens for the specified user.",
// "httpMethod": "GET",
// "id": "dex.Client.List",
// "id": "dex.RefreshClient.List",
// "parameterOrder": [
// "userid"
// ],
@ -482,6 +407,81 @@ func (c *RefreshClientListCall) Do() (*RefreshClientList, error) {
}
// method id "dex.RefreshClient.Revoke":
type RefreshClientRevokeCall struct {
s *Service
userid string
clientid string
opt_ map[string]interface{}
}
// Revoke: Revoke all refresh tokens issues to the client for the
// specified user.
func (r *RefreshClientService) Revoke(userid string, clientid string) *RefreshClientRevokeCall {
c := &RefreshClientRevokeCall{s: r.s, opt_: make(map[string]interface{})}
c.userid = userid
c.clientid = clientid
return c
}
// Fields allows partial responses to be retrieved.
// See https://developers.google.com/gdata/docs/2.0/basics#PartialResponse
// for more information.
func (c *RefreshClientRevokeCall) Fields(s ...googleapi.Field) *RefreshClientRevokeCall {
c.opt_["fields"] = googleapi.CombineFields(s)
return c
}
func (c *RefreshClientRevokeCall) Do() error {
var body io.Reader = nil
params := make(url.Values)
params.Set("alt", "json")
if v, ok := c.opt_["fields"]; ok {
params.Set("fields", fmt.Sprintf("%v", v))
}
urls := googleapi.ResolveRelative(c.s.BasePath, "account/{userid}/refresh/{clientid}")
urls += "?" + params.Encode()
req, _ := http.NewRequest("DELETE", urls, body)
googleapi.Expand(req.URL, map[string]string{
"userid": c.userid,
"clientid": c.clientid,
})
req.Header.Set("User-Agent", "google-api-go-client/0.5")
res, err := c.s.client.Do(req)
if err != nil {
return err
}
defer googleapi.CloseBody(res)
if err := googleapi.CheckResponse(res); err != nil {
return err
}
return nil
// {
// "description": "Revoke all refresh tokens issues to the client for the specified user.",
// "httpMethod": "DELETE",
// "id": "dex.RefreshClient.Revoke",
// "parameterOrder": [
// "userid",
// "clientid"
// ],
// "parameters": {
// "clientid": {
// "location": "path",
// "required": true,
// "type": "string"
// },
// "userid": {
// "location": "path",
// "required": true,
// "type": "string"
// }
// },
// "path": "account/{userid}/refresh/{clientid}"
// }
}
// method id "dex.User.Create":
type UsersCreateCall struct {

View file

@ -272,28 +272,6 @@ const DiscoveryJSON = `{
"response": {
"$ref": "ClientWithSecret"
}
},
"Revoke": {
"id": "dex.Client.Revoke",
"description": "Revoke all refresh tokens issues to the client for the authenticated user.",
"httpMethod": "DELETE",
"path": "account/{userid}/refresh/{clientid}",
"parameterOrder": [
"userid",
"clientid"
],
"parameters": {
"clientid": {
"type": "string",
"required": true,
"location": "path"
},
"userid": {
"type": "string",
"required": true,
"location": "path"
}
}
}
}
},
@ -398,8 +376,8 @@ const DiscoveryJSON = `{
"RefreshClient": {
"methods": {
"List": {
"id": "dex.Client.List",
"description": "List all clients that hold refresh tokens for the authenticated user.",
"id": "dex.RefreshClient.List",
"description": "List all clients that hold refresh tokens for the specified user.",
"httpMethod": "GET",
"path": "account/{userid}/refresh",
"parameters": {
@ -415,6 +393,28 @@ const DiscoveryJSON = `{
"response": {
"$ref": "RefreshClientList"
}
},
"Revoke": {
"id": "dex.RefreshClient.Revoke",
"description": "Revoke all refresh tokens issues to the client for the specified user.",
"httpMethod": "DELETE",
"path": "account/{userid}/refresh/{clientid}",
"parameterOrder": [
"userid",
"clientid"
],
"parameters": {
"clientid": {
"type": "string",
"required": true,
"location": "path"
},
"userid": {
"type": "string",
"required": true,
"location": "path"
}
}
}
}
}

View file

@ -266,28 +266,6 @@
"response": {
"$ref": "ClientWithSecret"
}
},
"Revoke": {
"id": "dex.Client.Revoke",
"description": "Revoke all refresh tokens issues to the client for the authenticated user.",
"httpMethod": "DELETE",
"path": "account/{userid}/refresh/{clientid}",
"parameterOrder": [
"userid",
"clientid"
],
"parameters": {
"clientid": {
"type": "string",
"required": true,
"location": "path"
},
"userid": {
"type": "string",
"required": true,
"location": "path"
}
}
}
}
},
@ -392,8 +370,8 @@
"RefreshClient": {
"methods": {
"List": {
"id": "dex.Client.List",
"description": "List all clients that hold refresh tokens for the authenticated user.",
"id": "dex.RefreshClient.List",
"description": "List all clients that hold refresh tokens for the specified user.",
"httpMethod": "GET",
"path": "account/{userid}/refresh",
"parameters": {
@ -409,6 +387,28 @@
"response": {
"$ref": "RefreshClientList"
}
},
"Revoke": {
"id": "dex.RefreshClient.Revoke",
"description": "Revoke all refresh tokens issues to the client for the specified user.",
"httpMethod": "DELETE",
"path": "account/{userid}/refresh/{clientid}",
"parameterOrder": [
"userid",
"clientid"
],
"parameters": {
"clientid": {
"type": "string",
"required": true,
"location": "path"
},
"userid": {
"type": "string",
"required": true,
"location": "path"
}
}
}
}
}

34
scope/scope.go Normal file
View file

@ -0,0 +1,34 @@
package scope
import "strings"
const (
// Scope prefix which indicates initiation of a cross-client authentication flow.
// See https://developers.google.com/identity/protocols/CrossClientAuth
ScopeGoogleCrossClient = "audience:server:client_id:"
)
type Scopes []string
func (s Scopes) OfflineAccess() bool {
return s.HasScope("offline_access")
}
func (s Scopes) HasScope(scope string) bool {
for _, curScope := range s {
if curScope == scope {
return true
}
}
return false
}
func (s Scopes) CrossClientIDs() []string {
clients := []string{}
for _, scope := range s {
if strings.HasPrefix(scope, ScopeGoogleCrossClient) {
clients = append(clients, scope[len(ScopeGoogleCrossClient):])
}
}
return clients
}

View file

@ -37,7 +37,7 @@ func TestClientToken(t *testing.T) {
cli := client.Client{
Metadata: clientMetadata,
}
creds, err := clientManager.New(cli)
creds, err := clientManager.New(cli, nil)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}

View file

@ -42,7 +42,7 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR
cli := client.Client{
Metadata: clientMetadata,
}
creds, err := s.ClientManager.New(cli)
creds, err := s.ClientManager.New(cli, nil)
if err != nil {
log.Errorf("Failed to create new client identity: %v", err)
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")

View file

@ -87,7 +87,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) {
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error()))
return
}
creds, err := c.manager.New(ci)
creds, err := c.manager.New(ci, nil)
if err != nil {
log.Errorf("Failed creating client: %v", err)

330
server/cross_client_test.go Normal file
View file

@ -0,0 +1,330 @@
package server
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"strings"
"testing"
"github.com/coreos/go-oidc/oidc"
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/scope"
)
func makeCrossClientTestFixtures() (*testFixtures, error) {
f, err := makeTestFixtures()
if err != nil {
return nil, fmt.Errorf("couldn't make test fixtures: %v", err)
}
for _, cliData := range []struct {
id string
authorized []string
}{
{
id: "client_a",
}, {
id: "client_b",
authorized: []string{"client_a"},
}, {
id: "client_c",
authorized: []string{"client_a", "client_b"},
},
} {
u := url.URL{
Scheme: "https://",
Path: cliData.id,
Host: cliData.id,
}
cliCreds, err := f.clientManager.New(client.Client{
Credentials: oidc.ClientCredentials{
ID: cliData.id,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{u},
},
}, &clientmanager.ClientOptions{
TrustedPeers: cliData.authorized,
})
if err != nil {
return nil, fmt.Errorf("Unexpected error creating clients: %v", err)
}
f.clientCreds[cliData.id] = *cliCreds
}
return f, nil
}
func TestServerCrossClientAuthAllowed(t *testing.T) {
f, err := makeCrossClientTestFixtures()
if err != nil {
t.Fatalf("couldn't make test fixtures: %v", err)
}
tests := []struct {
reqClient string
authClient string
wantAuthorized bool
wantErr bool
}{
{
reqClient: "client_b",
authClient: "client_a",
wantAuthorized: false,
wantErr: false,
},
{
reqClient: "client_a",
authClient: "client_b",
wantAuthorized: true,
wantErr: false,
},
{
reqClient: "client_a",
authClient: "client_c",
wantAuthorized: true,
wantErr: false,
},
{
reqClient: "client_c",
authClient: "client_b",
wantAuthorized: false,
wantErr: false,
},
{
reqClient: "client_c",
authClient: "nope",
wantErr: false,
},
}
for i, tt := range tests {
got, err := f.srv.CrossClientAuthAllowed(tt.reqClient, tt.authClient)
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
continue
}
if err != nil {
t.Errorf("case %d: unexpected err %v: ", i, err)
}
if got != tt.wantAuthorized {
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAuthorized, got)
}
}
}
func TestHandleAuthCrossClient(t *testing.T) {
f, err := makeCrossClientTestFixtures()
if err != nil {
t.Fatalf("couldn't make test fixtures: %v", err)
}
tests := []struct {
scopes []string
clientID string
wantCode int
}{
{
scopes: []string{scope.ScopeGoogleCrossClient + "client_a"},
clientID: "client_b",
wantCode: http.StatusBadRequest,
},
{
scopes: []string{scope.ScopeGoogleCrossClient + "client_b"},
clientID: "client_a",
wantCode: http.StatusFound,
},
{
scopes: []string{scope.ScopeGoogleCrossClient + "client_b"},
clientID: "client_a",
wantCode: http.StatusFound,
},
{
scopes: []string{scope.ScopeGoogleCrossClient + "client_c"},
clientID: "client_a",
wantCode: http.StatusFound,
},
{
// Two clients that client_a is authorized to mint tokens for.
scopes: []string{
scope.ScopeGoogleCrossClient + "client_c",
scope.ScopeGoogleCrossClient + "client_b",
},
clientID: "client_a",
wantCode: http.StatusFound,
},
{
// Two clients that client_a is authorized to mint tokens for.
scopes: []string{
scope.ScopeGoogleCrossClient + "client_c",
scope.ScopeGoogleCrossClient + "client_a",
},
clientID: "client_b",
wantCode: http.StatusBadRequest,
},
}
idpcs := []connector.Connector{
&fakeConnector{loginURL: "http://fake.example.com"},
}
for i, tt := range tests {
hdlr := handleAuthFunc(f.srv, idpcs, nil, true)
w := httptest.NewRecorder()
query := url.Values{
"response_type": []string{"code"},
"client_id": []string{tt.clientID},
"connector_id": []string{"fake"},
"scope": []string{strings.Join(append([]string{"openid"}, tt.scopes...), " ")},
}
u := fmt.Sprintf("http://server.example.com?%s", query.Encode())
req, err := http.NewRequest("GET", u, nil)
if err != nil {
t.Errorf("case %d: unable to form HTTP request: %v", i, err)
continue
}
hdlr.ServeHTTP(w, req)
if tt.wantCode != w.Code {
t.Errorf("case %d: HTTP code mismatch: want=%d got=%d", i, tt.wantCode, w.Code)
continue
}
}
}
func TestServerCodeTokenCrossClient(t *testing.T) {
f, err := makeCrossClientTestFixtures()
if err != nil {
t.Fatalf("Error creating test fixtures: %v", err)
}
sm := f.sessionManager
tests := []struct {
clientID string
offline bool
refreshToken string
crossClients []string
wantErr bool
wantAUD []string
wantAZP string
}{
// First test the non-cross-client cases, make sure they're undisturbed:
{
// No 'offline_access' in scope, should get empty refresh token.
clientID: testClientID,
refreshToken: "",
wantAUD: []string{testClientID},
},
{
// Have 'offline_access' in scope, should get non-empty refresh token.
clientID: testClientID,
offline: true,
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
wantAUD: []string{testClientID},
},
// Now test cross-client cases:
{
clientID: "client_a",
crossClients: []string{"client_b"},
wantAUD: []string{"client_b"},
wantAZP: "client_a",
},
{
clientID: "client_a",
crossClients: []string{"client_b", "client_a"},
wantAUD: []string{"client_a", "client_b"},
wantAZP: "client_a",
},
}
for i, tt := range tests {
scopes := []string{"openid"}
if tt.offline {
scopes = append(scopes, "offline_access")
}
for _, client := range tt.crossClients {
scopes = append(scopes, scope.ScopeGoogleCrossClient+client)
}
sessionID, err := sm.NewSession("bogus_idpc", tt.clientID, "bogus", url.URL{}, "", false, scopes)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
_, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{})
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
_, err = sm.AttachUser(sessionID, "ID-1")
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
key, err := sm.NewSessionKey(sessionID)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
jwt, token, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
if jwt == nil {
t.Fatalf("case %d: expect non-nil jwt", i)
}
if token != tt.refreshToken {
t.Errorf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
}
claims, err := jwt.Claims()
if err != nil {
t.Fatalf("case %d: unexpected error getting claims: %v", i, err)
}
var gotAUD []string
if len(tt.wantAUD) < 2 {
aud, _, err := claims.StringClaim("aud")
if err != nil {
t.Fatalf("case %d: unexpected error getting 'aud': %q: raw: %v", i, err, claims["aud"])
}
gotAUD = []string{aud}
} else {
gotAUD, _, err = claims.StringsClaim("aud")
if err != nil {
t.Fatalf("case %d: unexpected error getting 'aud': %v", i, err)
}
}
sort.Strings(gotAUD)
if diff := pretty.Compare(tt.wantAUD, gotAUD); diff != "" {
t.Fatalf("case %d: pretty.Compare(tt.wantAUD, gotAUD): %v", i, diff)
}
gotAZP, _, err := claims.StringClaim("azp")
if err != nil {
if err != nil {
t.Fatalf("case %d: unexpected error getting 'aud': %v", i, err)
}
}
if gotAZP != tt.wantAZP {
t.Errorf("case %d: wantAZP=%v, gotAZP=%v", i, tt.wantAZP, gotAZP)
}
}
}

View file

@ -21,6 +21,7 @@ import (
"github.com/coreos/dex/connector"
phttp "github.com/coreos/dex/pkg/http"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/scope"
)
const (
@ -341,30 +342,9 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
}
// Check scopes.
var scopes []string
foundOpenIDScope := false
for _, scope := range acr.Scope {
switch scope {
case "openid":
foundOpenIDScope = true
scopes = append(scopes, scope)
case "offline_access":
// According to the spec, for offline_access scope, the client must
// use a response_type value that would result in an Authorization Code.
// Currently oauth2.ResponseTypeCode is the only supported response type,
// and it's been checked above, so we don't need to check it again here.
//
// TODO(yifan): Verify that 'consent' should be in 'prompt'.
scopes = append(scopes, scope)
default:
// Pass all other scopes.
scopes = append(scopes, scope)
}
}
if !foundOpenIDScope {
log.Errorf("Invalid auth request: missing 'openid' in 'scope'")
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
if scopeErr := validateScopes(srv, acr.ClientID, acr.Scope); scopeErr != nil {
log.Error(scopeErr)
writeAuthError(w, scopeErr, acr.State)
return
}
@ -410,6 +390,67 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
}
}
func validateScopes(srv OIDCServer, clientID string, scopes []string) error {
foundOpenIDScope := false
for i, curScope := range scopes {
if i > 0 && curScope == scopes[i-1] {
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = fmt.Sprintf(
"Duplicate scopes are not allowed: %q",
curScope)
return err
}
switch {
case strings.HasPrefix(curScope, scope.ScopeGoogleCrossClient):
otherClient := curScope[len(scope.ScopeGoogleCrossClient):]
var allowed bool
var err error
if otherClient == clientID {
allowed = true
} else {
allowed, err = srv.CrossClientAuthAllowed(clientID, otherClient)
if err != nil {
return err
}
}
if !allowed {
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = fmt.Sprintf(
"%q is not authorized to perform cross-client requests for %q",
clientID, otherClient)
return err
}
case curScope == "openid":
foundOpenIDScope = true
case curScope == "profile":
case curScope == "email":
case curScope == "offline_access":
// According to the spec, for offline_access scope, the client must
// use a response_type value that would result in an Authorization
// Code. Currently oauth2.ResponseTypeCode is the only supported
// response type, and it's been checked above, so we don't need to
// check it again here.
//
// TODO(yifan): Verify that 'consent' should be in 'prompt'.
default:
// Reject all other scopes.
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = fmt.Sprintf("%q is not a recognized scope", curScope)
return err
}
}
if !foundOpenIDScope {
log.Errorf("Invalid auth request: missing 'openid' in 'scope'")
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = "Invalid auth request: missing 'openid' in 'scope'"
return err
}
return nil
}
func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {

View file

@ -18,6 +18,7 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/scope"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
@ -308,8 +309,110 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
}
}
func TestHandleTokenFunc(t *testing.T) {
func TestValidateScopes(t *testing.T) {
f, err := makeCrossClientTestFixtures()
if err != nil {
t.Fatalf("couldn't make test fixtures: %v", err)
}
tests := []struct {
clientID string
scopes []string
wantErr bool
}{
{
// ERR: no openid scope
clientID: "XXX",
scopes: []string{},
wantErr: true,
},
{
// OK: minimum scopes
clientID: "XXX",
scopes: []string{"openid"},
wantErr: false,
},
{
// OK: offline_access
clientID: "XXX",
scopes: []string{"openid", "offline_access"},
wantErr: false,
},
{
// ERR: unknown scope
clientID: "XXX",
scopes: []string{"openid", "wat"},
wantErr: true,
},
{
// ERR: invalid cross client auth
clientID: "XXX",
scopes: []string{"openid", scope.ScopeGoogleCrossClient + "client_a"},
wantErr: true,
},
{
// OK: valid cross client auth (though perverse - a client
// requesting cross-client auth for itself)
clientID: "client_a",
scopes: []string{"openid", scope.ScopeGoogleCrossClient + "client_a"},
wantErr: false,
},
{
// OK: valid cross client auth
clientID: "client_a",
scopes: []string{"openid", scope.ScopeGoogleCrossClient + "client_b"},
wantErr: false,
},
{
// ERR: valid cross client auth...but duplicated scope.
clientID: "client_a",
scopes: []string{"openid",
scope.ScopeGoogleCrossClient + "client_b",
scope.ScopeGoogleCrossClient + "client_b",
},
wantErr: true,
},
{
// OK: valid cross client auth with >1 clients including itself
clientID: "client_a",
scopes: []string{
"openid",
scope.ScopeGoogleCrossClient + "client_a",
scope.ScopeGoogleCrossClient + "client_b",
scope.ScopeGoogleCrossClient + "client_c",
},
wantErr: false,
},
{
// ERR: valid cross client auth with >1 clients including itself...but no openid!
clientID: "client_a",
scopes: []string{
scope.ScopeGoogleCrossClient + "client_a",
scope.ScopeGoogleCrossClient + "client_b",
scope.ScopeGoogleCrossClient + "client_c",
},
wantErr: true,
},
}
for i, tt := range tests {
err := validateScopes(f.srv, tt.clientID, tt.scopes)
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
continue
}
if err != nil {
t.Errorf("case %d: unexpected err: %v", i, err)
}
}
}
func TestHandleTokenFunc(t *testing.T) {
fx, err := makeTestFixtures()
if err != nil {
t.Fatalf("could not run test fixtures: %v", err)

View file

@ -45,13 +45,19 @@ type OIDCServer interface {
ClientMetadata(string) (*oidc.ClientMetadata, error)
NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error)
Login(oidc.Identity, string) (string, error)
// CodeToken exchanges a code for an ID token and a refresh token string on success.
CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error)
ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error)
// RefreshToken takes a previously generated refresh token and returns a new ID token
// if the token is valid.
RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error)
KillSession(string) error
CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error)
}
type JWTVerifierFactory func(clientID string) oidc.JWTVerifier
@ -438,6 +444,36 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
claims := ses.Claims(s.IssuerURL.String())
user.AddToClaims(claims)
crossClientIDs := ses.Scope.CrossClientIDs()
if len(crossClientIDs) > 0 {
var aud []string
for _, id := range crossClientIDs {
if ses.ClientID == id {
aud = append(aud, id)
continue
}
allowed, err := s.CrossClientAuthAllowed(ses.ClientID, id)
if err != nil {
log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", ses.ClientID, id, err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
if !allowed {
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = fmt.Sprintf(
"%q is not authorized to perform cross-client requests for %q",
ses.ClientID, id)
return nil, "", err
}
aud = append(aud, id)
}
if len(aud) == 1 {
claims.Add("aud", aud[0])
} else {
claims.Add("aud", aud)
}
claims.Add("azp", ses.ClientID)
}
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
@ -521,6 +557,19 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
return jwt, nil
}
func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {
alloweds, err := s.ClientRepo.GetTrustedPeers(nil, authorizingClientID)
if err != nil {
return false, err
}
for _, allowed := range alloweds {
if requestingClientID == allowed {
return true, nil
}
}
return false, nil
}
func (s *Server) JWTVerifierFactory() JWTVerifierFactory {
noop := func() error { return nil }

View file

@ -9,16 +9,17 @@ import (
"testing"
"time"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
)
var validRedirURL = url.URL{
@ -266,6 +267,12 @@ func TestServerLoginDisabledUser(t *testing.T) {
}
func TestServerCodeToken(t *testing.T) {
f, err := makeTestFixtures()
if err != nil {
t.Fatalf("Error creating test fixtures: %v", err)
}
sm := f.sessionManager
tests := []struct {
scope []string
refreshToken string
@ -277,21 +284,14 @@ func TestServerCodeToken(t *testing.T) {
},
// Have 'offline_access' in scope, should get non-empty refresh token.
{
// NOTE(ericchiang): This test assumes that the database ID of the first
// refresh token will be "1".
// NOTE(ericchiang): This test assumes that the database ID of the
// first refresh token will be "1".
scope: []string{"openid", "offline_access"},
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
},
}
for i, tt := range tests {
f, err := makeTestFixtures()
if err != nil {
t.Fatalf("error making test fixtures: %v", err)
}
f.srv.RefreshTokenRepo = refreshtest.NewTestRefreshTokenRepo()
sm := f.sessionManager
sessionID, err := sm.NewSession("bogus_idpc", testClientID, "bogus", url.URL{}, "", false, tt.scope)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
@ -311,11 +311,9 @@ func TestServerCodeToken(t *testing.T) {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
jwt, token, err := f.srv.CodeToken(
oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret,
}, key)
jwt, token, err := f.srv.CodeToken(oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret}, key)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}

View file

@ -14,6 +14,7 @@ import (
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email"
"github.com/coreos/dex/refresh/refreshtest"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
@ -83,6 +84,11 @@ var (
}
testPrivKey, _ = key.GeneratePrivateKey()
testClientCreds = oidc.ClientCredentials{
ID: testClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
}
)
type testFixtures struct {
@ -93,6 +99,7 @@ type testFixtures struct {
redirectURL url.URL
clientRepo client.ClientRepo
clientManager *clientmanager.ClientManager
clientCreds map[string]oidc.ClientCredentials
}
type testFixtureOptions struct {
@ -150,6 +157,8 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
emailer, err := email.NewTemplatizedEmailerFromGlobs(
emailTemplatesLocation+"/*.txt",
emailTemplatesLocation+"/*.html",
@ -210,6 +219,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
UserManager: userManager,
ClientManager: clientManager,
KeyManager: km,
RefreshTokenRepo: refreshTokenRepo,
}
err = setTemplates(srv, tpl)
@ -243,5 +253,8 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
emailer: emailer,
clientRepo: clientRepo,
clientManager: clientManager,
clientCreds: map[string]oidc.ClientCredentials{
testClientID: testClientCreds,
},
}, nil
}

View file

@ -6,6 +6,8 @@ import (
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/scope"
)
const (
@ -46,11 +48,13 @@ type Session struct {
// Regsiter indicates that this session is a registration flow.
Register bool
// Nonce is optionally provided in the initial authorization request, and propogated in such cases to the generated claims.
// Nonce is optionally provided in the initial authorization request, and
// propogated in such cases to the generated claims.
Nonce string
// Scope is the 'scope' field in the authentication request. Example scopes are 'openid', 'email', 'offline', etc.
Scope []string
// Scope is the 'scope' field in the authentication request. Example scopes
// are 'openid', 'email', 'offline', etc.
Scope scope.Scopes
}
// Claims returns a new set of Claims for the current session.