client: add client manager

adds a client manager to handle business logic, leaving the repo
for basic crud operations. Also adds client to the test script
This commit is contained in:
Evan Cordell 2016-05-12 09:53:01 -07:00
parent 3da98fcb8e
commit a418e1c4e7
37 changed files with 1094 additions and 676 deletions

View file

@ -1,31 +1,27 @@
// package admin provides an implementation of the API described in auth/schema/adminschema. // Package admin provides an implementation of the API described in auth/schema/adminschema.
package admin package admin
import ( import (
"net/http" "net/http"
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
)
var (
ClientIDGenerator = oidc.GenClientID
) )
// AdminAPI provides the logic necessary to implement the Admin API. // AdminAPI provides the logic necessary to implement the Admin API.
type AdminAPI struct { type AdminAPI struct {
userManager *manager.UserManager userManager *usermanager.UserManager
userRepo user.UserRepo userRepo user.UserRepo
passwordInfoRepo user.PasswordInfoRepo passwordInfoRepo user.PasswordInfoRepo
clientRepo client.ClientRepo clientRepo client.ClientRepo
clientManager *clientmanager.ClientManager
localConnectorID string localConnectorID string
} }
func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *manager.UserManager, localConnectorID string) *AdminAPI { func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, localConnectorID string) *AdminAPI {
if localConnectorID == "" { if localConnectorID == "" {
panic("must specify non-blank localConnectorID") panic("must specify non-blank localConnectorID")
} }
@ -34,6 +30,7 @@ func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRe
userRepo: userRepo, userRepo: userRepo,
passwordInfoRepo: pwiRepo, passwordInfoRepo: pwiRepo,
clientRepo: clientRepo, clientRepo: clientRepo,
clientManager: clientManager,
localConnectorID: localConnectorID, localConnectorID: localConnectorID,
} }
} }
@ -141,14 +138,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem
} }
// metadata is guaranteed to have at least one redirect_uri by earlier validation. // metadata is guaranteed to have at least one redirect_uri by earlier validation.
id, err := ClientIDGenerator(cli.Metadata.RedirectURIs[0].Host) creds, err := a.clientManager.New(cli.Metadata)
if err != nil {
return adminschema.ClientCreateResponse{}, mapError(err)
}
cli.Credentials.ID = id
creds, err := a.clientRepo.New(nil, cli)
if err != nil { if err != nil {
return adminschema.ClientCreateResponse{}, mapError(err) return adminschema.ClientCreateResponse{}, mapError(err)
} }

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/schema/adminschema"
@ -17,6 +18,7 @@ type testFixtures struct {
ur user.UserRepo ur user.UserRepo
pwr user.PasswordInfoRepo pwr user.PasswordInfoRepo
cr client.ClientRepo cr client.ClientRepo
cm *clientmanager.ClientManager
mgr *manager.UserManager mgr *manager.UserManager
adAPI *AdminAPI adAPI *AdminAPI
} }
@ -71,7 +73,8 @@ func makeTestFixtures() *testFixtures {
}() }()
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{}) f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, "local") f.cm = clientmanager.NewClientManager(f.cr, db.TransactionFactory(dbMap), clientmanager.ManagerOptions{})
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, f.cm, "local")
return f return f
} }

View file

