forked from mystiq/dex
Merge pull request #442 from ecordell/client-manager
Adds client manager
This commit is contained in:
commit
a846016ceb
37 changed files with 1122 additions and 690 deletions
26
admin/api.go
26
admin/api.go
|
@ -1,31 +1,27 @@
|
|||
// package admin provides an implementation of the API described in auth/schema/adminschema.
|
||||
// Package admin provides an implementation of the API described in auth/schema/adminschema.
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/schema/adminschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
var (
|
||||
ClientIDGenerator = oidc.GenClientID
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
// AdminAPI provides the logic necessary to implement the Admin API.
|
||||
type AdminAPI struct {
|
||||
userManager *manager.UserManager
|
||||
userManager *usermanager.UserManager
|
||||
userRepo user.UserRepo
|
||||
passwordInfoRepo user.PasswordInfoRepo
|
||||
clientRepo client.ClientRepo
|
||||
clientManager *clientmanager.ClientManager
|
||||
localConnectorID string
|
||||
}
|
||||
|
||||
func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *manager.UserManager, localConnectorID string) *AdminAPI {
|
||||
func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, localConnectorID string) *AdminAPI {
|
||||
if localConnectorID == "" {
|
||||
panic("must specify non-blank localConnectorID")
|
||||
}
|
||||
|
@ -34,6 +30,7 @@ func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRe
|
|||
userRepo: userRepo,
|
||||
passwordInfoRepo: pwiRepo,
|
||||
clientRepo: clientRepo,
|
||||
clientManager: clientManager,
|
||||
localConnectorID: localConnectorID,
|
||||
}
|
||||
}
|
||||
|
@ -141,14 +138,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem
|
|||
}
|
||||
|
||||
// metadata is guaranteed to have at least one redirect_uri by earlier validation.
|
||||
id, err := ClientIDGenerator(cli.Metadata.RedirectURIs[0].Host)
|
||||
if err != nil {
|
||||
return adminschema.ClientCreateResponse{}, mapError(err)
|
||||
}
|
||||
|
||||
cli.Credentials.ID = id
|
||||
|
||||
creds, err := a.clientRepo.New(cli)
|
||||
creds, err := a.clientManager.New(cli)
|
||||
if err != nil {
|
||||
return adminschema.ClientCreateResponse{}, mapError(err)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/schema/adminschema"
|
||||
|
@ -17,6 +18,7 @@ type testFixtures struct {
|
|||
ur user.UserRepo
|
||||
pwr user.PasswordInfoRepo
|
||||
cr client.ClientRepo
|
||||
cm *clientmanager.ClientManager
|
||||
mgr *manager.UserManager
|
||||
adAPI *AdminAPI
|
||||
}
|
||||
|
@ -71,7 +73,8 @@ func makeTestFixtures() *testFixtures {
|
|||
}()
|
||||
|
||||
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, "local")
|
||||
f.cm = clientmanager.NewClientManager(f.cr, db.TransactionFactory(dbMap), clientmanager.ManagerOptions{})
|
||||
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, f.cm, "local")
|
||||
|
||||
return f
|
||||
}
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
|
@ -17,6 +21,24 @@ var (
|
|||
ErrorNotFound = errors.New("no data found")
|
||||
)
|
||||
|
||||
const (
|
||||
bcryptHashCost = 10
|
||||
)
|
||||
|
||||
func HashSecret(creds oidc.ClientCredentials) ([]byte, error) {
|
||||
secretBytes, err := base64.URLEncoding.DecodeString(creds.Secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(
|
||||
secretBytes),
|
||||
bcryptHashCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hashed, nil
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
Credentials oidc.ClientCredentials
|
||||
Metadata oidc.ClientMetadata
|
||||
|
@ -24,30 +46,20 @@ type Client struct {
|
|||
}
|
||||
|
||||
type ClientRepo interface {
|
||||
Get(clientID string) (Client, error)
|
||||
Get(tx repo.Transaction, clientID string) (Client, error)
|
||||
|
||||
// Metadata returns one matching ClientMetadata if the given client
|
||||
// exists, otherwise nil. The returned error will be non-nil only
|
||||
// if the repo was unable to determine client existence.
|
||||
Metadata(clientID string) (*oidc.ClientMetadata, error)
|
||||
|
||||
// Authenticate asserts that a client with the given ID exists and
|
||||
// that the provided secret matches. If either of these assertions
|
||||
// fail, (false, nil) will be returned. Only if the repo is unable
|
||||
// to make these assertions will a non-nil error be returned.
|
||||
Authenticate(creds oidc.ClientCredentials) (bool, error)
|
||||
// GetSecret returns the (base64 encoded) hashed client secret
|
||||
GetSecret(tx repo.Transaction, clientID string) ([]byte, error)
|
||||
|
||||
// All returns all registered Clients
|
||||
All() ([]Client, error)
|
||||
All(tx repo.Transaction) ([]Client, error)
|
||||
|
||||
// New registers a Client with the repo.
|
||||
// An unused ID must be provided. A corresponding secret will be returned
|
||||
// in a ClientCredentials struct along with the provided ID.
|
||||
New(client Client) (*oidc.ClientCredentials, error)
|
||||
New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error)
|
||||
|
||||
SetDexAdmin(clientID string, isAdmin bool) error
|
||||
|
||||
IsDexAdmin(clientID string) (bool, error)
|
||||
Update(tx repo.Transaction, client Client) error
|
||||
}
|
||||
|
||||
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.
|
||||
|
|
|
@ -34,7 +34,7 @@ var (
|
|||
|
||||
badSecretClient = `{
|
||||
"id": "my_id",
|
||||
"secret": "` + "****" + `",
|
||||
"secret": "` + "" + `",
|
||||
"redirectURLs": ["https://client.example.com"]
|
||||
}`
|
||||
|
||||
|
@ -64,7 +64,7 @@ func TestClientsFromReader(t *testing.T) {
|
|||
{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "my_id",
|
||||
Secret: "my_secret",
|
||||
Secret: goodSecret1,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -80,7 +80,7 @@ func TestClientsFromReader(t *testing.T) {
|
|||
{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "my_id",
|
||||
Secret: "my_secret",
|
||||
Secret: goodSecret1,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -91,7 +91,7 @@ func TestClientsFromReader(t *testing.T) {
|
|||
{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "my_other_id",
|
||||
Secret: "my_other_secret",
|
||||
Secret: goodSecret2,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -101,7 +101,8 @@ func TestClientsFromReader(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
},
|
||||
{
|
||||
json: "[" + badURLClient + "]",
|
||||
wantErr: true,
|
||||
},
|
||||
|
|
213
client/manager/manager.go
Normal file
213
client/manager/manager.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"errors"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
pcrypto "github.com/coreos/dex/pkg/crypto"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// Blowfish, the algorithm underlying bcrypt, has a maximum
|
||||
// password length of 72. We explicitly track and check this
|
||||
// since the bcrypt library will silently ignore portions of
|
||||
// a password past the first 72 characters.
|
||||
maxSecretLength = 72
|
||||
)
|
||||
|
||||
type SecretGenerator func() ([]byte, error)
|
||||
|
||||
func DefaultSecretGenerator() ([]byte, error) {
|
||||
return pcrypto.RandBytes(maxSecretLength)
|
||||
}
|
||||
|
||||
func CompareHashAndPassword(hashedPassword, password []byte) error {
|
||||
if len(password) > maxSecretLength {
|
||||
return errors.New("password length greater than max secret length")
|
||||
}
|
||||
return bcrypt.CompareHashAndPassword(hashedPassword, password)
|
||||
}
|
||||
|
||||
// ClientManager performs client-related "business-logic" functions on client and related objects.
|
||||
// This is in contrast to the Repos which perform little more than CRUD operations.
|
||||
type ClientManager struct {
|
||||
clientRepo client.ClientRepo
|
||||
begin repo.TransactionFactory
|
||||
secretGenerator SecretGenerator
|
||||
clientIDGenerator func(string) (string, error)
|
||||
}
|
||||
|
||||
type ManagerOptions struct {
|
||||
SecretGenerator func() ([]byte, error)
|
||||
ClientIDGenerator func(string) (string, error)
|
||||
}
|
||||
|
||||
func NewClientManager(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *ClientManager {
|
||||
if options.SecretGenerator == nil {
|
||||
options.SecretGenerator = DefaultSecretGenerator
|
||||
}
|
||||
if options.ClientIDGenerator == nil {
|
||||
options.ClientIDGenerator = oidc.GenClientID
|
||||
}
|
||||
return &ClientManager{
|
||||
clientRepo: clientRepo,
|
||||
begin: txnFactory,
|
||||
secretGenerator: options.SecretGenerator,
|
||||
clientIDGenerator: options.ClientIDGenerator,
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientManagerFromClients(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, clients []client.Client, options ManagerOptions) (*ClientManager, error) {
|
||||
clientManager := NewClientManager(clientRepo, txnFactory, options)
|
||||
tx, err := clientManager.begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, c := range clients {
|
||||
if c.Credentials.Secret == "" {
|
||||
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
|
||||
}
|
||||
|
||||
cli, err := clientManager.generateClientCredentials(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = clientRepo.New(tx, cli)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return clientManager, nil
|
||||
}
|
||||
|
||||
func (m *ClientManager) New(cli client.Client) (*oidc.ClientCredentials, error) {
|
||||
tx, err := m.begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
c, err := m.generateClientCredentials(cli)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
creds := c.Credentials
|
||||
|
||||
// Save Client
|
||||
_, err = m.clientRepo.New(tx, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Returns creds with unhashed secret
|
||||
return &creds, nil
|
||||
}
|
||||
|
||||
func (m *ClientManager) Get(id string) (client.Client, error) {
|
||||
return m.clientRepo.Get(nil, id)
|
||||
}
|
||||
|
||||
func (m *ClientManager) All() ([]client.Client, error) {
|
||||
return m.clientRepo.All(nil)
|
||||
}
|
||||
|
||||
func (m *ClientManager) Metadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||
c, err := m.clientRepo.Get(nil, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &c.Metadata, nil
|
||||
}
|
||||
|
||||
func (m *ClientManager) IsDexAdmin(clientID string) (bool, error) {
|
||||
c, err := m.clientRepo.Get(nil, clientID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return c.Admin, nil
|
||||
}
|
||||
|
||||
func (m *ClientManager) SetDexAdmin(clientID string, isAdmin bool) error {
|
||||
tx, err := m.begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
c, err := m.clientRepo.Get(tx, clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Admin = isAdmin
|
||||
err = m.clientRepo.Update(tx, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ClientManager) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
||||
clientSecret, err := m.clientRepo.GetSecret(nil, creds.ID)
|
||||
if err != nil || clientSecret == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
|
||||
if err != nil {
|
||||
log.Errorf("error Decoding client creds: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
ok := CompareHashAndPassword(clientSecret, dec) == nil
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (m *ClientManager) generateClientCredentials(cli client.Client) (client.Client, error) {
|
||||
// Generate Client ID
|
||||
if len(cli.Metadata.RedirectURIs) < 1 {
|
||||
return cli, errors.New("no client redirect url given")
|
||||
}
|
||||
clientID, err := m.clientIDGenerator(cli.Metadata.RedirectURIs[0].Host)
|
||||
if err != nil {
|
||||
return cli, err
|
||||
}
|
||||
|
||||
// Generate Secret
|
||||
secret, err := m.secretGenerator()
|
||||
if err != nil {
|
||||
return cli, err
|
||||
}
|
||||
clientSecret := base64.URLEncoding.EncodeToString(secret)
|
||||
cli.Credentials = oidc.ClientCredentials{
|
||||
ID: clientID,
|
||||
Secret: clientSecret,
|
||||
}
|
||||
return cli, nil
|
||||
}
|
165
client/manager/manager_test.go
Normal file
165
client/manager/manager_test.go
Normal file
|
@ -0,0 +1,165 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
type testFixtures struct {
|
||||
clientRepo client.ClientRepo
|
||||
mgr *ClientManager
|
||||
}
|
||||
|
||||
var (
|
||||
goodSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
|
||||
)
|
||||
|
||||
func makeTestFixtures() *testFixtures {
|
||||
f := &testFixtures{}
|
||||
|
||||
dbMap := db.NewMemDB()
|
||||
clients := []client.Client{
|
||||
{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "client.example.com",
|
||||
Secret: goodSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "http", Host: "client.example.com", Path: "/"},
|
||||
},
|
||||
},
|
||||
Admin: true,
|
||||
},
|
||||
}
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
f.clientRepo = db.NewClientRepo(dbMap)
|
||||
clientManager, err := NewClientManagerFromClients(f.clientRepo, db.TransactionFactory(dbMap), clients, ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
panic("Failed to create client manager: " + err.Error())
|
||||
}
|
||||
f.mgr = clientManager
|
||||
return f
|
||||
}
|
||||
|
||||
func TestMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
clientID string
|
||||
uri string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
clientID: "client.example.com",
|
||||
uri: "http://client.example.com/",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
md, err := f.mgr.Metadata(tt.clientID)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected err: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if md.RedirectURIs[0].String() != tt.uri {
|
||||
t.Errorf("case %d: manager.Metadata.RedirectURIs: want=%q got=%q", i, tt.uri, md.RedirectURIs[0].String())
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDexAdmin(t *testing.T) {
|
||||
tests := []struct {
|
||||
clientID string
|
||||
isAdmin bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
clientID: "client.example.com",
|
||||
isAdmin: true,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
f := makeTestFixtures()
|
||||
admin, err := f.mgr.IsDexAdmin(tt.clientID)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected err: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if admin != tt.isAdmin {
|
||||
t.Errorf("case %d: manager.Admin want=%t got=%t", i, tt.isAdmin, admin)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetDexAdmin(t *testing.T) {
|
||||
f := makeTestFixtures()
|
||||
err := f.mgr.SetDexAdmin("client.example.com", false)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected err: %v", err)
|
||||
}
|
||||
admin, _ := f.mgr.IsDexAdmin("client.example.com")
|
||||
if admin {
|
||||
t.Errorf("expected admin to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate(t *testing.T) {
|
||||
f := makeTestFixtures()
|
||||
cm := oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "example.com", Path: "/cb"},
|
||||
},
|
||||
}
|
||||
cli := client.Client{
|
||||
Metadata: cm,
|
||||
}
|
||||
cc, err := f.mgr.New(cli)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
ok, err := f.mgr.Authenticate(*cc)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
} else if !ok {
|
||||
t.Fatalf("Authentication failed for good creds")
|
||||
}
|
||||
|
||||
creds := []oidc.ClientCredentials{
|
||||
//completely made up
|
||||
oidc.ClientCredentials{ID: "foo", Secret: "bar"},
|
||||
|
||||
// good client ID, bad secret
|
||||
oidc.ClientCredentials{ID: cc.ID, Secret: "bar"},
|
||||
|
||||
// bad client ID, good secret
|
||||
oidc.ClientCredentials{ID: "foo", Secret: cc.Secret},
|
||||
|
||||
// good client ID, secret with some fluff on the end
|
||||
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
|
||||
}
|
||||
for i, c := range creds {
|
||||
ok, err := f.mgr.Authenticate(c)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
} else if ok {
|
||||
t.Errorf("case %d: authentication succeeded for bad creds", i)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/admin"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/db"
|
||||
pflag "github.com/coreos/dex/pkg/flag"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
|
@ -119,8 +120,9 @@ func main() {
|
|||
clientRepo := db.NewClientRepo(dbc)
|
||||
userManager := manager.NewUserManager(userRepo,
|
||||
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
|
||||
|
||||
adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, *localConnectorID)
|
||||
adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, clientManager, *localConnectorID)
|
||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
|
||||
if err != nil {
|
||||
log.Fatalf(err.Error())
|
||||
|
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
|
@ -14,34 +15,26 @@ func newDBDriver(dsn string) (driver, error) {
|
|||
}
|
||||
|
||||
drv := &dbDriver{
|
||||
ciRepo: db.NewClientRepo(dbc),
|
||||
cfgRepo: db.NewConnectorConfigRepo(dbc),
|
||||
cfgRepo: db.NewConnectorConfigRepo(dbc),
|
||||
ciManager: manager.NewClientManager(db.NewClientRepo(dbc), db.TransactionFactory(dbc), manager.ManagerOptions{}),
|
||||
}
|
||||
|
||||
return drv, nil
|
||||
}
|
||||
|
||||
type dbDriver struct {
|
||||
ciRepo client.ClientRepo
|
||||
cfgRepo *db.ConnectorConfigRepo
|
||||
ciManager *manager.ClientManager
|
||||
cfgRepo *db.ConnectorConfigRepo
|
||||
}
|
||||
|
||||
func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) {
|
||||
if err := meta.Valid(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientID, err := oidc.GenClientID(meta.RedirectURIs[0].Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d.ciRepo.New(client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: clientID,
|
||||
},
|
||||
cli := client.Client{
|
||||
Metadata: meta,
|
||||
})
|
||||
}
|
||||
return d.ciManager.New(cli)
|
||||
}
|
||||
|
||||
func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) {
|
||||
|
|
203
db/client.go
203
db/client.go
|
@ -2,7 +2,6 @@ package db
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -10,24 +9,15 @@ import (
|
|||
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/go-gorp/gorp"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
pcrypto "github.com/coreos/dex/pkg/crypto"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/repo"
|
||||
)
|
||||
|
||||
const (
|
||||
clientTableName = "client_identity"
|
||||
|
||||
bcryptHashCost = 10
|
||||
|
||||
// Blowfish, the algorithm underlying bcrypt, has a maximum
|
||||
// password length of 72. We explicitly track and check this
|
||||
// since the bcrypt library will silently ignore portions of
|
||||
// a password past the first 72 characters.
|
||||
maxSecretLength = 72
|
||||
|
||||
// postgres error codes
|
||||
pgErrorCodeUniqueViolation = "23505" // unique_violation
|
||||
)
|
||||
|
@ -42,17 +32,10 @@ func init() {
|
|||
}
|
||||
|
||||
func newClientModel(cli client.Client) (*clientModel, error) {
|
||||
secretBytes, err := base64.URLEncoding.DecodeString(cli.Credentials.Secret)
|
||||
hashed, err := client.HashSecret(cli.Credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(
|
||||
secretBytes),
|
||||
bcryptHashCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bmeta, err := json.Marshal(&cli.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -92,56 +75,20 @@ func (m *clientModel) Client() (*client.Client, error) {
|
|||
|
||||
func NewClientRepo(dbm *gorp.DbMap) client.ClientRepo {
|
||||
return newClientRepo(dbm)
|
||||
|
||||
}
|
||||
|
||||
func NewClientRepoWithSecretGenerator(dbm *gorp.DbMap, secGen SecretGenerator) client.ClientRepo {
|
||||
rep := newClientRepo(dbm)
|
||||
rep.secretGenerator = secGen
|
||||
return rep
|
||||
}
|
||||
|
||||
func newClientRepo(dbm *gorp.DbMap) *clientRepo {
|
||||
return &clientRepo{
|
||||
db: &db{dbm},
|
||||
secretGenerator: DefaultSecretGenerator,
|
||||
db: &db{dbm},
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientRepoFromClients(dbm *gorp.DbMap, clients []client.Client) (client.ClientRepo, error) {
|
||||
repo := newClientRepo(dbm)
|
||||
tx, err := repo.begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := repo.executor(tx)
|
||||
for _, c := range clients {
|
||||
if c.Credentials.Secret == "" {
|
||||
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
|
||||
}
|
||||
cm, err := newClientModel(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = exec.Insert(cm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
type clientRepo struct {
|
||||
*db
|
||||
secretGenerator SecretGenerator
|
||||
}
|
||||
|
||||
func (r *clientRepo) Get(clientID string) (client.Client, error) {
|
||||
m, err := r.executor(nil).Get(clientModel{}, clientID)
|
||||
func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) {
|
||||
m, err := r.executor(tx).Get(clientModel{}, clientID)
|
||||
if err == sql.ErrNoRows || m == nil {
|
||||
return client.Client{}, client.ErrorNotFound
|
||||
}
|
||||
|
@ -163,82 +110,28 @@ func (r *clientRepo) Get(clientID string) (client.Client, error) {
|
|||
return *ci, nil
|
||||
}
|
||||
|
||||
func (r *clientRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||
c, err := r.Get(clientID)
|
||||
if err != nil {
|
||||
func (r *clientRepo) GetSecret(tx repo.Transaction, clientID string) ([]byte, error) {
|
||||
m, err := r.getModel(tx, clientID)
|
||||
if err != nil || m == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &c.Metadata, nil
|
||||
return m.Secret, nil
|
||||
}
|
||||
|
||||
func (r *clientRepo) IsDexAdmin(clientID string) (bool, error) {
|
||||
m, err := r.executor(nil).Get(clientModel{}, clientID)
|
||||
if m == nil || err != nil {
|
||||
return false, err
|
||||
func (r *clientRepo) Update(tx repo.Transaction, cli client.Client) error {
|
||||
if cli.Credentials.ID == "" {
|
||||
return client.ErrorNotFound
|
||||
}
|
||||
|
||||
cim, ok := m.(*clientModel)
|
||||
if !ok {
|
||||
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
|
||||
return false, errors.New("unrecognized model")
|
||||
}
|
||||
|
||||
return cim.DexAdmin, nil
|
||||
}
|
||||
|
||||
func (r *clientRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
||||
tx, err := r.begin()
|
||||
// make sure this client exists already
|
||||
_, err := r.get(tx, cli.Credentials.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
exec := r.executor(tx)
|
||||
|
||||
m, err := exec.Get(clientModel{}, clientID)
|
||||
if m == nil || err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cim, ok := m.(*clientModel)
|
||||
if !ok {
|
||||
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
|
||||
return errors.New("unrecognized model")
|
||||
}
|
||||
|
||||
cim.DexAdmin = isAdmin
|
||||
_, err = exec.Update(cim)
|
||||
err = r.update(tx, cli)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *clientRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
||||
m, err := r.executor(nil).Get(clientModel{}, creds.ID)
|
||||
if m == nil || err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
cim, ok := m.(*clientModel)
|
||||
if !ok {
|
||||
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
|
||||
return false, errors.New("unrecognized model")
|
||||
}
|
||||
|
||||
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
|
||||
if err != nil {
|
||||
log.Errorf("error Decoding client creds: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(dec) > maxSecretLength {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
ok = bcrypt.CompareHashAndPassword(cim.Secret, dec) == nil
|
||||
return ok, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var alreadyExistsCheckers []func(err error) bool
|
||||
|
@ -260,26 +153,14 @@ func isAlreadyExistsErr(err error) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
type SecretGenerator func() ([]byte, error)
|
||||
|
||||
func DefaultSecretGenerator() ([]byte, error) {
|
||||
return pcrypto.RandBytes(maxSecretLength)
|
||||
}
|
||||
|
||||
func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
|
||||
secret, err := r.secretGenerator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cli.Credentials.Secret = base64.URLEncoding.EncodeToString(secret)
|
||||
func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) {
|
||||
cim, err := newClientModel(cli)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := r.executor(nil).Insert(cim); err != nil {
|
||||
if err := r.executor(tx).Insert(cim); err != nil {
|
||||
if isAlreadyExistsErr(err) {
|
||||
err = errors.New("client ID already exists")
|
||||
}
|
||||
|
@ -294,10 +175,10 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
|
|||
return &cc, nil
|
||||
}
|
||||
|
||||
func (r *clientRepo) All() ([]client.Client, error) {
|
||||
func (r *clientRepo) All(tx repo.Transaction) ([]client.Client, error) {
|
||||
qt := r.quote(clientTableName)
|
||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||
objs, err := r.executor(nil).Select(&clientModel{}, q)
|
||||
objs, err := r.executor(tx).Select(&clientModel{}, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -317,3 +198,47 @@ func (r *clientRepo) All() ([]client.Client, error) {
|
|||
}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func (r *clientRepo) get(tx repo.Transaction, clientID string) (client.Client, error) {
|
||||
cm, err := r.getModel(tx, clientID)
|
||||
if err != nil {
|
||||
return client.Client{}, err
|
||||
}
|
||||
|
||||
cli, err := cm.Client()
|
||||
if err != nil {
|
||||
return client.Client{}, err
|
||||
}
|
||||
|
||||
return *cli, nil
|
||||
}
|
||||
|
||||
func (r *clientRepo) getModel(tx repo.Transaction, clientID string) (*clientModel, error) {
|
||||
ex := r.executor(tx)
|
||||
|
||||
m, err := ex.Get(clientModel{}, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, client.ErrorNotFound
|
||||
}
|
||||
|
||||
cm, ok := m.(*clientModel)
|
||||
if !ok {
|
||||
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
|
||||
return nil, errors.New("unrecognized model")
|
||||
}
|
||||
return cm, nil
|
||||
}
|
||||
|
||||
func (r *clientRepo) update(tx repo.Transaction, cli client.Client) error {
|
||||
ex := r.executor(tx)
|
||||
cm, err := newClientModel(cli)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = ex.Update(cm)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/coreos/dex/session"
|
||||
|
@ -191,7 +192,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := r.New(client.Client{
|
||||
_, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -201,20 +202,22 @@ func TestDBClientRepoMetadata(t *testing.T) {
|
|||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
got, err := r.Metadata("foo")
|
||||
got, err := r.Get(nil, "foo")
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(cm, *got); diff != "" {
|
||||
if diff := pretty.Compare(cm, got.Metadata); diff != "" {
|
||||
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBClientRepoMetadataNoExist(t *testing.T) {
|
||||
r := db.NewClientRepo(connect(t))
|
||||
c := connect(t)
|
||||
r := db.NewClientRepo(c)
|
||||
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{})
|
||||
|
||||
got, err := r.Metadata("noexist")
|
||||
got, err := m.Metadata("noexist")
|
||||
if err != client.ErrorNotFound {
|
||||
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
|
||||
}
|
||||
|
@ -232,7 +235,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
if _, err := r.New(client.Client{
|
||||
if _, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -247,7 +250,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
if _, err := r.New(client.Client{
|
||||
if _, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -261,7 +264,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
|
|||
|
||||
for _, admin := range []bool{true, false} {
|
||||
r := db.NewClientRepo(connect(t))
|
||||
if _, err := r.New(client.Client{
|
||||
if _, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -275,15 +278,15 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
|
|||
t.Fatalf("expected non-nil error: %v", err)
|
||||
}
|
||||
|
||||
gotAdmin, err := r.IsDexAdmin("foo")
|
||||
gotAdmin, err := r.Get(nil, "foo")
|
||||
if err != nil {
|
||||
t.Fatalf("expected non-nil error")
|
||||
}
|
||||
if gotAdmin != admin {
|
||||
if gotAdmin.Admin != admin {
|
||||
t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin)
|
||||
}
|
||||
|
||||
cli, err := r.Get("foo")
|
||||
cli, err := r.Get(nil, "foo")
|
||||
if err != nil {
|
||||
t.Fatalf("expected non-nil error")
|
||||
}
|
||||
|
@ -294,29 +297,35 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
|
|||
|
||||
}
|
||||
func TestDBClientRepoAuthenticate(t *testing.T) {
|
||||
r := db.NewClientRepo(connect(t))
|
||||
c := connect(t)
|
||||
r := db.NewClientRepo(c)
|
||||
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
|
||||
cm := oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
|
||||
},
|
||||
}
|
||||
|
||||
cc, err := r.New(client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "baz",
|
||||
},
|
||||
cli := client.Client{
|
||||
Metadata: cm,
|
||||
})
|
||||
}
|
||||
cc, err := m.New(cli)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
if cc.ID != "baz" {
|
||||
if cc.ID != "127.0.0.1:5556" {
|
||||
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
|
||||
}
|
||||
|
||||
ok, err := r.Authenticate(*cc)
|
||||
ok, err := m.Authenticate(*cc)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
} else if !ok {
|
||||
|
@ -337,7 +346,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
|
|||
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
|
||||
}
|
||||
for i, c := range creds {
|
||||
ok, err := r.Authenticate(c)
|
||||
ok, err := m.Authenticate(c)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
} else if ok {
|
||||
|
@ -355,7 +364,7 @@ func TestDBClientAll(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := r.New(client.Client{
|
||||
_, err := r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo",
|
||||
},
|
||||
|
@ -365,7 +374,7 @@ func TestDBClientAll(t *testing.T) {
|
|||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
got, err := r.All()
|
||||
got, err := r.All(nil)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
@ -383,7 +392,7 @@ func TestDBClientAll(t *testing.T) {
|
|||
url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
|
||||
},
|
||||
}
|
||||
_, err = r.New(client.Client{
|
||||
_, err = r.New(nil, client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "bar",
|
||||
},
|
||||
|
@ -393,7 +402,7 @@ func TestDBClientAll(t *testing.T) {
|
|||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
got, err = r.All()
|
||||
got, err = r.All(nil)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
|
|
@ -3,14 +3,10 @@ package repo
|
|||
import (
|
||||
"encoding/base64"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -47,95 +43,3 @@ var (
|
|||
},
|
||||
}
|
||||
)
|
||||
|
||||
func newClientRepo(t *testing.T) client.ClientRepo {
|
||||
dsn := os.Getenv("DEX_TEST_DSN")
|
||||
var dbMap *gorp.DbMap
|
||||
if dsn == "" {
|
||||
dbMap = db.NewMemDB()
|
||||
} else {
|
||||
dbMap = connect(t)
|
||||
}
|
||||
repo, err := db.NewClientRepoFromClients(dbMap, testClients)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create client repo from clients: %v", err)
|
||||
}
|
||||
return repo
|
||||
}
|
||||
|
||||
func TestGetSetAdminClient(t *testing.T) {
|
||||
startAdmins := []string{"client2"}
|
||||
tests := []struct {
|
||||
// client ID
|
||||
cid string
|
||||
|
||||
// initial state of client
|
||||
wantAdmin bool
|
||||
|
||||
// final state of client
|
||||
setAdmin bool
|
||||
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
cid: "client1",
|
||||
wantAdmin: false,
|
||||
setAdmin: true,
|
||||
},
|
||||
{
|
||||
cid: "client1",
|
||||
wantAdmin: false,
|
||||
setAdmin: false,
|
||||
},
|
||||
{
|
||||
cid: "client2",
|
||||
wantAdmin: true,
|
||||
setAdmin: true,
|
||||
},
|
||||
{
|
||||
cid: "client2",
|
||||
wantAdmin: true,
|
||||
setAdmin: false,
|
||||
},
|
||||
}
|
||||
|
||||
Tests:
|
||||
for i, tt := range tests {
|
||||
repo := newClientRepo(t)
|
||||
for _, cid := range startAdmins {
|
||||
err := repo.SetDexAdmin(cid, true)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to set dex admin: %v", i, err)
|
||||
continue Tests
|
||||
}
|
||||
}
|
||||
|
||||
gotAdmin, err := repo.IsDexAdmin(tt.cid)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if gotAdmin != tt.wantAdmin {
|
||||
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
|
||||
}
|
||||
|
||||
err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
gotAdmin, err = repo.IsDexAdmin(tt.cid)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if gotAdmin != tt.setAdmin {
|
||||
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/refresh"
|
||||
"github.com/coreos/dex/user"
|
||||
|
@ -27,7 +28,7 @@ func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients
|
|||
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
|
||||
t.Fatalf("Unable to add users: %v", err)
|
||||
}
|
||||
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil {
|
||||
if _, err := manager.NewClientManagerFromClients(db.NewClientRepo(dbMap), db.TransactionFactory(dbMap), clients, manager.ManagerOptions{}); err != nil {
|
||||
t.Fatalf("Unable to add clients: %v", err)
|
||||
}
|
||||
return db.NewRefreshTokenRepo(dbMap)
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/coreos/dex/admin"
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/schema/adminschema"
|
||||
"github.com/coreos/dex/server"
|
||||
|
@ -87,12 +88,16 @@ func makeAdminAPITestFixtures() *adminAPITestFixtures {
|
|||
secGen := func() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("client_%v", cliCount)), nil
|
||||
}
|
||||
cr := db.NewClientRepoWithSecretGenerator(dbMap, secGen)
|
||||
cr := db.NewClientRepo(dbMap)
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return fmt.Sprintf("client_%v", hostport), nil
|
||||
}
|
||||
cm := manager.NewClientManager(cr, db.TransactionFactory(dbMap), manager.ManagerOptions{SecretGenerator: secGen, ClientIDGenerator: clientIDGenerator})
|
||||
|
||||
f.cr = cr
|
||||
f.ur = ur
|
||||
f.pwr = pwr
|
||||
f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, "local")
|
||||
f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, cm, "local")
|
||||
f.adSrv = server.NewAdminServer(f.adAPI, nil, adminAPITestSecret)
|
||||
f.hSrv = httptest.NewServer(f.adSrv.HTTPHandler())
|
||||
f.hc = &http.Client{
|
||||
|
@ -268,14 +273,6 @@ func TestCreateAdmin(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreateClient(t *testing.T) {
|
||||
oldGen := admin.ClientIDGenerator
|
||||
admin.ClientIDGenerator = func(hostport string) (string, error) {
|
||||
return fmt.Sprintf("client_%v", hostport), nil
|
||||
}
|
||||
defer func() {
|
||||
admin.ClientIDGenerator = oldGen
|
||||
}()
|
||||
|
||||
mustParseURL := func(s string) *url.URL {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
|
@ -402,7 +399,7 @@ func TestCreateClient(t *testing.T) {
|
|||
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
|
||||
}
|
||||
|
||||
repoClient, err := f.cr.Get(resp.Client.Id)
|
||||
repoClient, err := f.cr.Get(nil, resp.Client.Id)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: Unexpected error getting client: %v", i, err)
|
||||
}
|
||||
|
|
|
@ -14,9 +14,10 @@ import (
|
|||
|
||||
func TestClientCreate(t *testing.T) {
|
||||
ci := client.Client{
|
||||
// Credentials are for reference, they are actually generated by the client manager
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "72de74a9",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
|
||||
ID: "authn.example.com",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -73,7 +74,7 @@ func TestClientCreate(t *testing.T) {
|
|||
t.Error("Expected non-empty Client Secret")
|
||||
}
|
||||
|
||||
meta, err := srv.ClientRepo.Metadata(newClient.Id)
|
||||
meta, err := srv.ClientManager.Metadata(newClient.Id)
|
||||
if err != nil {
|
||||
t.Errorf("Error looking up client metadata: %v", err)
|
||||
} else if meta == nil {
|
||||
|
|
|
@ -22,9 +22,10 @@ var (
|
|||
clock = clockwork.NewFakeClock()
|
||||
|
||||
testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"}
|
||||
testClientID = "XXX"
|
||||
testClientSecret = base64.URLEncoding.EncodeToString([]byte("yyy"))
|
||||
testClientID = "client.example.com"
|
||||
testClientSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
|
||||
testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
|
||||
testBadRedirectURL = url.URL{Scheme: "https", Host: "bad.example.com", Path: "/redirect"}
|
||||
testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
|
||||
testPrivKey, _ = key.GeneratePrivateKey()
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
phttp "github.com/coreos/dex/pkg/http"
|
||||
|
@ -35,7 +36,15 @@ func mockServer(cis []client.Client) (*server.Server, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientRepo, err := db.NewClientRepoFromClients(dbMap, cis)
|
||||
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbMap)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), cis, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -45,6 +54,7 @@ func mockServer(cis []client.Client) (*server.Server, error) {
|
|||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
SessionManager: sm,
|
||||
}
|
||||
|
||||
|
@ -82,15 +92,21 @@ func verifyUserClaims(claims jose.Claims, ci *client.Client, user *user.User, is
|
|||
expectedSub, expectedName = user.ID, user.DisplayName
|
||||
}
|
||||
|
||||
if aud := claims["aud"].(string); aud != ci.Credentials.ID {
|
||||
if aud, ok := claims["aud"].(string); !ok {
|
||||
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", ci.Credentials.ID)
|
||||
} else if aud != ci.Credentials.ID {
|
||||
return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID)
|
||||
}
|
||||
|
||||
if sub := claims["sub"].(string); sub != expectedSub {
|
||||
if sub, ok := claims["sub"].(string); !ok {
|
||||
return fmt.Errorf("unexpected claim value for sub, got=nil, want=%v", expectedSub)
|
||||
} else if sub != expectedSub {
|
||||
return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub)
|
||||
}
|
||||
|
||||
if name := claims["name"].(string); name != expectedName {
|
||||
if name, ok := claims["name"].(string); !ok {
|
||||
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", expectedName)
|
||||
} else if name != expectedName {
|
||||
return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName)
|
||||
}
|
||||
|
||||
|
@ -117,17 +133,34 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
ID: "local",
|
||||
}
|
||||
|
||||
validRedirURL := url.URL{
|
||||
Scheme: "http",
|
||||
Host: "client.example.com",
|
||||
Path: "/callback",
|
||||
}
|
||||
ci := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "72de74a9",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
|
||||
ID: validRedirURL.Host,
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
validRedirURL,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
dbMap := db.NewMemDB()
|
||||
cir, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci})
|
||||
clientRepo := db.NewClientRepo(dbMap)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: " + err.Error())
|
||||
t.Fatalf("Failed to create client identity manager: " + err.Error())
|
||||
}
|
||||
passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo})
|
||||
if err != nil {
|
||||
|
@ -164,7 +197,8 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
IssuerURL: issuerURL,
|
||||
KeyManager: km,
|
||||
SessionManager: sm,
|
||||
ClientRepo: cir,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
Templates: template.New(connector.LoginPageTemplateName),
|
||||
Connectors: []connector.Connector{},
|
||||
UserRepo: userRepo,
|
||||
|
@ -188,7 +222,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
HTTPClient: sClient,
|
||||
ProviderConfig: pcfg,
|
||||
Credentials: ci.Credentials,
|
||||
RedirectURL: "http://client.example.com",
|
||||
RedirectURL: validRedirURL.String(),
|
||||
KeySet: *ks,
|
||||
}
|
||||
|
||||
|
@ -263,10 +297,20 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHTTPClientCredsToken(t *testing.T) {
|
||||
validRedirURL := url.URL{
|
||||
Scheme: "http",
|
||||
Host: "client.example.com",
|
||||
Path: "/callback",
|
||||
}
|
||||
ci := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "72de74a9",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
|
||||
ID: validRedirURL.Host,
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
validRedirURL,
|
||||
},
|
||||
},
|
||||
}
|
||||
cis := []client.Client{ci}
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"google.golang.org/api/googleapi"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/db"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/server"
|
||||
|
@ -79,7 +80,7 @@ var (
|
|||
},
|
||||
}
|
||||
|
||||
userBadClientID = "ZZZ"
|
||||
userBadClientID = testBadRedirectURL.Host
|
||||
|
||||
userGoodToken = makeUserToken(testIssuerURL,
|
||||
"ID-1", testClientID, time.Hour*1, testPrivKey)
|
||||
|
@ -101,38 +102,42 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
|||
f := &userAPITestFixtures{}
|
||||
|
||||
dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
|
||||
cir := func() client.ClientRepo {
|
||||
repo, err := db.NewClientRepoFromClients(dbMap, []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: testClientID,
|
||||
Secret: testClientSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
testRedirectURL,
|
||||
},
|
||||
clients := []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: testClientID,
|
||||
Secret: testClientSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
testRedirectURL,
|
||||
},
|
||||
},
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: userBadClientID,
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
testRedirectURL,
|
||||
},
|
||||
},
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: userBadClientID,
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
testBadRedirectURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to create client identity repo: " + err.Error())
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
cir.SetDexAdmin(testClientID, true)
|
||||
},
|
||||
}
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte(testClientSecret), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbMap)
|
||||
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
panic("Failed to create client identity manager: " + err.Error())
|
||||
}
|
||||
clientManager.SetDexAdmin(testClientID, true)
|
||||
|
||||
noop := func() error { return nil }
|
||||
|
||||
|
@ -153,8 +158,9 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
|||
|
||||
f.emailer = &testEmailer{}
|
||||
um.Clock = clock
|
||||
api := api.NewUsersAPI(dbMap, um, f.emailer, "local")
|
||||
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir)
|
||||
|
||||
api := api.NewUsersAPI(um, clientManager, refreshRepo, f.emailer, "local")
|
||||
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, clientManager)
|
||||
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
|
||||
|
||||
f.trans = &tokenHandlerTransport{
|
||||
|
@ -536,7 +542,7 @@ func TestCreateUser(t *testing.T) {
|
|||
wantEmalier := testEmailer{
|
||||
cantEmail: tt.cantEmail,
|
||||
lastEmail: tt.req.User.Email,
|
||||
lastClientID: "XXX",
|
||||
lastClientID: testClientID,
|
||||
lastWasInvite: true,
|
||||
lastRedirectURL: *urlParsed,
|
||||
}
|
||||
|
@ -799,7 +805,7 @@ func TestResendEmailInvitation(t *testing.T) {
|
|||
wantEmalier := testEmailer{
|
||||
cantEmail: tt.cantEmail,
|
||||
lastEmail: strings.ToLower(tt.email),
|
||||
lastClientID: "XXX",
|
||||
lastClientID: testClientID,
|
||||
lastWasInvite: true,
|
||||
lastRedirectURL: *urlParsed,
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
|
@ -14,7 +14,7 @@ import (
|
|||
|
||||
type clientTokenMiddleware struct {
|
||||
issuerURL string
|
||||
ciRepo client.ClientRepo
|
||||
ciManager *manager.ClientManager
|
||||
keysFunc func() ([]key.PublicKey, error)
|
||||
next http.Handler
|
||||
}
|
||||
|
@ -30,8 +30,8 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
if c.ciRepo == nil {
|
||||
log.Errorf("Misconfigured clientTokenMiddleware, ClientRepo is not set")
|
||||
if c.ciManager == nil {
|
||||
log.Errorf("Misconfigured clientTokenMiddleware, ClientManager is not set")
|
||||
respondError()
|
||||
return
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
md, err := c.ciRepo.Metadata(clientID)
|
||||
md, err := c.ciManager.Metadata(clientID)
|
||||
if md == nil || err != nil {
|
||||
log.Errorf("Failed to find clientID: %s, error=%v", clientID, err)
|
||||
respondError()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -10,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
|
@ -25,22 +25,23 @@ func (h staticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
func TestClientToken(t *testing.T) {
|
||||
now := time.Now()
|
||||
tomorrow := now.Add(24 * time.Hour)
|
||||
validClientID := "valid-client"
|
||||
ci := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: validClientID,
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
|
||||
},
|
||||
clientMetadata := oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
|
||||
},
|
||||
}
|
||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
||||
|
||||
dbm := db.NewMemDB()
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbm), clientmanager.ManagerOptions{})
|
||||
cli := client.Client{
|
||||
Metadata: clientMetadata,
|
||||
}
|
||||
creds, err := clientManager.New(cli)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
validClientID := creds.ID
|
||||
|
||||
privKey, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
@ -65,63 +66,63 @@ func TestClientToken(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
keys []key.PublicKey
|
||||
repo client.ClientRepo
|
||||
manager *clientmanager.ClientManager
|
||||
header string
|
||||
wantCode int
|
||||
}{
|
||||
// valid token
|
||||
{
|
||||
keys: []key.PublicKey{pubKey},
|
||||
repo: repo,
|
||||
manager: clientManager,
|
||||
header: fmt.Sprintf("BEARER %s", validJWT),
|
||||
wantCode: http.StatusOK,
|
||||
},
|
||||
// invalid token
|
||||
{
|
||||
keys: []key.PublicKey{pubKey},
|
||||
repo: repo,
|
||||
manager: clientManager,
|
||||
header: fmt.Sprintf("BEARER %s", invalidJWT),
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
// empty header
|
||||
{
|
||||
keys: []key.PublicKey{pubKey},
|
||||
repo: repo,
|
||||
manager: clientManager,
|
||||
header: "",
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
// unparsable token
|
||||
{
|
||||
keys: []key.PublicKey{pubKey},
|
||||
repo: repo,
|
||||
manager: clientManager,
|
||||
header: "BEARER xxx",
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
// no verification keys
|
||||
{
|
||||
keys: []key.PublicKey{},
|
||||
repo: repo,
|
||||
manager: clientManager,
|
||||
header: fmt.Sprintf("BEARER %s", validJWT),
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
// nil repo
|
||||
{
|
||||
keys: []key.PublicKey{pubKey},
|
||||
repo: nil,
|
||||
manager: nil,
|
||||
header: fmt.Sprintf("BEARER %s", validJWT),
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
// empty repo
|
||||
{
|
||||
keys: []key.PublicKey{pubKey},
|
||||
repo: db.NewClientRepo(db.NewMemDB()),
|
||||
manager: clientmanager.NewClientManager(db.NewClientRepo(db.NewMemDB()), db.TransactionFactory(db.NewMemDB()), clientmanager.ManagerOptions{}),
|
||||
header: fmt.Sprintf("BEARER %s", validJWT),
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
// client not in repo
|
||||
{
|
||||
keys: []key.PublicKey{pubKey},
|
||||
repo: repo,
|
||||
manager: clientManager,
|
||||
header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
|
@ -131,7 +132,7 @@ func TestClientToken(t *testing.T) {
|
|||
w := httptest.NewRecorder()
|
||||
mw := &clientTokenMiddleware{
|
||||
issuerURL: validIss,
|
||||
ciRepo: tt.repo,
|
||||
ciManager: tt.manager,
|
||||
keysFunc: func() ([]key.PublicKey, error) {
|
||||
return tt.keys, nil
|
||||
},
|
||||
|
|
|
@ -39,18 +39,10 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR
|
|||
}
|
||||
|
||||
// metadata is guarenteed to have at least one redirect_uri by earlier validation.
|
||||
id, err := oidc.GenClientID(clientMetadata.RedirectURIs[0].Host)
|
||||
if err != nil {
|
||||
log.Errorf("Faild to create client ID: %v", err)
|
||||
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
|
||||
}
|
||||
|
||||
creds, err := s.ClientRepo.New(client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: id,
|
||||
},
|
||||
cli := client.Client{
|
||||
Metadata: clientMetadata,
|
||||
})
|
||||
}
|
||||
creds, err := s.ClientManager.New(cli)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create new client identity: %v", err)
|
||||
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
|
||||
|
|
|
@ -143,7 +143,7 @@ func TestClientRegistration(t *testing.T) {
|
|||
return fmt.Errorf("no client id in registration response")
|
||||
}
|
||||
|
||||
metadata, err := fixtures.clientRepo.Metadata(r.ClientID)
|
||||
metadata, err := fixtures.clientManager.Metadata(r.ClientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lookup client id after creation")
|
||||
}
|
||||
|
|
|
@ -6,21 +6,20 @@ import (
|
|||
"net/http"
|
||||
"path"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/client/manager"
|
||||
phttp "github.com/coreos/dex/pkg/http"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
)
|
||||
|
||||
type clientResource struct {
|
||||
repo client.ClientRepo
|
||||
manager *manager.ClientManager
|
||||
}
|
||||
|
||||
func registerClientResource(prefix string, repo client.ClientRepo) (string, http.Handler) {
|
||||
func registerClientResource(prefix string, manager *manager.ClientManager) (string, http.Handler) {
|
||||
mux := http.NewServeMux()
|
||||
c := &clientResource{
|
||||
repo: repo,
|
||||
manager: manager,
|
||||
}
|
||||
relPath := "clients"
|
||||
absPath := path.Join(prefix, relPath)
|
||||
|
@ -41,7 +40,7 @@ func (c *clientResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (c *clientResource) list(w http.ResponseWriter, r *http.Request) {
|
||||
cs, err := c.repo.All()
|
||||
cs, err := c.manager.All()
|
||||
if err != nil {
|
||||
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients"))
|
||||
return
|
||||
|
@ -88,16 +87,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) {
|
|||
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
clientID, err := oidc.GenClientID(ci.Metadata.RedirectURIs[0].Host)
|
||||
if err != nil {
|
||||
log.Errorf("Failed generating ID for new client: %v", err)
|
||||
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "unable to generate client ID"))
|
||||
return
|
||||
}
|
||||
|
||||
ci.Credentials.ID = clientID
|
||||
creds, err := c.repo.New(ci)
|
||||
creds, err := c.manager.New(ci)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("Failed creating client: %v", err)
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/db"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
|
@ -28,8 +29,10 @@ func makeBody(s string) io.ReadCloser {
|
|||
func TestCreateInvalidRequest(t *testing.T) {
|
||||
u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
|
||||
h := http.Header{"Content-Type": []string{"application/json"}}
|
||||
repo := db.NewClientRepo(db.NewMemDB())
|
||||
res := &clientResource{repo: repo}
|
||||
dbm := db.NewMemDB()
|
||||
repo := db.NewClientRepo(dbm)
|
||||
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
|
||||
res := &clientResource{manager: manager}
|
||||
tests := []struct {
|
||||
req *http.Request
|
||||
wantCode int
|
||||
|
@ -119,8 +122,10 @@ func TestCreateInvalidRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
repo := db.NewClientRepo(db.NewMemDB())
|
||||
res := &clientResource{repo: repo}
|
||||
dbm := db.NewMemDB()
|
||||
repo := db.NewClientRepo(dbm)
|
||||
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
|
||||
res := &clientResource{manager: manager}
|
||||
tests := [][]string{
|
||||
[]string{"http://example.com"},
|
||||
[]string{"https://example.com"},
|
||||
|
@ -190,7 +195,7 @@ func TestList(t *testing.T) {
|
|||
{
|
||||
cs: []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
|
||||
Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "example.com"},
|
||||
|
@ -200,7 +205,7 @@ func TestList(t *testing.T) {
|
|||
},
|
||||
want: []*schema.Client{
|
||||
&schema.Client{
|
||||
Id: "foo",
|
||||
Id: "example.com",
|
||||
RedirectURIs: []string{"http://example.com"},
|
||||
},
|
||||
},
|
||||
|
@ -209,7 +214,7 @@ func TestList(t *testing.T) {
|
|||
{
|
||||
cs: []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
|
||||
Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "example.com"},
|
||||
|
@ -217,21 +222,21 @@ func TestList(t *testing.T) {
|
|||
},
|
||||
},
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{ID: "biz", Secret: b64Encode("bang")},
|
||||
Credentials: oidc.ClientCredentials{ID: "example2.com", Secret: b64Encode("secret")},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"},
|
||||
url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []*schema.Client{
|
||||
&schema.Client{
|
||||
Id: "biz",
|
||||
RedirectURIs: []string{"https://example.com/one/two/three"},
|
||||
Id: "example2.com",
|
||||
RedirectURIs: []string{"https://example2.com/one/two/three"},
|
||||
},
|
||||
&schema.Client{
|
||||
Id: "foo",
|
||||
Id: "example.com",
|
||||
RedirectURIs: []string{"http://example.com"},
|
||||
},
|
||||
},
|
||||
|
@ -239,12 +244,20 @@ func TestList(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), tt.cs)
|
||||
dbm := db.NewMemDB()
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), tt.cs, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
continue
|
||||
}
|
||||
res := &clientResource{repo: repo}
|
||||
res := &clientResource{manager: clientManager}
|
||||
|
||||
r, err := http.NewRequest("GET", "http://example.com/clients", nil)
|
||||
if err != nil {
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
|
@ -114,9 +115,11 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
|
||||
}
|
||||
ciRepo, err := db.NewClientRepoFromClients(dbMap, clients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create client identity repo: %v", err)
|
||||
|
||||
clientRepo := db.NewClientRepo(dbMap)
|
||||
|
||||
for _, c := range clients {
|
||||
clientRepo.New(nil, c)
|
||||
}
|
||||
|
||||
f, err := os.Open(cfg.ConnectorsFile)
|
||||
|
@ -155,7 +158,12 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
|
||||
txnFactory := db.TransactionFactory(dbMap)
|
||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
||||
srv.ClientRepo = ciRepo
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
srv.ClientRepo = clientRepo
|
||||
srv.ClientManager = clientManager
|
||||
srv.KeySetRepo = kRepo
|
||||
srv.ConnectorConfigRepo = cfgRepo
|
||||
srv.UserRepo = userRepo
|
||||
|
@ -253,11 +261,13 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
|||
userRepo := db.NewUserRepo(dbc)
|
||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
|
||||
clientManager := clientmanager.NewClientManager(ciRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
|
||||
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
||||
|
||||
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||
|
||||
srv.ClientRepo = ciRepo
|
||||
srv.ClientManager = clientManager
|
||||
srv.KeySetRepo = kRepo
|
||||
srv.ConnectorConfigRepo = cfgRepo
|
||||
srv.UserRepo = userRepo
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/coreos/go-oidc/oidc"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/user"
|
||||
useremail "github.com/coreos/dex/user/email"
|
||||
|
@ -28,7 +29,7 @@ func handleVerifyEmailResendFunc(
|
|||
srvKeysFunc func() ([]key.PublicKey, error),
|
||||
emailer *useremail.UserEmailer,
|
||||
userRepo user.UserRepo,
|
||||
clientRepo client.ClientRepo) http.HandlerFunc {
|
||||
clientManager *clientmanager.ClientManager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
var params struct {
|
||||
|
@ -57,7 +58,7 @@ func handleVerifyEmailResendFunc(
|
|||
return
|
||||
}
|
||||
|
||||
cm, err := clientRepo.Metadata(clientID)
|
||||
cm, err := clientManager.Metadata(clientID)
|
||||
if err == client.ErrorNotFound {
|
||||
log.Errorf("No such client: %v", err)
|
||||
writeAPIError(w, http.StatusBadRequest,
|
||||
|
|
|
@ -130,7 +130,7 @@ func TestHandleVerifyEmailResend(t *testing.T) {
|
|||
keysFunc,
|
||||
f.srv.UserEmailer,
|
||||
f.userRepo,
|
||||
f.clientRepo)
|
||||
f.clientManager)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
u := "http://example.com"
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/session/manager"
|
||||
|
@ -75,28 +76,37 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
idpcs := []connector.Connector{
|
||||
&fakeConnector{loginURL: "http://fake.example.com"},
|
||||
}
|
||||
dbm := db.NewMemDB()
|
||||
clients := []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "client.example.com",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||
ClientRepo: func() client.ClientRepo {
|
||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}(),
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -108,7 +118,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"client_id": []string{"XXX"},
|
||||
"client_id": []string{"client.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
|
@ -121,7 +131,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"client_id": []string{"client.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
|
@ -134,7 +144,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"client_id": []string{"client.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
|
@ -157,7 +167,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"token"},
|
||||
"client_id": []string{"XXX"},
|
||||
"client_id": []string{"client.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
|
@ -170,11 +180,33 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"client_id": []string{"client.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
// empty response_type
|
||||
{
|
||||
query: url.Values{
|
||||
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||
"client_id": []string{"client.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
wantLocation: "http://client.example.com/callback?error=unsupported_response_type&state=",
|
||||
},
|
||||
|
||||
// empty client_id
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
@ -204,29 +236,39 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
idpcs := []connector.Connector{
|
||||
&fakeConnector{loginURL: "http://fake.example.com"},
|
||||
}
|
||||
|
||||
dbm := db.NewMemDB()
|
||||
clients := []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "foo.example.com",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "foo.example.com", Path: "/callback"},
|
||||
url.URL{Scheme: "http", Host: "bar.example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
srv := &Server{
|
||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||
ClientRepo: func() client.ClientRepo {
|
||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "foo.example.com", Path: "/callback"},
|
||||
url.URL{Scheme: "http", Host: "bar.example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}(),
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -239,7 +281,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"http://foo.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"client_id": []string{"foo.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
|
@ -252,7 +294,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"http://bar.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"client_id": []string{"foo.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
|
@ -265,7 +307,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"client_id": []string{"foo.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
|
@ -276,7 +318,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"client_id": []string{"XXX"},
|
||||
"client_id": []string{"foo.example.com"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
|
@ -328,8 +370,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
|||
"grant_type": []string{"invalid!"},
|
||||
"code": []string{"someCode"},
|
||||
},
|
||||
user: "XXX",
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
user: testClientID,
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
|
||||
|
@ -338,8 +380,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
|||
query: url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
},
|
||||
user: "XXX",
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
user: testClientID,
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
|
||||
|
@ -349,8 +391,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
|||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{""},
|
||||
},
|
||||
user: "XXX",
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
user: testClientID,
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
|
||||
|
@ -371,8 +413,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
|||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{"asdasd"},
|
||||
},
|
||||
user: "XXX",
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
user: testClientID,
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
|
||||
|
@ -382,8 +424,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
|||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{"code-2"},
|
||||
},
|
||||
user: "XXX",
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
user: testClientID,
|
||||
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
wantCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
@ -402,7 +444,7 @@ func TestHandleTokenFunc(t *testing.T) {
|
|||
|
||||
// need to create session in order to exchange the code (generated by the NewSessionKey func) for token
|
||||
setSession := func() error {
|
||||
sid, err := fx.sessionManager.NewSession("local", "XXX", "", testRedirectURL, "", true, []string{"openid"})
|
||||
sid, err := fx.sessionManager.NewSession("local", testClientID, "", testRedirectURL, "", true, []string{"openid"})
|
||||
if err != nil {
|
||||
return fmt.Errorf("case %d: cannot create session, error=%v", i, err)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/coreos/go-oidc/key"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
sessionmanager "github.com/coreos/dex/session/manager"
|
||||
"github.com/coreos/dex/user"
|
||||
|
@ -29,7 +30,7 @@ type SendResetPasswordEmailHandler struct {
|
|||
tpl *template.Template
|
||||
emailer *useremail.UserEmailer
|
||||
sm *sessionmanager.SessionManager
|
||||
cr client.ClientRepo
|
||||
cm *clientmanager.ClientManager
|
||||
}
|
||||
|
||||
func (h *SendResetPasswordEmailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -128,7 +129,7 @@ func (h *SendResetPasswordEmailHandler) validateRedirectURL(clientID string, red
|
|||
return url.URL{}, false
|
||||
}
|
||||
|
||||
cm, err := h.cr.Metadata(clientID)
|
||||
cm, err := h.cm.Metadata(clientID)
|
||||
if err != nil || cm == nil {
|
||||
log.Errorf("Error getting ClientMetadata: %v", err)
|
||||
return url.URL{}, false
|
||||
|
|
|
@ -253,7 +253,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
|
|||
t.Fatalf("case %d: could not make test fixtures: %v", i, err)
|
||||
}
|
||||
|
||||
_, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, []string{"openid"})
|
||||
_, err = f.srv.NewSession("local", testClientID, "", f.redirectURL, "", true, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: could not create new session: %v", i, err)
|
||||
}
|
||||
|
@ -267,7 +267,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
|
|||
tpl: f.srv.SendResetPasswordEmailTemplate,
|
||||
emailer: f.srv.UserEmailer,
|
||||
sm: f.sessionManager,
|
||||
cr: f.clientRepo,
|
||||
cm: f.clientManager,
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
|
|
@ -295,7 +295,7 @@ func TestHandleRegister(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"})
|
||||
key, err := f.srv.NewSession(tt.connID, testClientID, "", f.redirectURL, "", true, []string{"openid"})
|
||||
t.Logf("case %d: key for NewSession: %v", i, key)
|
||||
|
||||
if tt.attachRemote {
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
|
@ -72,6 +73,7 @@ type Server struct {
|
|||
Connectors []connector.Connector
|
||||
UserRepo user.UserRepo
|
||||
UserManager *usermanager.UserManager
|
||||
ClientManager *clientmanager.ClientManager
|
||||
PasswordInfoRepo user.PasswordInfoRepo
|
||||
RefreshTokenRepo refresh.RefreshTokenRepo
|
||||
UserEmailer *useremail.UserEmailer
|
||||
|
@ -213,13 +215,13 @@ func (s *Server) HTTPHandler() http.Handler {
|
|||
s.KeyManager.PublicKeys,
|
||||
s.UserEmailer,
|
||||
s.UserRepo,
|
||||
s.ClientRepo)))
|
||||
s.ClientManager)))
|
||||
|
||||
mux.Handle(httpPathSendResetPassword, &SendResetPasswordEmailHandler{
|
||||
tpl: s.SendResetPasswordEmailTemplate,
|
||||
emailer: s.UserEmailer,
|
||||
sm: s.SessionManager,
|
||||
cr: s.ClientRepo,
|
||||
cm: s.ClientManager,
|
||||
})
|
||||
|
||||
mux.Handle(httpPathResetPassword, &ResetPasswordHandler{
|
||||
|
@ -256,11 +258,11 @@ func (s *Server) HTTPHandler() http.Handler {
|
|||
apiBasePath := path.Join(httpPathAPI, APIVersion)
|
||||
registerDiscoveryResource(apiBasePath, mux)
|
||||
|
||||
clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientRepo)
|
||||
clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientManager)
|
||||
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
|
||||
|
||||
usersAPI := usersapi.NewUsersAPI(s.dbMap, s.UserManager, s.UserEmailer, s.localConnectorID)
|
||||
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientRepo).HTTPHandler()
|
||||
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientManager, s.RefreshTokenRepo, s.UserEmailer, s.localConnectorID)
|
||||
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientManager).HTTPHandler()
|
||||
|
||||
mux.Handle(apiBasePath+"/", handler)
|
||||
|
||||
|
@ -271,14 +273,14 @@ func (s *Server) HTTPHandler() http.Handler {
|
|||
func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler {
|
||||
return &clientTokenMiddleware{
|
||||
issuerURL: s.IssuerURL.String(),
|
||||
ciRepo: s.ClientRepo,
|
||||
ciManager: s.ClientManager,
|
||||
keysFunc: s.KeyManager.PublicKeys,
|
||||
next: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||
return s.ClientRepo.Metadata(clientID)
|
||||
return s.ClientManager.Metadata(clientID)
|
||||
}
|
||||
|
||||
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
|
||||
|
@ -365,9 +367,9 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
|||
}
|
||||
|
||||
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
|
||||
ok, err := s.ClientRepo.Authenticate(creds)
|
||||
ok, err := s.ClientManager.Authenticate(creds)
|
||||
if err != nil {
|
||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
|
||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
if !ok {
|
||||
|
@ -397,7 +399,7 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro
|
|||
}
|
||||
|
||||
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
|
||||
ok, err := s.ClientRepo.Authenticate(creds)
|
||||
ok, err := s.ClientManager.Authenticate(creds)
|
||||
if err != nil {
|
||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||
|
@ -466,7 +468,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
|||
}
|
||||
|
||||
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
|
||||
ok, err := s.ClientRepo.Authenticate(creds)
|
||||
ok, err := s.ClientManager.Authenticate(creds)
|
||||
if err != nil {
|
||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/refresh/refreshtest"
|
||||
"github.com/coreos/dex/session/manager"
|
||||
|
@ -21,7 +22,12 @@ import (
|
|||
"github.com/kylelemons/godebug/pretty"
|
||||
)
|
||||
|
||||
var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secrete"))
|
||||
var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
|
||||
var validRedirURL = url.URL{
|
||||
Scheme: "http",
|
||||
Host: "client.example.com",
|
||||
Path: "/callback",
|
||||
}
|
||||
|
||||
type StaticKeyManager struct {
|
||||
key.PrivateKeyManager
|
||||
|
@ -132,8 +138,8 @@ func TestServerNewSession(t *testing.T) {
|
|||
nonce := "oncenay"
|
||||
ci := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: "secrete",
|
||||
ID: testClientID,
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -181,7 +187,7 @@ func TestServerNewSession(t *testing.T) {
|
|||
func TestServerLogin(t *testing.T) {
|
||||
ci := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
ID: testClientID,
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
|
@ -194,13 +200,13 @@ func TestServerLogin(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
ciRepo := func() client.ClientRepo {
|
||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
dbm := db.NewMemDB()
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), []client.Client{ci}, clientmanager.ManagerOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
|
@ -222,7 +228,8 @@ func TestServerLogin(t *testing.T) {
|
|||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
SessionManager: sm,
|
||||
ClientRepo: ciRepo,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
UserRepo: userRepo,
|
||||
}
|
||||
|
||||
|
@ -244,20 +251,30 @@ func TestServerLogin(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
||||
ciRepo := func() client.ClientRepo {
|
||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX", Secret: clientTestSecret,
|
||||
clients := []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: testClientID, Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
validRedirURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
},
|
||||
}
|
||||
dbm := db.NewMemDB()
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
||||
}
|
||||
|
@ -266,11 +283,12 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
|||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
SessionManager: sm,
|
||||
ClientRepo: ciRepo,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
}
|
||||
|
||||
ident := oidc.Identity{ID: "YYY", Name: "elroy", Email: "elroy@example.com"}
|
||||
code, err := srv.Login(ident, "XXX")
|
||||
code, err := srv.Login(ident, testClientID)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected non-nil error")
|
||||
}
|
||||
|
@ -283,27 +301,28 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
|||
func TestServerLoginDisabledUser(t *testing.T) {
|
||||
ci := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
ID: testClientID,
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{
|
||||
Scheme: "http",
|
||||
Host: "client.example.com",
|
||||
Path: "/callback",
|
||||
},
|
||||
validRedirURL,
|
||||
},
|
||||
},
|
||||
}
|
||||
ciRepo := func() client.ClientRepo {
|
||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
|
||||
clients := []client.Client{ci}
|
||||
dbm := db.NewMemDB()
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
|
@ -338,7 +357,8 @@ func TestServerLoginDisabledUser(t *testing.T) {
|
|||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
SessionManager: sm,
|
||||
ClientRepo: ciRepo,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
UserRepo: userRepo,
|
||||
}
|
||||
|
||||
|
@ -357,17 +377,28 @@ func TestServerLoginDisabledUser(t *testing.T) {
|
|||
func TestServerCodeToken(t *testing.T) {
|
||||
ci := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
ID: testClientID,
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
validRedirURL,
|
||||
},
|
||||
},
|
||||
}
|
||||
clients := []client.Client{ci}
|
||||
dbm := db.NewMemDB()
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
ciRepo := func() client.ClientRepo {
|
||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
|
@ -384,7 +415,8 @@ func TestServerCodeToken(t *testing.T) {
|
|||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
SessionManager: sm,
|
||||
ClientRepo: ciRepo,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
UserRepo: userRepo,
|
||||
RefreshTokenRepo: refreshTokenRepo,
|
||||
}
|
||||
|
@ -443,17 +475,29 @@ func TestServerCodeToken(t *testing.T) {
|
|||
func TestServerTokenUnrecognizedKey(t *testing.T) {
|
||||
ci := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
ID: testClientID,
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
validRedirURL,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
clients := []client.Client{ci}
|
||||
dbm := db.NewMemDB()
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
ciRepo := func() client.ClientRepo {
|
||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}()
|
||||
km := &StaticKeyManager{
|
||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||
}
|
||||
|
@ -463,7 +507,8 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
|
|||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||
KeyManager: km,
|
||||
SessionManager: sm,
|
||||
ClientRepo: ciRepo,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
}
|
||||
|
||||
sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
|
||||
|
@ -492,7 +537,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
||||
keyFixture := "goodkey"
|
||||
ccFixture := oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
ID: testClientID,
|
||||
Secret: clientTestSecret,
|
||||
}
|
||||
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
||||
|
@ -569,14 +614,29 @@ func TestServerTokenFail(t *testing.T) {
|
|||
km := &StaticKeyManager{
|
||||
signer: tt.signer,
|
||||
}
|
||||
ciRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
||||
client.Client{Credentials: ccFixture},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
clients := []client.Client{
|
||||
client.Client{
|
||||
Credentials: ccFixture,
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
validRedirURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dbm := db.NewMemDB()
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
_, err = sm.AttachUser(sessionID, "testid-1")
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
|
@ -593,7 +653,8 @@ func TestServerTokenFail(t *testing.T) {
|
|||
IssuerURL: issuerURL,
|
||||
KeyManager: km,
|
||||
SessionManager: sm,
|
||||
ClientRepo: ciRepo,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
UserRepo: userRepo,
|
||||
RefreshTokenRepo: refreshTokenRepo,
|
||||
}
|
||||
|
@ -623,14 +684,27 @@ func TestServerTokenFail(t *testing.T) {
|
|||
|
||||
func TestServerRefreshToken(t *testing.T) {
|
||||
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
||||
|
||||
credXXX := oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: clientTestSecret,
|
||||
clientA := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: testClientID,
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "https", Host: "client.example.com", Path: "one/two/three"},
|
||||
},
|
||||
},
|
||||
}
|
||||
credYYY := oidc.ClientCredentials{
|
||||
ID: "YYY",
|
||||
Secret: clientTestSecret,
|
||||
clientB := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "example2.com",
|
||||
Secret: clientTestSecret,
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
||||
|
@ -647,47 +721,47 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
// Everything is good.
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
clientA.Credentials.ID,
|
||||
clientA.Credentials,
|
||||
signerFixture,
|
||||
nil,
|
||||
},
|
||||
// Invalid refresh token(malformatted).
|
||||
{
|
||||
"invalid-token",
|
||||
"XXX",
|
||||
credXXX,
|
||||
clientA.Credentials.ID,
|
||||
clientA.Credentials,
|
||||
signerFixture,
|
||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||
},
|
||||
// Invalid refresh token(invalid payload content).
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
clientA.Credentials.ID,
|
||||
clientA.Credentials,
|
||||
signerFixture,
|
||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||
},
|
||||
// Invalid refresh token(invalid ID content).
|
||||
{
|
||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
clientA.Credentials.ID,
|
||||
clientA.Credentials,
|
||||
signerFixture,
|
||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||
},
|
||||
// Invalid client(client is not associated with the token).
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credYYY,
|
||||
clientA.Credentials.ID,
|
||||
clientB.Credentials,
|
||||
signerFixture,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
},
|
||||
// Invalid client(no client ID).
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
clientA.Credentials.ID,
|
||||
oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
||||
signerFixture,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
|
@ -695,7 +769,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
// Invalid client(no such client).
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
clientA.Credentials.ID,
|
||||
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
||||
signerFixture,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
|
@ -703,24 +777,24 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
// Invalid client(no secrets).
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "XXX"},
|
||||
clientA.Credentials.ID,
|
||||
oidc.ClientCredentials{ID: testClientID},
|
||||
signerFixture,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
},
|
||||
// Invalid client(invalid secret).
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
|
||||
clientA.Credentials.ID,
|
||||
oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
|
||||
signerFixture,
|
||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||
},
|
||||
// Signing operation fails.
|
||||
{
|
||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||
"XXX",
|
||||
credXXX,
|
||||
clientA.Credentials.ID,
|
||||
clientA.Credentials,
|
||||
&StaticSigner{sig: nil, err: errors.New("fail")},
|
||||
oauth2.NewError(oauth2.ErrorServerError),
|
||||
},
|
||||
|
@ -731,15 +805,23 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
signer: tt.signer,
|
||||
}
|
||||
|
||||
ciRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
||||
client.Client{Credentials: credXXX},
|
||||
client.Client{Credentials: credYYY},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
|
||||
continue
|
||||
clients := []client.Client{
|
||||
clientA,
|
||||
clientB,
|
||||
}
|
||||
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
dbm := db.NewMemDB()
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
userRepo, err := makeNewUserRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
|
@ -750,7 +832,8 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
srv := &Server{
|
||||
IssuerURL: issuerURL,
|
||||
KeyManager: km,
|
||||
ClientRepo: ciRepo,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
UserRepo: userRepo,
|
||||
RefreshTokenRepo: refreshTokenRepo,
|
||||
}
|
||||
|
@ -772,7 +855,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("Case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != "XXX" {
|
||||
if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != testClientID {
|
||||
t.Errorf("Case %d: invalid claims: %v", i, claims)
|
||||
}
|
||||
}
|
||||
|
@ -784,14 +867,22 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
signer: signerFixture,
|
||||
}
|
||||
|
||||
ciRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
||||
client.Client{Credentials: credXXX},
|
||||
client.Client{Credentials: credYYY},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create client identity repo: %v", err)
|
||||
clients := []client.Client{
|
||||
clientA,
|
||||
clientB,
|
||||
}
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
dbm := db.NewMemDB()
|
||||
clientRepo := db.NewClientRepo(dbm)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||
}
|
||||
|
||||
userRepo, err := makeNewUserRepo()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
|
@ -810,12 +901,13 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
srv := &Server{
|
||||
IssuerURL: issuerURL,
|
||||
KeyManager: km,
|
||||
ClientRepo: ciRepo,
|
||||
ClientRepo: clientRepo,
|
||||
ClientManager: clientManager,
|
||||
UserRepo: userRepo,
|
||||
RefreshTokenRepo: refreshTokenRepo,
|
||||
}
|
||||
|
||||
if _, err := refreshTokenRepo.Create("testid-2", credXXX.ID); err != nil {
|
||||
if _, err := refreshTokenRepo.Create("testid-2", clientA.Credentials.ID); err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
|
@ -826,7 +918,7 @@ func TestServerRefreshToken(t *testing.T) {
|
|||
}
|
||||
srv.UserRepo = userRepo
|
||||
|
||||
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
||||
_, err = srv.RefreshToken(clientA.Credentials, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
||||
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
|
||||
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/coreos/go-oidc/oidc"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
"github.com/coreos/dex/email"
|
||||
|
@ -26,7 +27,7 @@ const (
|
|||
|
||||
var (
|
||||
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
|
||||
testClientID = "XXX"
|
||||
testClientID = "client.example.com"
|
||||
|
||||
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
|
||||
|
||||
|
@ -79,6 +80,7 @@ type testFixtures struct {
|
|||
emailer *email.TemplatizedEmailer
|
||||
redirectURL url.URL
|
||||
clientRepo client.ClientRepo
|
||||
clientManager *clientmanager.ClientManager
|
||||
}
|
||||
|
||||
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
|
||||
|
@ -123,7 +125,7 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
|
||||
userManager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
|
||||
|
||||
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||
|
@ -136,11 +138,11 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
clientRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
||||
clients := []client.Client{
|
||||
client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
ID: testClientID,
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -148,11 +150,19 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbMap)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
km := key.NewPrivateKeyManager()
|
||||
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)))
|
||||
if err != nil {
|
||||
|
@ -173,7 +183,8 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
Templates: tpl,
|
||||
UserRepo: userRepo,
|
||||
PasswordInfoRepo: pwRepo,
|
||||
UserManager: manager,
|
||||
UserManager: userManager,
|
||||
ClientManager: clientManager,
|
||||
KeyManager: km,
|
||||
}
|
||||
|
||||
|
@ -207,5 +218,6 @@ func makeTestFixtures() (*testFixtures, error) {
|
|||
sessionManager: sessionManager,
|
||||
emailer: emailer,
|
||||
clientRepo: clientRepo,
|
||||
clientManager: clientManager,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -11,12 +11,12 @@ import (
|
|||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/api"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -38,16 +38,16 @@ var (
|
|||
type UserMgmtServer struct {
|
||||
api *api.UsersAPI
|
||||
jwtvFactory JWTVerifierFactory
|
||||
um *manager.UserManager
|
||||
cir client.ClientRepo
|
||||
um *usermanager.UserManager
|
||||
cm *clientmanager.ClientManager
|
||||
}
|
||||
|
||||
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *manager.UserManager, cir client.ClientRepo) *UserMgmtServer {
|
||||
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *usermanager.UserManager, cm *clientmanager.ClientManager) *UserMgmtServer {
|
||||
return &UserMgmtServer{
|
||||
api: userMgmtAPI,
|
||||
jwtvFactory: jwtvFactory,
|
||||
um: um,
|
||||
cir: cir,
|
||||
cm: cm,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -295,7 +295,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request, requiresAdmin bool) (api.Cred
|
|||
return api.Creds{}, err
|
||||
}
|
||||
|
||||
isAdmin, err := s.cir.IsDexAdmin(clientID)
|
||||
isAdmin, err := s.cm.IsDexAdmin(clientID)
|
||||
if err != nil {
|
||||
log.Errorf("userMgmtServer: GetCreds err: %q", err)
|
||||
return api.Creds{}, err
|
||||
|
|
2
test
2
test
|
@ -18,7 +18,7 @@ if [ ! -d $GOPATH/pkg ]; then
|
|||
echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests."
|
||||
fi
|
||||
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin"
|
||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin client client/manager"
|
||||
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
||||
|
||||
# user has not provided PKG override
|
||||
|
|
|
@ -9,15 +9,13 @@ import (
|
|||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/go-gorp/gorp"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
"github.com/coreos/dex/db"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/pkg/log"
|
||||
"github.com/coreos/dex/refresh"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
"github.com/coreos/dex/user"
|
||||
"github.com/coreos/dex/user/manager"
|
||||
usermanager "github.com/coreos/dex/user/manager"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -88,9 +86,9 @@ func (e Error) Error() string {
|
|||
// calling User. It is assumed that the clientID has already validated as an
|
||||
// admin app before calling.
|
||||
type UsersAPI struct {
|
||||
manager *manager.UserManager
|
||||
userManager *usermanager.UserManager
|
||||
localConnectorID string
|
||||
clientRepo client.ClientRepo
|
||||
clientManager *clientmanager.ClientManager
|
||||
refreshRepo refresh.RefreshTokenRepo
|
||||
emailer Emailer
|
||||
}
|
||||
|
@ -105,11 +103,11 @@ type Creds struct {
|
|||
}
|
||||
|
||||
// TODO(ericchiang): Don't pass a dbMap. See #385.
|
||||
func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI {
|
||||
func NewUsersAPI(userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, refreshRepo refresh.RefreshTokenRepo, emailer Emailer, localConnectorID string) *UsersAPI {
|
||||
return &UsersAPI{
|
||||
manager: userManager,
|
||||
refreshRepo: db.NewRefreshTokenRepo(dbMap),
|
||||
clientRepo: db.NewClientRepo(dbMap),
|
||||
userManager: userManager,
|
||||
refreshRepo: refreshRepo,
|
||||
clientManager: clientManager,
|
||||
localConnectorID: localConnectorID,
|
||||
emailer: emailer,
|
||||
}
|
||||
|
@ -122,7 +120,7 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
|
|||
return schema.User{}, ErrorUnauthorized
|
||||
}
|
||||
|
||||
usr, err := u.manager.Get(id)
|
||||
usr, err := u.userManager.Get(id)
|
||||
|
||||
if err != nil {
|
||||
return schema.User{}, mapError(err)
|
||||
|
@ -137,7 +135,7 @@ func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema
|
|||
return schema.UserDisableResponse{}, ErrorUnauthorized
|
||||
}
|
||||
|
||||
if err := u.manager.Disable(userID, disable); err != nil {
|
||||
if err := u.userManager.Disable(userID, disable); err != nil {
|
||||
return schema.UserDisableResponse{}, mapError(err)
|
||||
}
|
||||
|
||||
|
@ -157,7 +155,7 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
|
|||
return schema.UserCreateResponse{}, mapError(err)
|
||||
}
|
||||
|
||||
metadata, err := u.clientRepo.Metadata(creds.ClientID)
|
||||
metadata, err := u.clientManager.Metadata(creds.ClientID)
|
||||
if err != nil {
|
||||
return schema.UserCreateResponse{}, mapError(err)
|
||||
}
|
||||
|
@ -167,12 +165,12 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
|
|||
return schema.UserCreateResponse{}, ErrorInvalidRedirectURL
|
||||
}
|
||||
|
||||
id, err := u.manager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID)
|
||||
id, err := u.userManager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID)
|
||||
if err != nil {
|
||||
return schema.UserCreateResponse{}, mapError(err)
|
||||
}
|
||||
|
||||
userUser, err := u.manager.Get(id)
|
||||
userUser, err := u.userManager.Get(id)
|
||||
if err != nil {
|
||||
return schema.UserCreateResponse{}, mapError(err)
|
||||
}
|
||||
|
@ -202,7 +200,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
|
|||
return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized
|
||||
}
|
||||
|
||||
metadata, err := u.clientRepo.Metadata(creds.ClientID)
|
||||
metadata, err := u.clientManager.Metadata(creds.ClientID)
|
||||
if err != nil {
|
||||
return schema.ResendEmailInvitationResponse{}, mapError(err)
|
||||
}
|
||||
|
@ -213,7 +211,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
|
|||
}
|
||||
|
||||
// Retrieve user to check if it's already created
|
||||
userUser, err := u.manager.Get(userID)
|
||||
userUser, err := u.userManager.Get(userID)
|
||||
if err != nil {
|
||||
return schema.ResendEmailInvitationResponse{}, mapError(err)
|
||||
}
|
||||
|
@ -251,7 +249,7 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string)
|
|||
return nil, "", ErrorMaxResultsTooHigh
|
||||
}
|
||||
|
||||
users, tok, err := u.manager.List(user.UserFilter{}, maxResults, nextPageToken)
|
||||
users, tok, err := u.userManager.List(user.UserFilter{}, maxResults, nextPageToken)
|
||||
if err != nil {
|
||||
return nil, "", mapError(err)
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/dex/client"
|
||||
clientmanager "github.com/coreos/dex/client/manager"
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/db"
|
||||
schema "github.com/coreos/dex/schema/workerschema"
|
||||
|
@ -50,14 +51,15 @@ func (t *testEmailer) sendEmail(email string, redirectURL url.URL, clientID stri
|
|||
}
|
||||
|
||||
var (
|
||||
clock = clockwork.NewFakeClock()
|
||||
clock = clockwork.NewFakeClock()
|
||||
goodClientID = "client.example.com"
|
||||
|
||||
goodCreds = Creds{
|
||||
User: user.User{
|
||||
ID: "ID-1",
|
||||
Admin: true,
|
||||
},
|
||||
ClientID: "XXX",
|
||||
ClientID: goodClientID,
|
||||
}
|
||||
|
||||
badCreds = Creds{
|
||||
|
@ -72,7 +74,7 @@ var (
|
|||
Admin: true,
|
||||
Disabled: true,
|
||||
},
|
||||
ClientID: "XXX",
|
||||
ClientID: goodClientID,
|
||||
}
|
||||
|
||||
resetPasswordURL = url.URL{
|
||||
|
@ -82,7 +84,7 @@ var (
|
|||
|
||||
validRedirURL = url.URL{
|
||||
Scheme: "http",
|
||||
Host: "client.example.com",
|
||||
Host: goodClientID,
|
||||
Path: "/callback",
|
||||
}
|
||||
)
|
||||
|
@ -158,8 +160,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
mgr.Clock = clock
|
||||
ci := client.Client{
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: "XXX",
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||
ID: goodClientID,
|
||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||
},
|
||||
Metadata: oidc.ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
|
@ -167,8 +169,17 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
},
|
||||
},
|
||||
}
|
||||
if _, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci}); err != nil {
|
||||
panic("Failed to create client repo: " + err.Error())
|
||||
|
||||
clientIDGenerator := func(hostport string) (string, error) {
|
||||
return hostport, nil
|
||||
}
|
||||
secGen := func() ([]byte, error) {
|
||||
return []byte("secret"), nil
|
||||
}
|
||||
clientRepo := db.NewClientRepo(dbMap)
|
||||
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||
if err != nil {
|
||||
panic("Failed to create client manager: " + err.Error())
|
||||
}
|
||||
|
||||
// Used in TestRevokeRefreshToken test.
|
||||
|
@ -176,8 +187,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
clientID string
|
||||
userID string
|
||||
}{
|
||||
{"XXX", "ID-1"},
|
||||
{"XXX", "ID-2"},
|
||||
{goodClientID, "ID-1"},
|
||||
{goodClientID, "ID-2"},
|
||||
}
|
||||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||
for _, token := range refreshTokens {
|
||||
|
@ -187,7 +198,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
|||
}
|
||||
|
||||
emailer := &testEmailer{}
|
||||
api := NewUsersAPI(dbMap, mgr, emailer, "local")
|
||||
api := NewUsersAPI(mgr, clientManager, refreshRepo, emailer, "local")
|
||||
return api, emailer
|
||||
|
||||
}
|
||||
|
@ -582,8 +593,8 @@ func TestRevokeRefreshToken(t *testing.T) {
|
|||
before []string // clientIDs expected before the change.
|
||||
after []string // clientIDs expected after the change.
|
||||
}{
|
||||
{"ID-1", "XXX", []string{"XXX"}, []string{}},
|
||||
{"ID-2", "XXX", []string{"XXX"}, []string{}},
|
||||
{"ID-1", goodClientID, []string{goodClientID}, []string{}},
|
||||
{"ID-2", goodClientID, []string{goodClientID}, []string{}},
|
||||
}
|
||||
|
||||
api, _ := makeTestFixtures()
|
||||
|
|
Loading…
Reference in a new issue