diff --git a/admin/api.go b/admin/api.go index 1e782d2b..b7e2f39c 100644 --- a/admin/api.go +++ b/admin/api.go @@ -141,7 +141,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem } // metadata is guaranteed to have at least one redirect_uri by earlier validation. - creds, err := a.clientManager.New(cli) + creds, err := a.clientManager.New(cli, nil) if err != nil { return adminschema.ClientCreateResponse{}, mapError(err) } diff --git a/client/client.go b/client/client.go index b7abb93a..9a862ef6 100644 --- a/client/client.go +++ b/client/client.go @@ -15,6 +15,7 @@ import ( ) var ( + ErrorInvalidClientID = errors.New("not a valid client ID") ErrorInvalidRedirectURL = errors.New("not a valid redirect url for the given client") ErrorCantChooseRedirectURL = errors.New("must provide a redirect url; client has many") ErrorNoValidRedirectURLs = errors.New("no valid redirect URLs for this client.") @@ -60,6 +61,12 @@ type ClientRepo interface { New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error) Update(tx repo.Transaction, client Client) error + + // GetTrustedPeers returns the list of clients authorized to mint ID token for the given client. + GetTrustedPeers(tx repo.Transaction, clientID string) ([]string, error) + + // SetTrustedPeers sets the list of clients authorized to mint ID token for the given client. + SetTrustedPeers(tx repo.Transaction, clientID string, clientIDs []string) error } // ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise. diff --git a/client/manager/manager.go b/client/manager/manager.go index 28335936..5ee0790b 100644 --- a/client/manager/manager.go +++ b/client/manager/manager.go @@ -21,6 +21,10 @@ const ( maxSecretLength = 72 ) +type ClientOptions struct { + TrustedPeers []string +} + type SecretGenerator func() ([]byte, error) func DefaultSecretGenerator() ([]byte, error) { @@ -63,7 +67,7 @@ func NewClientManager(clientRepo client.ClientRepo, txnFactory repo.TransactionF } } -func (m *ClientManager) New(cli client.Client) (*oidc.ClientCredentials, error) { +func (m *ClientManager) New(cli client.Client, options *ClientOptions) (*oidc.ClientCredentials, error) { tx, err := m.begin() if err != nil { return nil, err @@ -83,6 +87,13 @@ func (m *ClientManager) New(cli client.Client) (*oidc.ClientCredentials, error) return nil, err } + if options != nil && len(options.TrustedPeers) > 0 { + err = m.clientRepo.SetTrustedPeers(tx, creds.ID, options.TrustedPeers) + if err != nil { + return nil, err + } + } + err = tx.Commit() if err != nil { return nil, err diff --git a/client/manager/manager_test.go b/client/manager/manager_test.go index 2ab2083c..b00a4d22 100644 --- a/client/manager/manager_test.go +++ b/client/manager/manager_test.go @@ -132,7 +132,7 @@ func TestAuthenticate(t *testing.T) { cli := client.Client{ Metadata: cm, } - cc, err := f.mgr.New(cli) + cc, err := f.mgr.New(cli, nil) if err != nil { t.Fatalf(err.Error()) } diff --git a/cmd/dexctl/driver_db.go b/cmd/dexctl/driver_db.go index 19bfc9f1..b0471e83 100644 --- a/cmd/dexctl/driver_db.go +++ b/cmd/dexctl/driver_db.go @@ -34,7 +34,7 @@ func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials, cli := client.Client{ Metadata: meta, } - return d.ciManager.New(cli) + return d.ciManager.New(cli, nil) } func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) { diff --git a/db/client.go b/db/client.go index edbd87c0..895e9e79 100644 --- a/db/client.go +++ b/db/client.go @@ -16,7 +16,8 @@ import ( ) const ( - clientTableName = "client_identity" + clientTableName = "client_identity" + trustedPeerTableName = "trusted_peers" // postgres error codes pgErrorCodeUniqueViolation = "23505" // unique_violation @@ -29,6 +30,13 @@ func init() { autoinc: false, pkey: []string{"id"}, }) + + register(table{ + name: trustedPeerTableName, + model: trustedPeerModel{}, + autoinc: false, + pkey: []string{"client_id", "trusted_client_id"}, + }) } func newClientModel(cli client.Client) (*clientModel, error) { @@ -58,6 +66,11 @@ type clientModel struct { DexAdmin bool `db:"dex_admin"` } +type trustedPeerModel struct { + ClientID string `db:"client_id"` + TrustedClientID string `db:"trusted_client_id"` +} + func (m *clientModel) Client() (*client.Client, error) { ci := client.Client{ Credentials: oidc.ClientCredentials{ @@ -254,3 +267,63 @@ func (r *clientRepo) update(tx repo.Transaction, cli client.Client) error { _, err = ex.Update(cm) return err } + +func (r *clientRepo) GetTrustedPeers(tx repo.Transaction, clientID string) ([]string, error) { + ex := r.executor(tx) + if clientID == "" { + return nil, client.ErrorInvalidClientID + } + + qt := r.quote(trustedPeerTableName) + var ids []string + _, err := ex.Select(&ids, fmt.Sprintf("SELECT trusted_client_id from %s where client_id = $1", qt), clientID) + + if err != nil { + if err != sql.ErrNoRows { + return nil, err + } + return nil, nil + } + + return ids, nil +} + +func (r *clientRepo) SetTrustedPeers(tx repo.Transaction, clientID string, clientIDs []string) error { + ex := r.executor(tx) + qt := r.quote(trustedPeerTableName) + + // First delete all existing rows + _, err := ex.Exec(fmt.Sprintf("DELETE from %s where client_id = $1", qt), clientID) + if err != nil { + return err + } + + // Ensure that the client exists. + _, err = r.get(tx, clientID) + if err != nil { + return err + } + + // Verify that all the clients are valid + for _, curID := range clientIDs { + _, err := r.get(tx, curID) + if err != nil { + return err + } + } + + // Set the clients + rows := []interface{}{} + for _, curID := range clientIDs { + rows = append(rows, &trustedPeerModel{ + ClientID: clientID, + TrustedClientID: curID, + }) + } + err = ex.Insert(rows...) + if err != nil { + return err + } + + return nil +} diff --git a/db/migrate.go b/db/migrate.go index b6ebaefb..8a6e8488 100644 --- a/db/migrate.go +++ b/db/migrate.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/go-gorp/gorp" - "github.com/rubenv/sql-migrate" + migrate "github.com/rubenv/sql-migrate" "github.com/coreos/dex/db/migrations" ) diff --git a/db/migrate_sqlite3.go b/db/migrate_sqlite3.go index af00fcf8..6e2142a4 100644 --- a/db/migrate_sqlite3.go +++ b/db/migrate_sqlite3.go @@ -70,4 +70,10 @@ CREATE TABLE session_key ( expires_at bigint, stale integer ); + +CREATE TABLE trusted_peers ( + client_id text NOT NULL, + trusted_client_id text NOT NULL +); + ` diff --git a/db/migrations/0012_add_cross_client_authorizers.sql b/db/migrations/0012_add_cross_client_authorizers.sql new file mode 100644 index 00000000..6939f4c0 --- /dev/null +++ b/db/migrations/0012_add_cross_client_authorizers.sql @@ -0,0 +1,5 @@ +-- +migrate Up +CREATE TABLE IF NOT EXISTS "trusted_peers" ( + "client_id" text not null, + "trusted_client_id" text not null, + primary key ("client_id", "trusted_client_id")) ; diff --git a/db/migrations/assets.go b/db/migrations/assets.go index 6b3821e7..e0d995b4 100644 --- a/db/migrations/assets.go +++ b/db/migrations/assets.go @@ -72,5 +72,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{ "-- +migrate Up\n\n-- This migration is a fix for a bug that allowed duplicate emails if they used different cases (see #338).\n-- When migrating, dex will not take the liberty of deleting rows for duplicate cases. Instead it will\n-- raise an exception and call for an admin to remove duplicates manually.\n\nCREATE OR REPLACE FUNCTION raise_exp() RETURNS VOID AS $$\nBEGIN\n RAISE EXCEPTION 'Found duplicate emails when using case insensitive comparision, cannot perform migration.';\nEND;\n$$ LANGUAGE plpgsql;\n\nSELECT LOWER(email),\n COUNT(email),\n CASE\n WHEN COUNT(email) > 1 THEN raise_exp()\n ELSE NULL\n END\nFROM authd_user\nGROUP BY LOWER(email);\n\nUPDATE authd_user SET email = LOWER(email);\n", }, }, + { + Id: "0012_add_cross_client_authorizers.sql", + Up: []string{ + "-- +migrate Up\nCREATE TABLE IF NOT EXISTS \"trusted_peers\" (\n \"client_id\" text not null,\n \"trusted_client_id\" text not null,\n primary key (\"client_id\", \"trusted_client_id\")) ;\n", + }, + }, }, } diff --git a/db/user.go b/db/user.go index 00991668..6aecdd63 100644 --- a/db/user.go +++ b/db/user.go @@ -248,7 +248,7 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us if err != sql.ErrNoRows { return nil, err } - return nil, err + return nil, nil } if len(rims) == 0 { return nil, nil diff --git a/functional/db_test.go b/functional/db_test.go index c8322afd..cc78d8a0 100644 --- a/functional/db_test.go +++ b/functional/db_test.go @@ -316,7 +316,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) { cli := client.Client{ Metadata: cm, } - cc, err := m.New(cli) + cc, err := m.New(cli, nil) if err != nil { t.Fatalf(err.Error()) } diff --git a/server/auth_middleware_test.go b/server/auth_middleware_test.go index 568f0564..0be50af7 100644 --- a/server/auth_middleware_test.go +++ b/server/auth_middleware_test.go @@ -37,7 +37,7 @@ func TestClientToken(t *testing.T) { cli := client.Client{ Metadata: clientMetadata, } - creds, err := clientManager.New(cli) + creds, err := clientManager.New(cli, nil) if err != nil { t.Fatalf("Failed to create client: %v", err) } diff --git a/server/client_registration.go b/server/client_registration.go index 0de3490a..6a95760a 100644 --- a/server/client_registration.go +++ b/server/client_registration.go @@ -42,7 +42,7 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR cli := client.Client{ Metadata: clientMetadata, } - creds, err := s.ClientManager.New(cli) + creds, err := s.ClientManager.New(cli, nil) if err != nil { log.Errorf("Failed to create new client identity: %v", err) return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata") diff --git a/server/client_resource.go b/server/client_resource.go index b00cbee9..b4fad488 100644 --- a/server/client_resource.go +++ b/server/client_resource.go @@ -87,7 +87,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) { writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error())) return } - creds, err := c.manager.New(ci) + creds, err := c.manager.New(ci, nil) if err != nil { log.Errorf("Failed creating client: %v", err)