@ -1,12 +1,15 @@
package client package client
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
"net/url" "net/url"
"reflect" "reflect"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
@ -18,6 +21,24 @@ var (
ErrorNotFound = errors.New("no data found") ErrorNotFound = errors.New("no data found")
) )
const (
bcryptHashCost = 10
)
func HashSecret(creds oidc.ClientCredentials) ([]byte, error) {
secretBytes, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
return nil, err
}
hashed, err := bcrypt.GenerateFromPassword([]byte(
secretBytes),
bcryptHashCost)
if err != nil {
return nil, err
}
return hashed, nil
}
type Client struct { type Client struct {
Credentials oidc.ClientCredentials Credentials oidc.ClientCredentials
Metadata oidc.ClientMetadata Metadata oidc.ClientMetadata
@ -27,16 +48,8 @@ type Client struct {
type ClientRepo interface { type ClientRepo interface {
Get(tx repo.Transaction, clientID string) (Client, error) Get(tx repo.Transaction, clientID string) (Client, error)
// Metadata returns one matching ClientMetadata if the given client // GetSecret returns the (base64 encoded) hashed client secret
// exists, otherwise nil. The returned error will be non-nil only GetSecret(tx repo.Transaction, clientID string) ([]byte, error)
// if the repo was unable to determine client existence.
Metadata(tx repo.Transaction, clientID string) (*oidc.ClientMetadata, error)
// Authenticate asserts that a client with the given ID exists and
// that the provided secret matches. If either of these assertions
// fail, (false, nil) will be returned. Only if the repo is unable
// to make these assertions will a non-nil error be returned.
Authenticate(tx repo.Transaction, creds oidc.ClientCredentials) (bool, error)
// All returns all registered Clients // All returns all registered Clients
All(tx repo.Transaction) ([]Client, error) All(tx repo.Transaction) ([]Client, error)
@ -46,9 +59,7 @@ type ClientRepo interface {
// in a ClientCredentials struct along with the provided ID. // in a ClientCredentials struct along with the provided ID.
New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error) New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error)
SetDexAdmin(clientID string, isAdmin bool) error Update(tx repo.Transaction, client Client) error
IsDexAdmin(clientID string) (bool, error)
} }
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise. // ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.

View file

@ -34,7 +34,7 @@ var (
badSecretClient = `{ badSecretClient = `{
"id": "my_id", "id": "my_id",
"secret": "` + "****" + `", "secret": "` + "" + `",
"redirectURLs": ["https://client.example.com"] "redirectURLs": ["https://client.example.com"]
}` }`
@ -64,7 +64,7 @@ func TestClientsFromReader(t *testing.T) {
{ {
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "my_id", ID: "my_id",
Secret: "my_secret", Secret: goodSecret1,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -80,7 +80,7 @@ func TestClientsFromReader(t *testing.T) {
{ {
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "my_id", ID: "my_id",
Secret: "my_secret", Secret: goodSecret1,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -91,7 +91,7 @@ func TestClientsFromReader(t *testing.T) {
{ {
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "my_other_id", ID: "my_other_id",
Secret: "my_other_secret", Secret: goodSecret2,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -101,7 +101,8 @@ func TestClientsFromReader(t *testing.T) {
}, },
}, },
}, },
}, { },
{
json: "[" + badURLClient + "]", json: "[" + badURLClient + "]",
wantErr: true, wantErr: true,
}, },

217
client/manager/manager.go Normal file
View file

@ -0,0 +1,217 @@
package manager
import (
"encoding/base64"
"fmt"
"errors"
"github.com/coreos/dex/client"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc"
"golang.org/x/crypto/bcrypt"
)
const (
// Blowfish, the algorithm underlying bcrypt, has a maximum
// password length of 72. We explicitly track and check this
// since the bcrypt library will silently ignore portions of
// a password past the first 72 characters.
maxSecretLength = 72
)
type SecretGenerator func() ([]byte, error)
func DefaultSecretGenerator() ([]byte, error) {
return pcrypto.RandBytes(maxSecretLength)
}
func CompareHashAndPassword(hashedPassword, password []byte) error {
if len(password) > maxSecretLength {
return errors.New("password length greater than max secret length")
}
return bcrypt.CompareHashAndPassword(hashedPassword, password)
}
// ClientManager performs client-related "business-logic" functions on client and related objects.
// This is in contrast to the Repos which perform little more than CRUD operations.
type ClientManager struct {
clientRepo client.ClientRepo
begin repo.TransactionFactory
secretGenerator SecretGenerator
clientIDGenerator func(string) (string, error)
}
type ManagerOptions struct {
SecretGenerator func() ([]byte, error)
ClientIDGenerator func(string) (string, error)
}
func NewClientManager(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *ClientManager {
if options.SecretGenerator == nil {
options.SecretGenerator = DefaultSecretGenerator
}
if options.ClientIDGenerator == nil {
options.ClientIDGenerator = oidc.GenClientID
}
return &ClientManager{
clientRepo: clientRepo,
begin: txnFactory,
secretGenerator: options.SecretGenerator,
clientIDGenerator: options.ClientIDGenerator,
}
}
func NewClientManagerFromClients(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, clients []client.Client, options ManagerOptions) (*ClientManager, error) {
clientManager := NewClientManager(clientRepo, txnFactory, options)
tx, err := clientManager.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
for _, c := range clients {
if c.Credentials.Secret == "" {
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
}
cli, err := clientManager.clientFromMetadata(c.Metadata)
if err != nil {
return nil, err
}
cli.Admin = c.Admin
_, err = clientRepo.New(tx, cli)
if err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return clientManager, nil
}
func (m *ClientManager) New(meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) {
tx, err := m.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
cli, err := m.clientFromMetadata(meta)
if err != nil {
return nil, err
}
creds := cli.Credentials
// Save Client
_, err = m.clientRepo.New(tx, cli)
if err != nil {
return nil, err
}
err = tx.Commit()
if err != nil {
return nil, err
}
// Returns creds with unhashed secret
return &creds, nil
}
func (m *ClientManager) Get(id string) (client.Client, error) {
return m.clientRepo.Get(nil, id)
}
func (m *ClientManager) All() ([]client.Client, error) {
return m.clientRepo.All(nil)
}
func (m *ClientManager) Metadata(clientID string) (*oidc.ClientMetadata, error) {
c, err := m.clientRepo.Get(nil, clientID)
if err != nil {
return nil, err
}
return &c.Metadata, nil
}
func (m *ClientManager) IsDexAdmin(clientID string) (bool, error) {
c, err := m.clientRepo.Get(nil, clientID)
if err != nil {
return false, err
}
return c.Admin, nil
}
func (m *ClientManager) SetDexAdmin(clientID string, isAdmin bool) error {
tx, err := m.begin()
if err != nil {
return err
}
defer tx.Rollback()
c, err := m.clientRepo.Get(tx, clientID)
if err != nil {
return err
}
c.Admin = isAdmin
err = m.clientRepo.Update(tx, c)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
func (m *ClientManager) Authenticate(creds oidc.ClientCredentials) (bool, error) {
clientSecret, err := m.clientRepo.GetSecret(nil, creds.ID)
if err != nil || clientSecret == nil {
return false, nil
}
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
log.Errorf("error Decoding client creds: %v", err)
return false, nil
}
ok := CompareHashAndPassword(clientSecret, dec) == nil
return ok, nil
}
func (m *ClientManager) clientFromMetadata(meta oidc.ClientMetadata) (client.Client, error) {
// Generate Client ID
if len(meta.RedirectURIs) < 1 {
return client.Client{}, errors.New("no client redirect url given")
}
clientID, err := m.clientIDGenerator(meta.RedirectURIs[0].Host)
if err != nil {
return client.Client{}, err
}
// Generate Secret
secret, err := m.secretGenerator()
if err != nil {
return client.Client{}, err
}
clientSecret := base64.URLEncoding.EncodeToString(secret)
cli := client.Client{
Credentials: oidc.ClientCredentials{
ID: clientID,
Secret: clientSecret,
},
Metadata: meta,
}
return cli, nil
}

View file

@ -0,0 +1,163 @@
package manager
import (
"encoding/base64"
"fmt"
"net/url"
"testing"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/go-oidc/oidc"
)
type testFixtures struct {
clientRepo client.ClientRepo
mgr *ClientManager
}
var (
goodSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
)
func makeTestFixtures() *testFixtures {
f := &testFixtures{}
dbMap := db.NewMemDB()
clients := []client.Client{
{
Credentials: oidc.ClientCredentials{
ID: "client.example.com",
Secret: goodSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
{Scheme: "http", Host: "client.example.com", Path: "/"},
},
},
Admin: true,
},
}
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
f.clientRepo = db.NewClientRepo(dbMap)
clientManager, err := NewClientManagerFromClients(f.clientRepo, db.TransactionFactory(dbMap), clients, ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
panic("Failed to create client manager: " + err.Error())
}
f.mgr = clientManager
return f
}
func TestMetadata(t *testing.T) {
tests := []struct {
clientID string
uri string
wantErr bool
}{
{
clientID: "client.example.com",
uri: "http://client.example.com/",
wantErr: false,
},
}
for i, tt := range tests {
f := makeTestFixtures()
md, err := f.mgr.Metadata(tt.clientID)
if err != nil {
t.Errorf("case %d: unexpected err: %v", i, err)
continue
}
if md.RedirectURIs[0].String() != tt.uri {
t.Errorf("case %d: manager.Metadata.RedirectURIs: want=%q got=%q", i, tt.uri, md.RedirectURIs[0].String())
continue
}
}
}
func TestIsDexAdmin(t *testing.T) {
tests := []struct {
clientID string
isAdmin bool
wantErr bool
}{
{
clientID: "client.example.com",
isAdmin: true,
wantErr: false,
},
}
for i, tt := range tests {
f := makeTestFixtures()
admin, err := f.mgr.IsDexAdmin(tt.clientID)
if err != nil {
t.Errorf("case %d: unexpected err: %v", i, err)
continue
}
if admin != tt.isAdmin {
t.Errorf("case %d: manager.Admin want=%t got=%t", i, tt.isAdmin, admin)
continue
}
}
}
func TestSetDexAdmin(t *testing.T) {
f := makeTestFixtures()
err := f.mgr.SetDexAdmin("client.example.com", false)
if err != nil {
t.Errorf("unexpected err: %v", err)
}
admin, _ := f.mgr.IsDexAdmin("client.example.com")
if admin {
t.Errorf("expected admin to be false")
}
}
func TestAuthenticate(t *testing.T) {
f := makeTestFixtures()
cm := oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com", Path: "/cb"},
},
}
cc, err := f.mgr.New(cm)
if err != nil {
t.Fatalf(err.Error())
}
ok, err := f.mgr.Authenticate(*cc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
} else if !ok {
t.Fatalf("Authentication failed for good creds")
}
creds := []oidc.ClientCredentials{
//completely made up
oidc.ClientCredentials{ID: "foo", Secret: "bar"},
// good client ID, bad secret
oidc.ClientCredentials{ID: cc.ID, Secret: "bar"},
// bad client ID, good secret
oidc.ClientCredentials{ID: "foo", Secret: cc.Secret},
// good client ID, secret with some fluff on the end
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
}
for i, c := range creds {
ok, err := f.mgr.Authenticate(c)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
} else if ok {
t.Errorf("case %d: authentication succeeded for bad creds", i)
}
}
}

View file

@ -15,6 +15,7 @@ import (
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/coreos/dex/admin" "github.com/coreos/dex/admin"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
pflag "github.com/coreos/dex/pkg/flag" pflag "github.com/coreos/dex/pkg/flag"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
@ -119,8 +120,9 @@ func main() {
clientRepo := db.NewClientRepo(dbc) clientRepo := db.NewClientRepo(dbc)
userManager := manager.NewUserManager(userRepo, userManager := manager.NewUserManager(userRepo,
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, *localConnectorID) adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, clientManager, *localConnectorID)
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...) kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
if err != nil { if err != nil {
log.Fatalf(err.Error()) log.Fatalf(err.Error())

View file

@ -1,7 +1,7 @@
package main package main
import ( import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
@ -14,15 +14,15 @@ func newDBDriver(dsn string) (driver, error) {
} }
drv := &dbDriver{ drv := &dbDriver{
ciRepo: db.NewClientRepo(dbc),
cfgRepo: db.NewConnectorConfigRepo(dbc), cfgRepo: db.NewConnectorConfigRepo(dbc),
ciManager: manager.NewClientManager(db.NewClientRepo(dbc), db.TransactionFactory(dbc), manager.ManagerOptions{}),
} }
return drv, nil return drv, nil
} }
type dbDriver struct { type dbDriver struct {
ciRepo client.ClientRepo ciManager *manager.ClientManager
cfgRepo *db.ConnectorConfigRepo cfgRepo *db.ConnectorConfigRepo
} }
@ -30,18 +30,7 @@ func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials,
if err := meta.Valid(); err != nil { if err := meta.Valid(); err != nil {
return nil, err return nil, err
} }
return d.ciManager.New(meta)
clientID, err := oidc.GenClientID(meta.RedirectURIs[0].Host)
if err != nil {
return nil, err
}
return d.ciRepo.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: clientID,
},
Metadata: meta,
})
} }
func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) { func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) {

View file

@ -2,7 +2,6 @@ package db
import ( import (
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -10,10 +9,8 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
) )
@ -21,14 +18,6 @@ import (
const ( const (
clientTableName = "client_identity" clientTableName = "client_identity"
bcryptHashCost = 10
// Blowfish, the algorithm underlying bcrypt, has a maximum
// password length of 72. We explicitly track and check this
// since the bcrypt library will silently ignore portions of
// a password past the first 72 characters.
maxSecretLength = 72
// postgres error codes // postgres error codes
pgErrorCodeUniqueViolation = "23505" // unique_violation pgErrorCodeUniqueViolation = "23505" // unique_violation
) )
@ -43,17 +32,10 @@ func init() {
} }
func newClientModel(cli client.Client) (*clientModel, error) { func newClientModel(cli client.Client) (*clientModel, error) {
secretBytes, err := base64.URLEncoding.DecodeString(cli.Credentials.Secret) hashed, err := client.HashSecret(cli.Credentials)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hashed, err := bcrypt.GenerateFromPassword([]byte(
secretBytes),
bcryptHashCost)
if err != nil {
return nil, err
}
bmeta, err := json.Marshal(&cli.Metadata) bmeta, err := json.Marshal(&cli.Metadata)
if err != nil { if err != nil {
return nil, err return nil, err
@ -93,52 +75,16 @@ func (m *clientModel) Client() (*client.Client, error) {
func NewClientRepo(dbm *gorp.DbMap) client.ClientRepo { func NewClientRepo(dbm *gorp.DbMap) client.ClientRepo {
return newClientRepo(dbm) return newClientRepo(dbm)
}
func NewClientRepoWithSecretGenerator(dbm *gorp.DbMap, secGen SecretGenerator) client.ClientRepo {
rep := newClientRepo(dbm)
rep.secretGenerator = secGen
return rep
} }
func newClientRepo(dbm *gorp.DbMap) *clientRepo { func newClientRepo(dbm *gorp.DbMap) *clientRepo {
return &clientRepo{ return &clientRepo{
db: &db{dbm}, db: &db{dbm},
secretGenerator: DefaultSecretGenerator,
} }
} }
func NewClientRepoFromClients(dbm *gorp.DbMap, clients []client.Client) (client.ClientRepo, error) {
repo := newClientRepo(dbm)
tx, err := repo.begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
exec := repo.executor(tx)
for _, c := range clients {
if c.Credentials.Secret == "" {
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
}
cm, err := newClientModel(c)
if err != nil {
return nil, err
}
err = exec.Insert(cm)
if err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return repo, nil
}
type clientRepo struct { type clientRepo struct {
*db *db
secretGenerator SecretGenerator
} }
func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) { func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) {
@ -164,82 +110,28 @@ func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, e
return *ci, nil return *ci, nil
} }
func (r *clientRepo) Metadata(tx repo.Transaction, clientID string) (*oidc.ClientMetadata, error) { func (r *clientRepo) GetSecret(tx repo.Transaction, clientID string) ([]byte, error) {
c, err := r.Get(tx, clientID) m, err := r.getModel(tx, clientID)
if err != nil { if err != nil || m == nil {
return nil, err return nil, err
} }
return m.Secret, nil
return &c.Metadata, nil
} }
func (r *clientRepo) IsDexAdmin(clientID string) (bool, error) { func (r *clientRepo) Update(tx repo.Transaction, cli client.Client) error {
m, err := r.executor(nil).Get(clientModel{}, clientID) if cli.Credentials.ID == "" {
if m == nil || err != nil { return client.ErrorNotFound
return false, err
} }
// make sure this client exists already
cim, ok := m.(*clientModel) _, err := r.get(tx, cli.Credentials.ID)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return false, errors.New("unrecognized model")
}
return cim.DexAdmin, nil
}
func (r *clientRepo) SetDexAdmin(clientID string, isAdmin bool) error {
tx, err := r.begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() err = r.update(tx, cli)
exec := r.executor(tx)
m, err := exec.Get(clientModel{}, clientID)
if m == nil || err != nil {
return err
}
cim, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return errors.New("unrecognized model")
}
cim.DexAdmin = isAdmin
_, err = exec.Update(cim)
if err != nil { if err != nil {
return err return err
} }
return nil
return tx.Commit()
}
func (r *clientRepo) Authenticate(tx repo.Transaction, creds oidc.ClientCredentials) (bool, error) {
m, err := r.executor(tx).Get(clientModel{}, creds.ID)
if m == nil || err != nil {
return false, err
}
cim, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return false, errors.New("unrecognized model")
}
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
log.Errorf("error Decoding client creds: %v", err)
return false, nil
}
if len(dec) > maxSecretLength {
return false, nil
}
ok = bcrypt.CompareHashAndPassword(cim.Secret, dec) == nil
return ok, nil
} }
var alreadyExistsCheckers []func(err error) bool var alreadyExistsCheckers []func(err error) bool
@ -261,19 +153,7 @@ func isAlreadyExistsErr(err error) bool {
return false return false
} }
type SecretGenerator func() ([]byte, error)
func DefaultSecretGenerator() ([]byte, error) {
return pcrypto.RandBytes(maxSecretLength)
}
func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) { func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) {
secret, err := r.secretGenerator()
if err != nil {
return nil, err
}
cli.Credentials.Secret = base64.URLEncoding.EncodeToString(secret)
cim, err := newClientModel(cli) cim, err := newClientModel(cli)
if err != nil { if err != nil {
@ -318,3 +198,47 @@ func (r *clientRepo) All(tx repo.Transaction) ([]client.Client, error) {
} }
return cs, nil return cs, nil
} }
func (r *clientRepo) get(tx repo.Transaction, clientID string) (client.Client, error) {
cm, err := r.getModel(tx, clientID)
if err != nil {
return client.Client{}, err
}
cli, err := cm.Client()
if err != nil {
return client.Client{}, err
}
return *cli, nil
}
func (r *clientRepo) getModel(tx repo.Transaction, clientID string) (*clientModel, error) {
ex := r.executor(tx)
m, err := ex.Get(clientModel{}, clientID)
if err != nil {
return nil, err
}
if m == nil {
return nil, client.ErrorNotFound
}
cm, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return nil, errors.New("unrecognized model")
}
return cm, nil
}
func (r *clientRepo) update(tx repo.Transaction, cli client.Client) error {
ex := r.executor(tx)
cm, err := newClientModel(cli)
if err != nil {
return err
}
_, err = ex.Update(cm)
return err
}

View file

@ -14,6 +14,7 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
@ -201,20 +202,22 @@ func TestDBClientRepoMetadata(t *testing.T) {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
got, err := r.Metadata(nil, "foo") got, err := r.Get(nil, "foo")
if err != nil { if err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
if diff := pretty.Compare(cm, *got); diff != "" { if diff := pretty.Compare(cm, got.Metadata); diff != "" {
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff) t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
} }
} }
func TestDBClientRepoMetadataNoExist(t *testing.T) { func TestDBClientRepoMetadataNoExist(t *testing.T) {
r := db.NewClientRepo(connect(t)) c := connect(t)
r := db.NewClientRepo(c)
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{})
got, err := r.Metadata(nil, "noexist") got, err := m.Metadata("noexist")
if err != client.ErrorNotFound { if err != client.ErrorNotFound {
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err) t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
} }
@ -275,11 +278,11 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
t.Fatalf("expected non-nil error: %v", err) t.Fatalf("expected non-nil error: %v", err)
} }
gotAdmin, err := r.IsDexAdmin("foo") gotAdmin, err := r.Get(nil, "foo")
if err != nil { if err != nil {
t.Fatalf("expected non-nil error") t.Fatalf("expected non-nil error")
} }
if gotAdmin != admin { if gotAdmin.Admin != admin {
t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin) t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin)
} }
@ -294,7 +297,16 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
} }
func TestDBClientRepoAuthenticate(t *testing.T) { func TestDBClientRepoAuthenticate(t *testing.T) {
r := db.NewClientRepo(connect(t)) c := connect(t)
r := db.NewClientRepo(c)
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
cm := oidc.ClientMetadata{ cm := oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -302,21 +314,16 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
}, },
} }
cc, err := r.New(nil, client.Client{ cc, err := m.New(cm)
Credentials: oidc.ClientCredentials{
ID: "baz",
},
Metadata: cm,
})
if err != nil { if err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
if cc.ID != "baz" { if cc.ID != "127.0.0.1:5556" {
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID) t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
} }
ok, err := r.Authenticate(nil, *cc) ok, err := m.Authenticate(*cc)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} else if !ok { } else if !ok {
@ -337,7 +344,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)}, oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
} }
for i, c := range creds { for i, c := range creds {
ok, err := r.Authenticate(nil, c) ok, err := m.Authenticate(c)
if err != nil { if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err) t.Errorf("case %d: unexpected error: %v", i, err)
} else if ok { } else if ok {

View file

@ -3,14 +3,10 @@ package repo
import ( import (
"encoding/base64" "encoding/base64"
"net/url" "net/url"
"os"
"testing"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
) )
var ( var (
@ -47,95 +43,3 @@ var (
}, },
} }
) )
func newClientRepo(t *testing.T) client.ClientRepo {
dsn := os.Getenv("DEX_TEST_DSN")
var dbMap *gorp.DbMap
if dsn == "" {
dbMap = db.NewMemDB()
} else {
dbMap = connect(t)
}
repo, err := db.NewClientRepoFromClients(dbMap, testClients)
if err != nil {
t.Fatalf("failed to create client repo from clients: %v", err)
}
return repo
}
func TestGetSetAdminClient(t *testing.T) {
startAdmins := []string{"client2"}
tests := []struct {
// client ID
cid string
// initial state of client
wantAdmin bool
// final state of client
setAdmin bool
wantErr bool
}{
{
cid: "client1",
wantAdmin: false,
setAdmin: true,
},
{
cid: "client1",
wantAdmin: false,
setAdmin: false,
},
{
cid: "client2",
wantAdmin: true,
setAdmin: true,
},
{
cid: "client2",
wantAdmin: true,
setAdmin: false,
},
}
Tests:
for i, tt := range tests {
repo := newClientRepo(t)
for _, cid := range startAdmins {
err := repo.SetDexAdmin(cid, true)
if err != nil {
t.Errorf("case %d: failed to set dex admin: %v", i, err)
continue Tests
}
}
gotAdmin, err := repo.IsDexAdmin(tt.cid)
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
continue
}
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if gotAdmin != tt.wantAdmin {
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
}
err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
gotAdmin, err = repo.IsDexAdmin(tt.cid)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
if gotAdmin != tt.setAdmin {
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)
}
}
}

View file

@ -12,6 +12,7 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
@ -27,7 +28,7 @@ func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil { if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
t.Fatalf("Unable to add users: %v", err) t.Fatalf("Unable to add users: %v", err)
} }
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil { if _, err := manager.NewClientManagerFromClients(db.NewClientRepo(dbMap), db.TransactionFactory(dbMap), clients, manager.ManagerOptions{}); err != nil {
t.Fatalf("Unable to add clients: %v", err) t.Fatalf("Unable to add clients: %v", err)
} }
return db.NewRefreshTokenRepo(dbMap) return db.NewRefreshTokenRepo(dbMap)

View file

@ -14,6 +14,7 @@ import (
"github.com/coreos/dex/admin" "github.com/coreos/dex/admin"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/schema/adminschema" "github.com/coreos/dex/schema/adminschema"
"github.com/coreos/dex/server" "github.com/coreos/dex/server"
@ -87,12 +88,16 @@ func makeAdminAPITestFixtures() *adminAPITestFixtures {
secGen := func() ([]byte, error) { secGen := func() ([]byte, error) {
return []byte(fmt.Sprintf("client_%v", cliCount)), nil return []byte(fmt.Sprintf("client_%v", cliCount)), nil
} }
cr := db.NewClientRepoWithSecretGenerator(dbMap, secGen) cr := db.NewClientRepo(dbMap)
clientIDGenerator := func(hostport string) (string, error) {
return fmt.Sprintf("client_%v", hostport), nil
}
cm := manager.NewClientManager(cr, db.TransactionFactory(dbMap), manager.ManagerOptions{SecretGenerator: secGen, ClientIDGenerator: clientIDGenerator})
f.cr = cr f.cr = cr
f.ur = ur f.ur = ur
f.pwr = pwr f.pwr = pwr
f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, "local") f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, cm, "local")
f.adSrv = server.NewAdminServer(f.adAPI, nil, adminAPITestSecret) f.adSrv = server.NewAdminServer(f.adAPI, nil, adminAPITestSecret)
f.hSrv = httptest.NewServer(f.adSrv.HTTPHandler()) f.hSrv = httptest.NewServer(f.adSrv.HTTPHandler())
f.hc = &http.Client{ f.hc = &http.Client{
@ -268,14 +273,6 @@ func TestCreateAdmin(t *testing.T) {
} }
func TestCreateClient(t *testing.T) { func TestCreateClient(t *testing.T) {
oldGen := admin.ClientIDGenerator
admin.ClientIDGenerator = func(hostport string) (string, error) {
return fmt.Sprintf("client_%v", hostport), nil
}
defer func() {
admin.ClientIDGenerator = oldGen
}()
mustParseURL := func(s string) *url.URL { mustParseURL := func(s string) *url.URL {
u, err := url.Parse(s) u, err := url.Parse(s)
if err != nil { if err != nil {

View file

@ -14,9 +14,10 @@ import (
func TestClientCreate(t *testing.T) { func TestClientCreate(t *testing.T) {
ci := client.Client{ ci := client.Client{
// Credentials are for reference, they are actually generated by the client manager
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: "authn.example.com",
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -73,7 +74,7 @@ func TestClientCreate(t *testing.T) {
t.Error("Expected non-empty Client Secret") t.Error("Expected non-empty Client Secret")
} }
meta, err := srv.ClientRepo.Metadata(nil, newClient.Id) meta, err := srv.ClientManager.Metadata(newClient.Id)
if err != nil { if err != nil {
t.Errorf("Error looking up client metadata: %v", err) t.Errorf("Error looking up client metadata: %v", err)
} else if meta == nil { } else if meta == nil {

View file

@ -22,9 +22,10 @@ var (
clock = clockwork.NewFakeClock() clock = clockwork.NewFakeClock()
testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"} testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"}
testClientID = "XXX" testClientID = "client.example.com"
testClientSecret = base64.URLEncoding.EncodeToString([]byte("yyy")) testClientSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"} testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
testBadRedirectURL = url.URL{Scheme: "https", Host: "bad.example.com", Path: "/redirect"}
testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"} testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
testPrivKey, _ = key.GeneratePrivateKey() testPrivKey, _ = key.GeneratePrivateKey()
) )

View file

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
phttp "github.com/coreos/dex/pkg/http" phttp "github.com/coreos/dex/pkg/http"
@ -35,7 +36,15 @@ func mockServer(cis []client.Client) (*server.Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
clientRepo, err := db.NewClientRepoFromClients(dbMap, cis)
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), cis, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -45,6 +54,7 @@ func mockServer(cis []client.Client) (*server.Server, error) {
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
ClientRepo: clientRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
SessionManager: sm, SessionManager: sm,
} }
@ -82,15 +92,21 @@ func verifyUserClaims(claims jose.Claims, ci *client.Client, user *user.User, is
expectedSub, expectedName = user.ID, user.DisplayName expectedSub, expectedName = user.ID, user.DisplayName
} }
if aud := claims["aud"].(string); aud != ci.Credentials.ID { if aud, ok := claims["aud"].(string); !ok {
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", ci.Credentials.ID)
} else if aud != ci.Credentials.ID {
return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID) return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID)
} }
if sub := claims["sub"].(string); sub != expectedSub { if sub, ok := claims["sub"].(string); !ok {
return fmt.Errorf("unexpected claim value for sub, got=nil, want=%v", expectedSub)
} else if sub != expectedSub {
return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub) return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub)
} }
if name := claims["name"].(string); name != expectedName { if name, ok := claims["name"].(string); !ok {
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", expectedName)
} else if name != expectedName {
return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName) return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName)
} }
@ -117,17 +133,34 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
ID: "local", ID: "local",
} }
validRedirURL := url.URL{
Scheme: "http",
Host: "client.example.com",
Path: "/callback",
}
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: validRedirURL.Host,
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
}, },
} }
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
dbMap := db.NewMemDB() dbMap := db.NewMemDB()
cir, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci}) clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity repo: " + err.Error()) t.Fatalf("Failed to create client identity manager: " + err.Error())
} }
passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo}) passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo})
if err != nil { if err != nil {
@ -164,7 +197,8 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
IssuerURL: issuerURL, IssuerURL: issuerURL,
KeyManager: km, KeyManager: km,
SessionManager: sm, SessionManager: sm,
ClientRepo: cir, ClientRepo: clientRepo,
ClientManager: clientManager,
Templates: template.New(connector.LoginPageTemplateName), Templates: template.New(connector.LoginPageTemplateName),
Connectors: []connector.Connector{}, Connectors: []connector.Connector{},
UserRepo: userRepo, UserRepo: userRepo,
@ -188,7 +222,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
HTTPClient: sClient, HTTPClient: sClient,
ProviderConfig: pcfg, ProviderConfig: pcfg,
Credentials: ci.Credentials, Credentials: ci.Credentials,
RedirectURL: "http://client.example.com", RedirectURL: validRedirURL.String(),
KeySet: *ks, KeySet: *ks,
} }
@ -263,10 +297,20 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
} }
func TestHTTPClientCredsToken(t *testing.T) { func TestHTTPClientCredsToken(t *testing.T) {
validRedirURL := url.URL{
Scheme: "http",
Host: "client.example.com",
Path: "/callback",
}
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "72de74a9", ID: validRedirURL.Host,
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
}, },
} }
cis := []client.Client{ci} cis := []client.Client{ci}

View file

@ -18,6 +18,7 @@ import (
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/server" "github.com/coreos/dex/server"
@ -79,7 +80,7 @@ var (
}, },
} }
userBadClientID = "ZZZ" userBadClientID = testBadRedirectURL.Host
userGoodToken = makeUserToken(testIssuerURL, userGoodToken = makeUserToken(testIssuerURL,
"ID-1", testClientID, time.Hour*1, testPrivKey) "ID-1", testClientID, time.Hour*1, testPrivKey)
@ -101,8 +102,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f := &userAPITestFixtures{} f := &userAPITestFixtures{}
dbMap, _, _, um := makeUserObjects(userUsers, userPasswords) dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
cir := func() client.ClientRepo { clients := []client.Client{
repo, err := db.NewClientRepoFromClients(dbMap, []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: testClientID, ID: testClientID,
@ -121,18 +121,23 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
testRedirectURL, testBadRedirectURL,
}, },
}, },
}, },
})
if err != nil {
panic("Failed to create client identity repo: " + err.Error())
} }
return repo clientIDGenerator := func(hostport string) (string, error) {
}() return hostport, nil
}
cir.SetDexAdmin(testClientID, true) secGen := func() ([]byte, error) {
return []byte(testClientSecret), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
panic("Failed to create client identity manager: " + err.Error())
}
clientManager.SetDexAdmin(testClientID, true)
noop := func() error { return nil } noop := func() error { return nil }
@ -153,8 +158,9 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f.emailer = &testEmailer{} f.emailer = &testEmailer{}
um.Clock = clock um.Clock = clock
api := api.NewUsersAPI(dbMap, um, f.emailer, "local")
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir) api := api.NewUsersAPI(um, clientManager, refreshRepo, f.emailer, "local")
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, clientManager)
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler()) f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
f.trans = &tokenHandlerTransport{ f.trans = &tokenHandlerTransport{
@ -536,7 +542,7 @@ func TestCreateUser(t *testing.T) {
wantEmalier := testEmailer{ wantEmalier := testEmailer{
cantEmail: tt.cantEmail, cantEmail: tt.cantEmail,
lastEmail: tt.req.User.Email, lastEmail: tt.req.User.Email,
lastClientID: "XXX", lastClientID: testClientID,
lastWasInvite: true, lastWasInvite: true,
lastRedirectURL: *urlParsed, lastRedirectURL: *urlParsed,
} }
@ -799,7 +805,7 @@ func TestResendEmailInvitation(t *testing.T) {
wantEmalier := testEmailer{ wantEmalier := testEmailer{
cantEmail: tt.cantEmail, cantEmail: tt.cantEmail,
lastEmail: strings.ToLower(tt.email), lastEmail: strings.ToLower(tt.email),
lastClientID: "XXX", lastClientID: testClientID,
lastWasInvite: true, lastWasInvite: true,
lastRedirectURL: *urlParsed, lastRedirectURL: *urlParsed,
} }

View file

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/coreos/dex/client" "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
@ -14,7 +14,7 @@ import (
type clientTokenMiddleware struct { type clientTokenMiddleware struct {
issuerURL string issuerURL string
ciRepo client.ClientRepo ciManager *manager.ClientManager
keysFunc func() ([]key.PublicKey, error) keysFunc func() ([]key.PublicKey, error)
next http.Handler next http.Handler
} }
@ -30,8 +30,8 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
return return
} }
if c.ciRepo == nil { if c.ciManager == nil {
log.Errorf("Misconfigured clientTokenMiddleware, ClientRepo is not set") log.Errorf("Misconfigured clientTokenMiddleware, ClientManager is not set")
respondError() respondError()
return return
} }
@ -83,7 +83,7 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
return return
} }
md, err := c.ciRepo.Metadata(nil, clientID) md, err := c.ciManager.Metadata(clientID)
if md == nil || err != nil { if md == nil || err != nil {
log.Errorf("Failed to find clientID: %s, error=%v", clientID, err) log.Errorf("Failed to find clientID: %s, error=%v", clientID, err)
respondError() respondError()

View file

@ -1,7 +1,6 @@
package server package server
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -9,7 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/coreos/dex/client" clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
@ -25,22 +24,20 @@ func (h staticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func TestClientToken(t *testing.T) { func TestClientToken(t *testing.T) {
now := time.Now() now := time.Now()
tomorrow := now.Add(24 * time.Hour) tomorrow := now.Add(24 * time.Hour)
validClientID := "valid-client" clientMetadata := oidc.ClientMetadata{
ci := client.Client{
Credentials: oidc.ClientCredentials{
ID: validClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
{Scheme: "https", Host: "authn.example.com", Path: "/callback"}, {Scheme: "https", Host: "authn.example.com", Path: "/callback"},
}, },
},
} }
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
dbm := db.NewMemDB()
clientRepo := db.NewClientRepo(dbm)
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbm), clientmanager.ManagerOptions{})
creds, err := clientManager.New(clientMetadata)
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
validClientID := creds.ID
privKey, err := key.GeneratePrivateKey() privKey, err := key.GeneratePrivateKey()
if err != nil { if err != nil {
@ -65,63 +62,63 @@ func TestClientToken(t *testing.T) {
tests := []struct { tests := []struct {
keys []key.PublicKey keys []key.PublicKey
repo client.ClientRepo manager *clientmanager.ClientManager
header string header string
wantCode int wantCode int
}{ }{
// valid token // valid token
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusOK, wantCode: http.StatusOK,
}, },
// invalid token // invalid token
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: fmt.Sprintf("BEARER %s", invalidJWT), header: fmt.Sprintf("BEARER %s", invalidJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// empty header // empty header
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: "", header: "",
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// unparsable token // unparsable token
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: "BEARER xxx", header: "BEARER xxx",
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// no verification keys // no verification keys
{ {
keys: []key.PublicKey{}, keys: []key.PublicKey{},
repo: repo, manager: clientManager,
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// nil repo // nil repo
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: nil, manager: nil,
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// empty repo // empty repo
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: db.NewClientRepo(db.NewMemDB()), manager: clientmanager.NewClientManager(db.NewClientRepo(db.NewMemDB()), db.TransactionFactory(db.NewMemDB()), clientmanager.ManagerOptions{}),
header: fmt.Sprintf("BEARER %s", validJWT), header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
// client not in repo // client not in repo
{ {
keys: []key.PublicKey{pubKey}, keys: []key.PublicKey{pubKey},
repo: repo, manager: clientManager,
header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)), header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
wantCode: http.StatusUnauthorized, wantCode: http.StatusUnauthorized,
}, },
@ -131,7 +128,7 @@ func TestClientToken(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
mw := &clientTokenMiddleware{ mw := &clientTokenMiddleware{
issuerURL: validIss, issuerURL: validIss,
ciRepo: tt.repo, ciManager: tt.manager,
keysFunc: func() ([]key.PublicKey, error) { keysFunc: func() ([]key.PublicKey, error) {
return tt.keys, nil return tt.keys, nil
}, },

View file

@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/coreos/dex/client"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oauth2"
@ -39,18 +38,7 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR
} }
// metadata is guarenteed to have at least one redirect_uri by earlier validation. // metadata is guarenteed to have at least one redirect_uri by earlier validation.
id, err := oidc.GenClientID(clientMetadata.RedirectURIs[0].Host) creds, err := s.ClientManager.New(clientMetadata)
if err != nil {
log.Errorf("Faild to create client ID: %v", err)
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
}
creds, err := s.ClientRepo.New(nil, client.Client{
Credentials: oidc.ClientCredentials{
ID: id,
},
Metadata: clientMetadata,
})
if err != nil { if err != nil {
log.Errorf("Failed to create new client identity: %v", err) log.Errorf("Failed to create new client identity: %v", err)
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata") return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")

View file

@ -143,7 +143,7 @@ func TestClientRegistration(t *testing.T) {
return fmt.Errorf("no client id in registration response") return fmt.Errorf("no client id in registration response")
} }
metadata, err := fixtures.clientRepo.Metadata(nil, r.ClientID) metadata, err := fixtures.clientManager.Metadata(r.ClientID)
if err != nil { if err != nil {
return fmt.Errorf("failed to lookup client id after creation") return fmt.Errorf("failed to lookup client id after creation")
} }

View file

@ -6,21 +6,20 @@ import (
"net/http" "net/http"
"path" "path"
"github.com/coreos/dex/client" "github.com/coreos/dex/client/manager"
phttp "github.com/coreos/dex/pkg/http" phttp "github.com/coreos/dex/pkg/http"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/go-oidc/oidc"
) )
type clientResource struct { type clientResource struct {
repo client.ClientRepo manager *manager.ClientManager
} }
func registerClientResource(prefix string, repo client.ClientRepo) (string, http.Handler) { func registerClientResource(prefix string, manager *manager.ClientManager) (string, http.Handler) {
mux := http.NewServeMux() mux := http.NewServeMux()
c := &clientResource{ c := &clientResource{
repo: repo, manager: manager,
} }
relPath := "clients" relPath := "clients"
absPath := path.Join(prefix, relPath) absPath := path.Join(prefix, relPath)
@ -41,7 +40,7 @@ func (c *clientResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func (c *clientResource) list(w http.ResponseWriter, r *http.Request) { func (c *clientResource) list(w http.ResponseWriter, r *http.Request) {
cs, err := c.repo.All(nil) cs, err := c.manager.All()
if err != nil { if err != nil {
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients")) writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients"))
return return
@ -88,16 +87,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) {
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error())) writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error()))
return return
} }
creds, err := c.manager.New(ci.Metadata)
clientID, err := oidc.GenClientID(ci.Metadata.RedirectURIs[0].Host)
if err != nil {
log.Errorf("Failed generating ID for new client: %v", err)
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "unable to generate client ID"))
return
}
ci.Credentials.ID = clientID
creds, err := c.repo.New(nil, ci)
if err != nil { if err != nil {
log.Errorf("Failed creating client: %v", err) log.Errorf("Failed creating client: %v", err)

View file

@ -15,6 +15,7 @@ import (
"testing" "testing"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
@ -28,8 +29,10 @@ func makeBody(s string) io.ReadCloser {
func TestCreateInvalidRequest(t *testing.T) { func TestCreateInvalidRequest(t *testing.T) {
u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"} u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
h := http.Header{"Content-Type": []string{"application/json"}} h := http.Header{"Content-Type": []string{"application/json"}}
repo := db.NewClientRepo(db.NewMemDB()) dbm := db.NewMemDB()
res := &clientResource{repo: repo} repo := db.NewClientRepo(dbm)
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
res := &clientResource{manager: manager}
tests := []struct { tests := []struct {
req *http.Request req *http.Request
wantCode int wantCode int
@ -119,8 +122,10 @@ func TestCreateInvalidRequest(t *testing.T) {
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
repo := db.NewClientRepo(db.NewMemDB()) dbm := db.NewMemDB()
res := &clientResource{repo: repo} repo := db.NewClientRepo(dbm)
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
res := &clientResource{manager: manager}
tests := [][]string{ tests := [][]string{
[]string{"http://example.com"}, []string{"http://example.com"},
[]string{"https://example.com"}, []string{"https://example.com"},
@ -190,7 +195,7 @@ func TestList(t *testing.T) {
{ {
cs: []client.Client{ cs: []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")}, Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"}, url.URL{Scheme: "http", Host: "example.com"},
@ -200,7 +205,7 @@ func TestList(t *testing.T) {
}, },
want: []*schema.Client{ want: []*schema.Client{
&schema.Client{ &schema.Client{
Id: "foo", Id: "example.com",
RedirectURIs: []string{"http://example.com"}, RedirectURIs: []string{"http://example.com"},
}, },
}, },
@ -209,7 +214,7 @@ func TestList(t *testing.T) {
{ {
cs: []client.Client{ cs: []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")}, Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "http", Host: "example.com"}, url.URL{Scheme: "http", Host: "example.com"},
@ -217,21 +222,21 @@ func TestList(t *testing.T) {
}, },
}, },
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ID: "biz", Secret: b64Encode("bang")}, Credentials: oidc.ClientCredentials{ID: "example2.com", Secret: b64Encode("secret")},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"}, url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
}, },
}, },
}, },
}, },
want: []*schema.Client{ want: []*schema.Client{
&schema.Client{ &schema.Client{
Id: "biz", Id: "example2.com",
RedirectURIs: []string{"https://example.com/one/two/three"}, RedirectURIs: []string{"https://example2.com/one/two/three"},
}, },
&schema.Client{ &schema.Client{
Id: "foo", Id: "example.com",
RedirectURIs: []string{"http://example.com"}, RedirectURIs: []string{"http://example.com"},
}, },
}, },
@ -239,12 +244,20 @@ func TestList(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), tt.cs) dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), tt.cs, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err) t.Fatalf("Failed to create client identity manager: %v", err)
continue continue
} }
res := &clientResource{repo: repo} res := &clientResource{manager: clientManager}
r, err := http.NewRequest("GET", "http://example.com/clients", nil) r, err := http.NewRequest("GET", "http://example.com/clients", nil)
if err != nil { if err != nil {

View file

@ -17,6 +17,7 @@ import (
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
@ -114,9 +115,11 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
if err != nil { if err != nil {
return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err) return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
} }
ciRepo, err := db.NewClientRepoFromClients(dbMap, clients)
if err != nil { clientRepo := db.NewClientRepo(dbMap)
return fmt.Errorf("failed to create client identity repo: %v", err)
for _, c := range clients {
clientRepo.New(nil, c)
} }
f, err := os.Open(cfg.ConnectorsFile) f, err := os.Open(cfg.ConnectorsFile)
@ -155,7 +158,12 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
txnFactory := db.TransactionFactory(dbMap) txnFactory := db.TransactionFactory(dbMap)
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
srv.ClientRepo = ciRepo clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{})
if err != nil {
return fmt.Errorf("Failed to create client identity manager: %v", err)
}
srv.ClientRepo = clientRepo
srv.ClientManager = clientManager
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo srv.ConnectorConfigRepo = cfgRepo
srv.UserRepo = userRepo srv.UserRepo = userRepo
@ -253,11 +261,13 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
userRepo := db.NewUserRepo(dbc) userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
clientManager := clientmanager.NewClientManager(ciRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc) refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := sessionmanager.NewSessionManager(sRepo, skRepo) sm := sessionmanager.NewSessionManager(sRepo, skRepo)
srv.ClientRepo = ciRepo srv.ClientRepo = ciRepo
srv.ClientManager = clientManager
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo srv.ConnectorConfigRepo = cfgRepo
srv.UserRepo = userRepo srv.UserRepo = userRepo

View file

@ -12,6 +12,7 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
@ -28,7 +29,7 @@ func handleVerifyEmailResendFunc(
srvKeysFunc func() ([]key.PublicKey, error), srvKeysFunc func() ([]key.PublicKey, error),
emailer *useremail.UserEmailer, emailer *useremail.UserEmailer,
userRepo user.UserRepo, userRepo user.UserRepo,
clientRepo client.ClientRepo) http.HandlerFunc { clientManager *clientmanager.ClientManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body) decoder := json.NewDecoder(r.Body)
var params struct { var params struct {
@ -57,7 +58,7 @@ func handleVerifyEmailResendFunc(
return return
} }
cm, err := clientRepo.Metadata(nil, clientID) cm, err := clientManager.Metadata(clientID)
if err == client.ErrorNotFound { if err == client.ErrorNotFound {
log.Errorf("No such client: %v", err) log.Errorf("No such client: %v", err)
writeAPIError(w, http.StatusBadRequest, writeAPIError(w, http.StatusBadRequest,

View file

@ -130,7 +130,7 @@ func TestHandleVerifyEmailResend(t *testing.T) {
keysFunc, keysFunc,
f.srv.UserEmailer, f.srv.UserEmailer,
f.userRepo, f.userRepo,
f.clientRepo) f.clientManager)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u := "http://example.com" u := "http://example.com"

View file

@ -17,6 +17,7 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/session/manager" "github.com/coreos/dex/session/manager"
@ -75,15 +76,12 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
idpcs := []connector.Connector{ idpcs := []connector.Connector{
&fakeConnector{loginURL: "http://fake.example.com"}, &fakeConnector{loginURL: "http://fake.example.com"},
} }
srv := &Server{ dbm := db.NewMemDB()
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, clients := []client.Client{
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientRepo: func() client.ClientRepo {
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "client.example.com",
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -91,12 +89,24 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
}, },
}, },
}, },
})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
} }
return repo
}(), clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
}
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientRepo: clientRepo,
ClientManager: clientManager,
} }
tests := []struct { tests := []struct {
@ -108,7 +118,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
{ {
query: url.Values{ query: url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"client_id": []string{"XXX"}, "client_id": []string{"client.example.com"},
"connector_id": []string{"fake"}, "connector_id": []string{"fake"},
"scope": []string{"openid"}, "scope": []string{"openid"},
}, },
@ -121,7 +131,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
query: url.Values{ query: url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"redirect_uri": []string{"http://client.example.com/callback"}, "redirect_uri": []string{"http://client.example.com/callback"},
"client_id": []string{"XXX"}, "client_id": []string{"client.example.com"},
"connector_id": []string{"fake"}, "connector_id": []string{"fake"},
"scope": []string{"openid"}, "scope": []string{"openid"},
}, },
@ -134,7 +144,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
query: url.Values{ query: url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"redirect_uri": []string{"http://unrecognized.example.com/callback"}, "redirect_uri": []string{"http://unrecognized.example.com/callback"},
"client_id": []string{"XXX"}, "client_id": []string{"client.example.com"},
"connector_id": []string{"fake"}, "connector_id": []string{"fake"},
"scope": []string{"openid"}, "scope": []string{"openid"},
}, },
@ -157,7 +167,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
{ {
query: url.Values{ query: url.Values{
"response_type": []string{"token"}, "response_type": []string{"token"},
"client_id": []string{"XXX"}, "client_id": []string{"client.example.com"},
"connector_id": []string{"fake"}, "connector_id": []string{"fake"},
"scope": []string{"openid"}, "scope": []string{"openid"},
}, },
@ -170,11 +180,33 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
query: url.Values{ query: url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"redirect_uri": []string{"http://client.example.com/callback"}, "redirect_uri": []string{"http://client.example.com/callback"},
"client_id": []string{"XXX"}, "client_id": []string{"client.example.com"},
"connector_id": []string{"fake"}, "connector_id": []string{"fake"},
}, },
wantCode: http.StatusBadRequest, wantCode: http.StatusBadRequest,
}, },
// empty response_type
{
query: url.Values{
"redirect_uri": []string{"http://client.example.com/callback"},
"client_id": []string{"client.example.com"},
"connector_id": []string{"fake"},
"scope": []string{"openid"},
},
wantCode: http.StatusFound,
wantLocation: "http://client.example.com/callback?error=unsupported_response_type&state=",
},
// empty client_id
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
"connector_id": []string{"fake"},
"scope": []string{"openid"},
},
wantCode: http.StatusBadRequest,
},
} }
for i, tt := range tests { for i, tt := range tests {
@ -204,14 +236,12 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
idpcs := []connector.Connector{ idpcs := []connector.Connector{
&fakeConnector{loginURL: "http://fake.example.com"}, &fakeConnector{loginURL: "http://fake.example.com"},
} }
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, dbm := db.NewMemDB()
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())), clients := []client.Client{
ClientRepo: func() client.ClientRepo {
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: "foo.example.com",
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
@ -221,12 +251,24 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
}, },
}, },
}, },
})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
} }
return repo
}(), clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
}
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientRepo: clientRepo,
ClientManager: clientManager,
} }
tests := []struct { tests := []struct {
@ -239,7 +281,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
query: url.Values{ query: url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"redirect_uri": []string{"http://foo.example.com/callback"}, "redirect_uri": []string{"http://foo.example.com/callback"},
"client_id": []string{"XXX"}, "client_id": []string{"foo.example.com"},
"connector_id": []string{"fake"}, "connector_id": []string{"fake"},
"scope": []string{"openid"}, "scope": []string{"openid"},
}, },
@ -252,7 +294,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
query: url.Values{ query: url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"redirect_uri": []string{"http://bar.example.com/callback"}, "redirect_uri": []string{"http://bar.example.com/callback"},
"client_id": []string{"XXX"}, "client_id": []string{"foo.example.com"},
"connector_id": []string{"fake"}, "connector_id": []string{"fake"},
"scope": []string{"openid"}, "scope": []string{"openid"},
}, },
@ -265,7 +307,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
query: url.Values{ query: url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"redirect_uri": []string{"http://unrecognized.example.com/callback"}, "redirect_uri": []string{"http://unrecognized.example.com/callback"},
"client_id": []string{"XXX"}, "client_id": []string{"foo.example.com"},
"connector_id": []string{"fake"}, "connector_id": []string{"fake"},
"scope": []string{"openid"}, "scope": []string{"openid"},
}, },
@ -276,7 +318,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
{ {
query: url.Values{ query: url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"client_id": []string{"XXX"}, "client_id": []string{"foo.example.com"},
"connector_id": []string{"fake"}, "connector_id": []string{"fake"},
"scope": []string{"openid"}, "scope": []string{"openid"},
}, },
@ -328,8 +370,8 @@ func TestHandleTokenFunc(t *testing.T) {
"grant_type": []string{"invalid!"}, "grant_type": []string{"invalid!"},
"code": []string{"someCode"}, "code": []string{"someCode"},
}, },
user: "XXX", user: testClientID,
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")), passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
wantCode: http.StatusBadRequest, wantCode: http.StatusBadRequest,
}, },
@ -338,8 +380,8 @@ func TestHandleTokenFunc(t *testing.T) {
query: url.Values{ query: url.Values{
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
}, },
user: "XXX", user: testClientID,
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")), passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
wantCode: http.StatusBadRequest, wantCode: http.StatusBadRequest,
}, },
@ -349,8 +391,8 @@ func TestHandleTokenFunc(t *testing.T) {
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
"code": []string{""}, "code": []string{""},
}, },
user: "XXX", user: testClientID,
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")), passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
wantCode: http.StatusBadRequest, wantCode: http.StatusBadRequest,
}, },
@ -371,8 +413,8 @@ func TestHandleTokenFunc(t *testing.T) {
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
"code": []string{"asdasd"}, "code": []string{"asdasd"},
}, },
user: "XXX", user: testClientID,
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")), passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
wantCode: http.StatusBadRequest, wantCode: http.StatusBadRequest,
}, },
@ -382,8 +424,8 @@ func TestHandleTokenFunc(t *testing.T) {
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
"code": []string{"code-2"}, "code": []string{"code-2"},
}, },
user: "XXX", user: testClientID,
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")), passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
wantCode: http.StatusOK, wantCode: http.StatusOK,
}, },
} }
@ -402,7 +444,7 @@ func TestHandleTokenFunc(t *testing.T) {
// need to create session in order to exchange the code (generated by the NewSessionKey func) for token // need to create session in order to exchange the code (generated by the NewSessionKey func) for token
setSession := func() error { setSession := func() error {
sid, err := fx.sessionManager.NewSession("local", "XXX", "", testRedirectURL, "", true, []string{"openid"}) sid, err := fx.sessionManager.NewSession("local", testClientID, "", testRedirectURL, "", true, []string{"openid"})
if err != nil { if err != nil {
return fmt.Errorf("case %d: cannot create session, error=%v", i, err) return fmt.Errorf("case %d: cannot create session, error=%v", i, err)
} }

View file

@ -8,6 +8,7 @@ import (
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
sessionmanager "github.com/coreos/dex/session/manager" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
@ -29,7 +30,7 @@ type SendResetPasswordEmailHandler struct {
tpl *template.Template tpl *template.Template
emailer *useremail.UserEmailer emailer *useremail.UserEmailer
sm *sessionmanager.SessionManager sm *sessionmanager.SessionManager
cr client.ClientRepo cm *clientmanager.ClientManager
} }
func (h *SendResetPasswordEmailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *SendResetPasswordEmailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -128,7 +129,7 @@ func (h *SendResetPasswordEmailHandler) validateRedirectURL(clientID string, red
return url.URL{}, false return url.URL{}, false
} }
cm, err := h.cr.Metadata(nil, clientID) cm, err := h.cm.Metadata(clientID)
if err != nil || cm == nil { if err != nil || cm == nil {
log.Errorf("Error getting ClientMetadata: %v", err) log.Errorf("Error getting ClientMetadata: %v", err)
return url.URL{}, false return url.URL{}, false

View file

@ -253,7 +253,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
t.Fatalf("case %d: could not make test fixtures: %v", i, err) t.Fatalf("case %d: could not make test fixtures: %v", i, err)
} }
_, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, []string{"openid"}) _, err = f.srv.NewSession("local", testClientID, "", f.redirectURL, "", true, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("case %d: could not create new session: %v", i, err) t.Fatalf("case %d: could not create new session: %v", i, err)
} }
@ -267,7 +267,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
tpl: f.srv.SendResetPasswordEmailTemplate, tpl: f.srv.SendResetPasswordEmailTemplate,
emailer: f.srv.UserEmailer, emailer: f.srv.UserEmailer,
sm: f.sessionManager, sm: f.sessionManager,
cr: f.clientRepo, cm: f.clientManager,
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()

View file

@ -295,7 +295,7 @@ func TestHandleRegister(t *testing.T) {
}) })
} }
key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"}) key, err := f.srv.NewSession(tt.connID, testClientID, "", f.redirectURL, "", true, []string{"openid"})
t.Logf("case %d: key for NewSession: %v", i, key) t.Logf("case %d: key for NewSession: %v", i, key)
if tt.attachRemote { if tt.attachRemote {

View file

@ -19,6 +19,7 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
@ -72,6 +73,7 @@ type Server struct {
Connectors []connector.Connector Connectors []connector.Connector
UserRepo user.UserRepo UserRepo user.UserRepo
UserManager *usermanager.UserManager UserManager *usermanager.UserManager
ClientManager *clientmanager.ClientManager
PasswordInfoRepo user.PasswordInfoRepo PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer UserEmailer *useremail.UserEmailer
@ -213,13 +215,13 @@ func (s *Server) HTTPHandler() http.Handler {
s.KeyManager.PublicKeys, s.KeyManager.PublicKeys,
s.UserEmailer, s.UserEmailer,
s.UserRepo, s.UserRepo,
s.ClientRepo))) s.ClientManager)))
mux.Handle(httpPathSendResetPassword, &SendResetPasswordEmailHandler{ mux.Handle(httpPathSendResetPassword, &SendResetPasswordEmailHandler{
tpl: s.SendResetPasswordEmailTemplate, tpl: s.SendResetPasswordEmailTemplate,
emailer: s.UserEmailer, emailer: s.UserEmailer,
sm: s.SessionManager, sm: s.SessionManager,
cr: s.ClientRepo, cm: s.ClientManager,
}) })
mux.Handle(httpPathResetPassword, &ResetPasswordHandler{ mux.Handle(httpPathResetPassword, &ResetPasswordHandler{
@ -256,11 +258,11 @@ func (s *Server) HTTPHandler() http.Handler {
apiBasePath := path.Join(httpPathAPI, APIVersion) apiBasePath := path.Join(httpPathAPI, APIVersion)
registerDiscoveryResource(apiBasePath, mux) registerDiscoveryResource(apiBasePath, mux)
clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientRepo) clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientManager)
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler)) mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
usersAPI := usersapi.NewUsersAPI(s.dbMap, s.UserManager, s.UserEmailer, s.localConnectorID) usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientManager, s.RefreshTokenRepo, s.UserEmailer, s.localConnectorID)
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientRepo).HTTPHandler() handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientManager).HTTPHandler()
mux.Handle(apiBasePath+"/", handler) mux.Handle(apiBasePath+"/", handler)
@ -271,14 +273,14 @@ func (s *Server) HTTPHandler() http.Handler {
func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler { func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler {
return &clientTokenMiddleware{ return &clientTokenMiddleware{
issuerURL: s.IssuerURL.String(), issuerURL: s.IssuerURL.String(),
ciRepo: s.ClientRepo, ciManager: s.ClientManager,
keysFunc: s.KeyManager.PublicKeys, keysFunc: s.KeyManager.PublicKeys,
next: handler, next: handler,
} }
} }
func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) { func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) {
return s.ClientRepo.Metadata(nil, clientID) return s.ClientManager.Metadata(clientID)
} }
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) { func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
@ -365,9 +367,9 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
} }
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) { func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
ok, err := s.ClientRepo.Authenticate(nil, creds) ok, err := s.ClientManager.Authenticate(creds)
if err != nil { if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, oauth2.NewError(oauth2.ErrorServerError)
} }
if !ok { if !ok {
@ -397,7 +399,7 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro
} }
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) { func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
ok, err := s.ClientRepo.Authenticate(nil, creds) ok, err := s.ClientManager.Authenticate(creds)
if err != nil { if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError) return nil, "", oauth2.NewError(oauth2.ErrorServerError)
@ -466,7 +468,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
} }
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) { func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
ok, err := s.ClientRepo.Authenticate(nil, creds) ok, err := s.ClientManager.Authenticate(creds)
if err != nil { if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, oauth2.NewError(oauth2.ErrorServerError)

View file

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/refresh/refreshtest" "github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/session/manager" "github.com/coreos/dex/session/manager"
@ -21,7 +22,12 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
) )
var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secrete")) var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
var validRedirURL = url.URL{
Scheme: "http",
Host: "client.example.com",
Path: "/callback",
}
type StaticKeyManager struct { type StaticKeyManager struct {
key.PrivateKeyManager key.PrivateKeyManager
@ -132,8 +138,8 @@ func TestServerNewSession(t *testing.T) {
nonce := "oncenay" nonce := "oncenay"
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: testClientID,
Secret: "secrete", Secret: clientTestSecret,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -181,7 +187,7 @@ func TestServerNewSession(t *testing.T) {
func TestServerLogin(t *testing.T) { func TestServerLogin(t *testing.T) {
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: testClientID,
Secret: clientTestSecret, Secret: clientTestSecret,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
@ -194,13 +200,13 @@ func TestServerLogin(t *testing.T) {
}, },
}, },
} }
ciRepo := func() client.ClientRepo {
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci}) dbm := db.NewMemDB()
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), []client.Client{ci}, clientmanager.ManagerOptions{})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err) t.Fatalf("Failed to create client identity manager: %v", err)
} }
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
@ -222,7 +228,8 @@ func TestServerLogin(t *testing.T) {
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
SessionManager: sm, SessionManager: sm,
ClientRepo: ciRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo, UserRepo: userRepo,
} }
@ -244,20 +251,30 @@ func TestServerLogin(t *testing.T) {
} }
func TestServerLoginUnrecognizedSessionKey(t *testing.T) { func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
ciRepo := func() client.ClientRepo { clients := []client.Client{
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", Secret: clientTestSecret, ID: testClientID, Secret: clientTestSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
}, },
}, },
})
if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err)
} }
return repo dbm := db.NewMemDB()
}() clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
}
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: nil, err: errors.New("fail")}, signer: &StaticSigner{sig: nil, err: errors.New("fail")},
} }
@ -266,11 +283,12 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
SessionManager: sm, SessionManager: sm,
ClientRepo: ciRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
} }
ident := oidc.Identity{ID: "YYY", Name: "elroy", Email: "elroy@example.com"} ident := oidc.Identity{ID: "YYY", Name: "elroy", Email: "elroy@example.com"}
code, err := srv.Login(ident, "XXX") code, err := srv.Login(ident, testClientID)
if err == nil { if err == nil {
t.Fatalf("Expected non-nil error") t.Fatalf("Expected non-nil error")
} }
@ -283,27 +301,28 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
func TestServerLoginDisabledUser(t *testing.T) { func TestServerLoginDisabledUser(t *testing.T) {
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: testClientID,
Secret: clientTestSecret, Secret: clientTestSecret,
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
url.URL{ validRedirURL,
Scheme: "http",
Host: "client.example.com",
Path: "/callback",
},
}, },
}, },
} }
ciRepo := func() client.ClientRepo { clients := []client.Client{ci}
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci}) dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err) t.Fatalf("Failed to create client identity manager: %v", err)
} }
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
@ -338,7 +357,8 @@ func TestServerLoginDisabledUser(t *testing.T) {
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
SessionManager: sm, SessionManager: sm,
ClientRepo: ciRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo, UserRepo: userRepo,
} }
@ -357,17 +377,28 @@ func TestServerLoginDisabledUser(t *testing.T) {
func TestServerCodeToken(t *testing.T) { func TestServerCodeToken(t *testing.T) {
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: testClientID,
Secret: clientTestSecret, Secret: clientTestSecret,
}, },
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
} }
ciRepo := func() client.ClientRepo { clients := []client.Client{ci}
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci}) dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err) t.Fatalf("Failed to create client identity manager: %v", err)
} }
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
@ -384,7 +415,8 @@ func TestServerCodeToken(t *testing.T) {
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
SessionManager: sm, SessionManager: sm,
ClientRepo: ciRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo, UserRepo: userRepo,
RefreshTokenRepo: refreshTokenRepo, RefreshTokenRepo: refreshTokenRepo,
} }
@ -443,17 +475,29 @@ func TestServerCodeToken(t *testing.T) {
func TestServerTokenUnrecognizedKey(t *testing.T) { func TestServerTokenUnrecognizedKey(t *testing.T) {
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: testClientID,
Secret: clientTestSecret, Secret: clientTestSecret,
}, },
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
} }
ciRepo := func() client.ClientRepo {
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci}) clients := []client.Client{ci}
dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity repo: %v", err) t.Fatalf("Failed to create client identity manager: %v", err)
} }
return repo
}()
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
@ -463,7 +507,8 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
SessionManager: sm, SessionManager: sm,
ClientRepo: ciRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
} }
sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"}) sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
@ -492,7 +537,7 @@ func TestServerTokenFail(t *testing.T) {
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
keyFixture := "goodkey" keyFixture := "goodkey"
ccFixture := oidc.ClientCredentials{ ccFixture := oidc.ClientCredentials{
ID: "XXX", ID: testClientID,
Secret: clientTestSecret, Secret: clientTestSecret,
} }
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
@ -569,14 +614,29 @@ func TestServerTokenFail(t *testing.T) {
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: tt.signer, signer: tt.signer,
} }
ciRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
client.Client{Credentials: ccFixture},
})
if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
continue
}
clients := []client.Client{
client.Client{
Credentials: ccFixture,
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
},
}
dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
}
_, err = sm.AttachUser(sessionID, "testid-1") _, err = sm.AttachUser(sessionID, "testid-1")
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Fatalf("case %d: unexpected error: %v", i, err)
@ -593,7 +653,8 @@ func TestServerTokenFail(t *testing.T) {
IssuerURL: issuerURL, IssuerURL: issuerURL,
KeyManager: km, KeyManager: km,
SessionManager: sm, SessionManager: sm,
ClientRepo: ciRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo, UserRepo: userRepo,
RefreshTokenRepo: refreshTokenRepo, RefreshTokenRepo: refreshTokenRepo,
} }
@ -623,14 +684,27 @@ func TestServerTokenFail(t *testing.T) {
func TestServerRefreshToken(t *testing.T) { func TestServerRefreshToken(t *testing.T) {
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
clientA := client.Client{
credXXX := oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: testClientID,
Secret: clientTestSecret, Secret: clientTestSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "client.example.com", Path: "one/two/three"},
},
},
} }
credYYY := oidc.ClientCredentials{ clientB := client.Client{
ID: "YYY", Credentials: oidc.ClientCredentials{
ID: "example2.com",
Secret: clientTestSecret, Secret: clientTestSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
},
},
} }
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
@ -647,47 +721,47 @@ func TestServerRefreshToken(t *testing.T) {
// Everything is good. // Everything is good.
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", clientA.Credentials.ID,
credXXX, clientA.Credentials,
signerFixture, signerFixture,
nil, nil,
}, },
// Invalid refresh token(malformatted). // Invalid refresh token(malformatted).
{ {
"invalid-token", "invalid-token",
"XXX", clientA.Credentials.ID,
credXXX, clientA.Credentials,
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest), oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid refresh token(invalid payload content). // Invalid refresh token(invalid payload content).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
"XXX", clientA.Credentials.ID,
credXXX, clientA.Credentials,
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest), oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid refresh token(invalid ID content). // Invalid refresh token(invalid ID content).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", clientA.Credentials.ID,
credXXX, clientA.Credentials,
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest), oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid client(client is not associated with the token). // Invalid client(client is not associated with the token).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", clientA.Credentials.ID,
credYYY, clientB.Credentials,
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no client ID). // Invalid client(no client ID).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", clientA.Credentials.ID,
oidc.ClientCredentials{ID: "", Secret: "aaa"}, oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
@ -695,7 +769,7 @@ func TestServerRefreshToken(t *testing.T) {
// Invalid client(no such client). // Invalid client(no such client).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", clientA.Credentials.ID,
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
@ -703,24 +777,24 @@ func TestServerRefreshToken(t *testing.T) {
// Invalid client(no secrets). // Invalid client(no secrets).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", clientA.Credentials.ID,
oidc.ClientCredentials{ID: "XXX"}, oidc.ClientCredentials{ID: testClientID},
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(invalid secret). // Invalid client(invalid secret).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", clientA.Credentials.ID,
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"}, oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Signing operation fails. // Signing operation fails.
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", clientA.Credentials.ID,
credXXX, clientA.Credentials,
&StaticSigner{sig: nil, err: errors.New("fail")}, &StaticSigner{sig: nil, err: errors.New("fail")},
oauth2.NewError(oauth2.ErrorServerError), oauth2.NewError(oauth2.ErrorServerError),
}, },
@ -731,15 +805,23 @@ func TestServerRefreshToken(t *testing.T) {
signer: tt.signer, signer: tt.signer,
} }
ciRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ clients := []client.Client{
client.Client{Credentials: credXXX}, clientA,
client.Client{Credentials: credYYY}, clientB,
})
if err != nil {
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
continue
} }
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
dbm := db.NewMemDB()
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
}
userRepo, err := makeNewUserRepo() userRepo, err := makeNewUserRepo()
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
@ -750,7 +832,8 @@ func TestServerRefreshToken(t *testing.T) {
srv := &Server{ srv := &Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
KeyManager: km, KeyManager: km,
ClientRepo: ciRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo, UserRepo: userRepo,
RefreshTokenRepo: refreshTokenRepo, RefreshTokenRepo: refreshTokenRepo,
} }
@ -772,7 +855,7 @@ func TestServerRefreshToken(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Case %d: unexpected error: %v", i, err) t.Errorf("Case %d: unexpected error: %v", i, err)
} }
if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != "XXX" { if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != testClientID {
t.Errorf("Case %d: invalid claims: %v", i, claims) t.Errorf("Case %d: invalid claims: %v", i, claims)
} }
} }
@ -784,14 +867,22 @@ func TestServerRefreshToken(t *testing.T) {
signer: signerFixture, signer: signerFixture,
} }
ciRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ clients := []client.Client{
client.Client{Credentials: credXXX}, clientA,
client.Client{Credentials: credYYY}, clientB,
}) }
if err != nil { clientIDGenerator := func(hostport string) (string, error) {
t.Fatalf("failed to create client identity repo: %v", err) return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
dbm := db.NewMemDB()
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
} }
userRepo, err := makeNewUserRepo() userRepo, err := makeNewUserRepo()
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
@ -810,12 +901,13 @@ func TestServerRefreshToken(t *testing.T) {
srv := &Server{ srv := &Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
KeyManager: km, KeyManager: km,
ClientRepo: ciRepo, ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo, UserRepo: userRepo,
RefreshTokenRepo: refreshTokenRepo, RefreshTokenRepo: refreshTokenRepo,
} }
if _, err := refreshTokenRepo.Create("testid-2", credXXX.ID); err != nil { if _, err := refreshTokenRepo.Create("testid-2", clientA.Credentials.ID); err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
@ -826,7 +918,7 @@ func TestServerRefreshToken(t *testing.T) {
} }
srv.UserRepo = userRepo srv.UserRepo = userRepo
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1")))) _, err = srv.RefreshToken(clientA.Credentials, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) { if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err) t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
} }

View file

@ -10,6 +10,7 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
@ -26,7 +27,7 @@ const (
var ( var (
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"} testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
testClientID = "XXX" testClientID = "client.example.com"
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"} testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
@ -79,6 +80,7 @@ type testFixtures struct {
emailer *email.TemplatizedEmailer emailer *email.TemplatizedEmailer
redirectURL url.URL redirectURL url.URL
clientRepo client.ClientRepo clientRepo client.ClientRepo
clientManager *clientmanager.ClientManager
} }
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc { func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
@ -123,7 +125,7 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err return nil, err
} }
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sessionManager.GenerateCode = sequentialGenerateCodeFunc() sessionManager.GenerateCode = sequentialGenerateCodeFunc()
@ -136,11 +138,11 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err return nil, err
} }
clientRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ clients := []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: testClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -148,11 +150,19 @@ func makeTestFixtures() (*testFixtures, error) {
}, },
}, },
}, },
}) }
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
return nil, err return nil, err
} }
km := key.NewPrivateKeyManager() km := key.NewPrivateKeyManager()
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute))) err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)))
if err != nil { if err != nil {
@ -173,7 +183,8 @@ func makeTestFixtures() (*testFixtures, error) {
Templates: tpl, Templates: tpl,
UserRepo: userRepo, UserRepo: userRepo,
PasswordInfoRepo: pwRepo, PasswordInfoRepo: pwRepo,
UserManager: manager, UserManager: userManager,
ClientManager: clientManager,
KeyManager: km, KeyManager: km,
} }
@ -207,5 +218,6 @@ func makeTestFixtures() (*testFixtures, error) {
sessionManager: sessionManager, sessionManager: sessionManager,
emailer: emailer, emailer: emailer,
clientRepo: clientRepo, clientRepo: clientRepo,
clientManager: clientManager,
}, nil }, nil
} }

View file

@ -11,12 +11,12 @@ import (
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/coreos/dex/client" clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/api" "github.com/coreos/dex/user/api"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
const ( const (
@ -38,16 +38,16 @@ var (
type UserMgmtServer struct { type UserMgmtServer struct {
api *api.UsersAPI api *api.UsersAPI
jwtvFactory JWTVerifierFactory jwtvFactory JWTVerifierFactory
um *manager.UserManager um *usermanager.UserManager
cir client.ClientRepo cm *clientmanager.ClientManager
} }
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *manager.UserManager, cir client.ClientRepo) *UserMgmtServer { func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *usermanager.UserManager, cm *clientmanager.ClientManager) *UserMgmtServer {
return &UserMgmtServer{ return &UserMgmtServer{
api: userMgmtAPI, api: userMgmtAPI,
jwtvFactory: jwtvFactory, jwtvFactory: jwtvFactory,
um: um, um: um,
cir: cir, cm: cm,
} }
} }
@ -295,7 +295,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request, requiresAdmin bool) (api.Cred
return api.Creds{}, err return api.Creds{}, err
} }
isAdmin, err := s.cir.IsDexAdmin(clientID) isAdmin, err := s.cm.IsDexAdmin(clientID)
if err != nil { if err != nil {
log.Errorf("userMgmtServer: GetCreds err: %q", err) log.Errorf("userMgmtServer: GetCreds err: %q", err)
return api.Creds{}, err return api.Creds{}, err

2
test
View file

@ -18,7 +18,7 @@ if [ ! -d $GOPATH/pkg ]; then
echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests." echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests."
fi fi
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin" TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin client client/manager"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log" FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override # user has not provided PKG override

View file

@ -9,15 +9,13 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/go-gorp/gorp"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db" clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
var ( var (
@ -88,9 +86,9 @@ func (e Error) Error() string {
// calling User. It is assumed that the clientID has already validated as an // calling User. It is assumed that the clientID has already validated as an
// admin app before calling. // admin app before calling.
type UsersAPI struct { type UsersAPI struct {
manager *manager.UserManager userManager *usermanager.UserManager
localConnectorID string localConnectorID string
clientRepo client.ClientRepo clientManager *clientmanager.ClientManager
refreshRepo refresh.RefreshTokenRepo refreshRepo refresh.RefreshTokenRepo
emailer Emailer emailer Emailer
} }
@ -105,11 +103,11 @@ type Creds struct {
} }
// TODO(ericchiang): Don't pass a dbMap. See #385. // TODO(ericchiang): Don't pass a dbMap. See #385.
func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI { func NewUsersAPI(userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, refreshRepo refresh.RefreshTokenRepo, emailer Emailer, localConnectorID string) *UsersAPI {
return &UsersAPI{ return &UsersAPI{
manager: userManager, userManager: userManager,
refreshRepo: db.NewRefreshTokenRepo(dbMap), refreshRepo: refreshRepo,
clientRepo: db.NewClientRepo(dbMap), clientManager: clientManager,
localConnectorID: localConnectorID, localConnectorID: localConnectorID,
emailer: emailer, emailer: emailer,
} }
@ -122,7 +120,7 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
return schema.User{}, ErrorUnauthorized return schema.User{}, ErrorUnauthorized
} }
usr, err := u.manager.Get(id) usr, err := u.userManager.Get(id)
if err != nil { if err != nil {
return schema.User{}, mapError(err) return schema.User{}, mapError(err)
@ -137,7 +135,7 @@ func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema
return schema.UserDisableResponse{}, ErrorUnauthorized return schema.UserDisableResponse{}, ErrorUnauthorized
} }
if err := u.manager.Disable(userID, disable); err != nil { if err := u.userManager.Disable(userID, disable); err != nil {
return schema.UserDisableResponse{}, mapError(err) return schema.UserDisableResponse{}, mapError(err)
} }
@ -157,7 +155,7 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
return schema.UserCreateResponse{}, mapError(err) return schema.UserCreateResponse{}, mapError(err)
} }
metadata, err := u.clientRepo.Metadata(nil, creds.ClientID) metadata, err := u.clientManager.Metadata(creds.ClientID)
if err != nil { if err != nil {
return schema.UserCreateResponse{}, mapError(err) return schema.UserCreateResponse{}, mapError(err)
} }
@ -167,12 +165,12 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
return schema.UserCreateResponse{}, ErrorInvalidRedirectURL return schema.UserCreateResponse{}, ErrorInvalidRedirectURL
} }
id, err := u.manager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID) id, err := u.userManager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID)
if err != nil { if err != nil {
return schema.UserCreateResponse{}, mapError(err) return schema.UserCreateResponse{}, mapError(err)
} }
userUser, err := u.manager.Get(id) userUser, err := u.userManager.Get(id)
if err != nil { if err != nil {
return schema.UserCreateResponse{}, mapError(err) return schema.UserCreateResponse{}, mapError(err)
} }
@ -202,7 +200,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized
} }
metadata, err := u.clientRepo.Metadata(nil, creds.ClientID) metadata, err := u.clientManager.Metadata(creds.ClientID)
if err != nil { if err != nil {
return schema.ResendEmailInvitationResponse{}, mapError(err) return schema.ResendEmailInvitationResponse{}, mapError(err)
} }
@ -213,7 +211,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
} }
// Retrieve user to check if it's already created // Retrieve user to check if it's already created
userUser, err := u.manager.Get(userID) userUser, err := u.userManager.Get(userID)
if err != nil { if err != nil {
return schema.ResendEmailInvitationResponse{}, mapError(err) return schema.ResendEmailInvitationResponse{}, mapError(err)
} }
@ -251,7 +249,7 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string)
return nil, "", ErrorMaxResultsTooHigh return nil, "", ErrorMaxResultsTooHigh
} }
users, tok, err := u.manager.List(user.UserFilter{}, maxResults, nextPageToken) users, tok, err := u.userManager.List(user.UserFilter{}, maxResults, nextPageToken)
if err != nil { if err != nil {
return nil, "", mapError(err) return nil, "", mapError(err)
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
schema "github.com/coreos/dex/schema/workerschema" schema "github.com/coreos/dex/schema/workerschema"
@ -51,13 +52,14 @@ func (t *testEmailer) sendEmail(email string, redirectURL url.URL, clientID stri
var ( var (
clock = clockwork.NewFakeClock() clock = clockwork.NewFakeClock()
goodClientID = "client.example.com"
goodCreds = Creds{ goodCreds = Creds{
User: user.User{ User: user.User{
ID: "ID-1", ID: "ID-1",
Admin: true, Admin: true,
}, },
ClientID: "XXX", ClientID: goodClientID,
} }
badCreds = Creds{ badCreds = Creds{
@ -72,7 +74,7 @@ var (
Admin: true, Admin: true,
Disabled: true, Disabled: true,
}, },
ClientID: "XXX", ClientID: goodClientID,
} }
resetPasswordURL = url.URL{ resetPasswordURL = url.URL{
@ -82,7 +84,7 @@ var (
validRedirURL = url.URL{ validRedirURL = url.URL{
Scheme: "http", Scheme: "http",
Host: "client.example.com", Host: goodClientID,
Path: "/callback", Path: "/callback",
} }
) )
@ -158,8 +160,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
mgr.Clock = clock mgr.Clock = clock
ci := client.Client{ ci := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "XXX", ID: goodClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")), Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
@ -167,8 +169,17 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}, },
}, },
} }
if _, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci}); err != nil {
panic("Failed to create client repo: " + err.Error()) clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbMap)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
panic("Failed to create client manager: " + err.Error())
} }
// Used in TestRevokeRefreshToken test. // Used in TestRevokeRefreshToken test.
@ -176,8 +187,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
clientID string clientID string
userID string userID string
}{ }{
{"XXX", "ID-1"}, {goodClientID, "ID-1"},
{"XXX", "ID-2"}, {goodClientID, "ID-2"},
} }
refreshRepo := db.NewRefreshTokenRepo(dbMap) refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens { for _, token := range refreshTokens {
@ -187,7 +198,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
} }
emailer := &testEmailer{} emailer := &testEmailer{}
api := NewUsersAPI(dbMap, mgr, emailer, "local") api := NewUsersAPI(mgr, clientManager, refreshRepo, emailer, "local")
return api, emailer return api, emailer
} }
@ -582,8 +593,8 @@ func TestRevokeRefreshToken(t *testing.T) {
before []string // clientIDs expected before the change. before []string // clientIDs expected before the change.
after []string // clientIDs expected after the change. after []string // clientIDs expected after the change.
}{ }{
{"ID-1", "XXX", []string{"XXX"}, []string{}}, {"ID-1", goodClientID, []string{goodClientID}, []string{}},
{"ID-2", "XXX", []string{"XXX"}, []string{}}, {"ID-2", goodClientID, []string{goodClientID}, []string{}},
} }
api, _ := makeTestFixtures() api, _ := makeTestFixtures()