Merge pull request #442 from ecordell/client-manager
Adds client manager
This commit is contained in:
commit
a846016ceb
37 changed files with 1122 additions and 690 deletions
26
admin/api.go
26
admin/api.go
|
@ -1,31 +1,27 @@
|
||||||
// package admin provides an implementation of the API described in auth/schema/adminschema.
|
// Package admin provides an implementation of the API described in auth/schema/adminschema.
|
||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/oidc"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/schema/adminschema"
|
"github.com/coreos/dex/schema/adminschema"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/dex/user/manager"
|
usermanager "github.com/coreos/dex/user/manager"
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ClientIDGenerator = oidc.GenClientID
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// AdminAPI provides the logic necessary to implement the Admin API.
|
// AdminAPI provides the logic necessary to implement the Admin API.
|
||||||
type AdminAPI struct {
|
type AdminAPI struct {
|
||||||
userManager *manager.UserManager
|
userManager *usermanager.UserManager
|
||||||
userRepo user.UserRepo
|
userRepo user.UserRepo
|
||||||
passwordInfoRepo user.PasswordInfoRepo
|
passwordInfoRepo user.PasswordInfoRepo
|
||||||
clientRepo client.ClientRepo
|
clientRepo client.ClientRepo
|
||||||
|
clientManager *clientmanager.ClientManager
|
||||||
localConnectorID string
|
localConnectorID string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *manager.UserManager, localConnectorID string) *AdminAPI {
|
func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRepo client.ClientRepo, userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, localConnectorID string) *AdminAPI {
|
||||||
if localConnectorID == "" {
|
if localConnectorID == "" {
|
||||||
panic("must specify non-blank localConnectorID")
|
panic("must specify non-blank localConnectorID")
|
||||||
}
|
}
|
||||||
|
@ -34,6 +30,7 @@ func NewAdminAPI(userRepo user.UserRepo, pwiRepo user.PasswordInfoRepo, clientRe
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
passwordInfoRepo: pwiRepo,
|
passwordInfoRepo: pwiRepo,
|
||||||
clientRepo: clientRepo,
|
clientRepo: clientRepo,
|
||||||
|
clientManager: clientManager,
|
||||||
localConnectorID: localConnectorID,
|
localConnectorID: localConnectorID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -141,14 +138,7 @@ func (a *AdminAPI) CreateClient(req adminschema.ClientCreateRequest) (adminschem
|
||||||
}
|
}
|
||||||
|
|
||||||
// metadata is guaranteed to have at least one redirect_uri by earlier validation.
|
// metadata is guaranteed to have at least one redirect_uri by earlier validation.
|
||||||
id, err := ClientIDGenerator(cli.Metadata.RedirectURIs[0].Host)
|
creds, err := a.clientManager.New(cli)
|
||||||
if err != nil {
|
|
||||||
return adminschema.ClientCreateResponse{}, mapError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cli.Credentials.ID = id
|
|
||||||
|
|
||||||
creds, err := a.clientRepo.New(cli)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return adminschema.ClientCreateResponse{}, mapError(err)
|
return adminschema.ClientCreateResponse{}, mapError(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/schema/adminschema"
|
"github.com/coreos/dex/schema/adminschema"
|
||||||
|
@ -17,6 +18,7 @@ type testFixtures struct {
|
||||||
ur user.UserRepo
|
ur user.UserRepo
|
||||||
pwr user.PasswordInfoRepo
|
pwr user.PasswordInfoRepo
|
||||||
cr client.ClientRepo
|
cr client.ClientRepo
|
||||||
|
cm *clientmanager.ClientManager
|
||||||
mgr *manager.UserManager
|
mgr *manager.UserManager
|
||||||
adAPI *AdminAPI
|
adAPI *AdminAPI
|
||||||
}
|
}
|
||||||
|
@ -71,7 +73,8 @@ func makeTestFixtures() *testFixtures {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
f.mgr = manager.NewUserManager(f.ur, f.pwr, ccr, db.TransactionFactory(dbMap), manager.ManagerOptions{})
|
||||||
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, "local")
|
f.cm = clientmanager.NewClientManager(f.cr, db.TransactionFactory(dbMap), clientmanager.ManagerOptions{})
|
||||||
|
f.adAPI = NewAdminAPI(f.ur, f.pwr, f.cr, f.mgr, f.cm, "local")
|
||||||
|
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/repo"
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,6 +21,24 @@ var (
|
||||||
ErrorNotFound = errors.New("no data found")
|
ErrorNotFound = errors.New("no data found")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bcryptHashCost = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
func HashSecret(creds oidc.ClientCredentials) ([]byte, error) {
|
||||||
|
secretBytes, err := base64.URLEncoding.DecodeString(creds.Secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hashed, err := bcrypt.GenerateFromPassword([]byte(
|
||||||
|
secretBytes),
|
||||||
|
bcryptHashCost)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return hashed, nil
|
||||||
|
}
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
Credentials oidc.ClientCredentials
|
Credentials oidc.ClientCredentials
|
||||||
Metadata oidc.ClientMetadata
|
Metadata oidc.ClientMetadata
|
||||||
|
@ -24,30 +46,20 @@ type Client struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientRepo interface {
|
type ClientRepo interface {
|
||||||
Get(clientID string) (Client, error)
|
Get(tx repo.Transaction, clientID string) (Client, error)
|
||||||
|
|
||||||
// Metadata returns one matching ClientMetadata if the given client
|
// GetSecret returns the (base64 encoded) hashed client secret
|
||||||
// exists, otherwise nil. The returned error will be non-nil only
|
GetSecret(tx repo.Transaction, clientID string) ([]byte, error)
|
||||||
// if the repo was unable to determine client existence.
|
|
||||||
Metadata(clientID string) (*oidc.ClientMetadata, error)
|
|
||||||
|
|
||||||
// Authenticate asserts that a client with the given ID exists and
|
|
||||||
// that the provided secret matches. If either of these assertions
|
|
||||||
// fail, (false, nil) will be returned. Only if the repo is unable
|
|
||||||
// to make these assertions will a non-nil error be returned.
|
|
||||||
Authenticate(creds oidc.ClientCredentials) (bool, error)
|
|
||||||
|
|
||||||
// All returns all registered Clients
|
// All returns all registered Clients
|
||||||
All() ([]Client, error)
|
All(tx repo.Transaction) ([]Client, error)
|
||||||
|
|
||||||
// New registers a Client with the repo.
|
// New registers a Client with the repo.
|
||||||
// An unused ID must be provided. A corresponding secret will be returned
|
// An unused ID must be provided. A corresponding secret will be returned
|
||||||
// in a ClientCredentials struct along with the provided ID.
|
// in a ClientCredentials struct along with the provided ID.
|
||||||
New(client Client) (*oidc.ClientCredentials, error)
|
New(tx repo.Transaction, client Client) (*oidc.ClientCredentials, error)
|
||||||
|
|
||||||
SetDexAdmin(clientID string, isAdmin bool) error
|
Update(tx repo.Transaction, client Client) error
|
||||||
|
|
||||||
IsDexAdmin(clientID string) (bool, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.
|
// ValidRedirectURL returns the passed in URL if it is present in the redirectURLs list, and returns an error otherwise.
|
||||||
|
|
|
@ -34,7 +34,7 @@ var (
|
||||||
|
|
||||||
badSecretClient = `{
|
badSecretClient = `{
|
||||||
"id": "my_id",
|
"id": "my_id",
|
||||||
"secret": "` + "****" + `",
|
"secret": "` + "" + `",
|
||||||
"redirectURLs": ["https://client.example.com"]
|
"redirectURLs": ["https://client.example.com"]
|
||||||
}`
|
}`
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ func TestClientsFromReader(t *testing.T) {
|
||||||
{
|
{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "my_id",
|
ID: "my_id",
|
||||||
Secret: "my_secret",
|
Secret: goodSecret1,
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
|
@ -80,7 +80,7 @@ func TestClientsFromReader(t *testing.T) {
|
||||||
{
|
{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "my_id",
|
ID: "my_id",
|
||||||
Secret: "my_secret",
|
Secret: goodSecret1,
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
|
@ -91,7 +91,7 @@ func TestClientsFromReader(t *testing.T) {
|
||||||
{
|
{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "my_other_id",
|
ID: "my_other_id",
|
||||||
Secret: "my_other_secret",
|
Secret: goodSecret2,
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
|
@ -101,7 +101,8 @@ func TestClientsFromReader(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, {
|
},
|
||||||
|
{
|
||||||
json: "[" + badURLClient + "]",
|
json: "[" + badURLClient + "]",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
|
213
client/manager/manager.go
Normal file
213
client/manager/manager.go
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/client"
|
||||||
|
pcrypto "github.com/coreos/dex/pkg/crypto"
|
||||||
|
"github.com/coreos/dex/pkg/log"
|
||||||
|
"github.com/coreos/dex/repo"
|
||||||
|
"github.com/coreos/go-oidc/oidc"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Blowfish, the algorithm underlying bcrypt, has a maximum
|
||||||
|
// password length of 72. We explicitly track and check this
|
||||||
|
// since the bcrypt library will silently ignore portions of
|
||||||
|
// a password past the first 72 characters.
|
||||||
|
maxSecretLength = 72
|
||||||
|
)
|
||||||
|
|
||||||
|
type SecretGenerator func() ([]byte, error)
|
||||||
|
|
||||||
|
func DefaultSecretGenerator() ([]byte, error) {
|
||||||
|
return pcrypto.RandBytes(maxSecretLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CompareHashAndPassword(hashedPassword, password []byte) error {
|
||||||
|
if len(password) > maxSecretLength {
|
||||||
|
return errors.New("password length greater than max secret length")
|
||||||
|
}
|
||||||
|
return bcrypt.CompareHashAndPassword(hashedPassword, password)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientManager performs client-related "business-logic" functions on client and related objects.
|
||||||
|
// This is in contrast to the Repos which perform little more than CRUD operations.
|
||||||
|
type ClientManager struct {
|
||||||
|
clientRepo client.ClientRepo
|
||||||
|
begin repo.TransactionFactory
|
||||||
|
secretGenerator SecretGenerator
|
||||||
|
clientIDGenerator func(string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ManagerOptions struct {
|
||||||
|
SecretGenerator func() ([]byte, error)
|
||||||
|
ClientIDGenerator func(string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClientManager(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *ClientManager {
|
||||||
|
if options.SecretGenerator == nil {
|
||||||
|
options.SecretGenerator = DefaultSecretGenerator
|
||||||
|
}
|
||||||
|
if options.ClientIDGenerator == nil {
|
||||||
|
options.ClientIDGenerator = oidc.GenClientID
|
||||||
|
}
|
||||||
|
return &ClientManager{
|
||||||
|
clientRepo: clientRepo,
|
||||||
|
begin: txnFactory,
|
||||||
|
secretGenerator: options.SecretGenerator,
|
||||||
|
clientIDGenerator: options.ClientIDGenerator,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClientManagerFromClients(clientRepo client.ClientRepo, txnFactory repo.TransactionFactory, clients []client.Client, options ManagerOptions) (*ClientManager, error) {
|
||||||
|
clientManager := NewClientManager(clientRepo, txnFactory, options)
|
||||||
|
tx, err := clientManager.begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
for _, c := range clients {
|
||||||
|
if c.Credentials.Secret == "" {
|
||||||
|
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
cli, err := clientManager.generateClientCredentials(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = clientRepo.New(tx, cli)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return clientManager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ClientManager) New(cli client.Client) (*oidc.ClientCredentials, error) {
|
||||||
|
tx, err := m.begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
c, err := m.generateClientCredentials(cli)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
creds := c.Credentials
|
||||||
|
|
||||||
|
// Save Client
|
||||||
|
_, err = m.clientRepo.New(tx, c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns creds with unhashed secret
|
||||||
|
return &creds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ClientManager) Get(id string) (client.Client, error) {
|
||||||
|
return m.clientRepo.Get(nil, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ClientManager) All() ([]client.Client, error) {
|
||||||
|
return m.clientRepo.All(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ClientManager) Metadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||||
|
c, err := m.clientRepo.Get(nil, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &c.Metadata, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ClientManager) IsDexAdmin(clientID string) (bool, error) {
|
||||||
|
c, err := m.clientRepo.Get(nil, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Admin, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ClientManager) SetDexAdmin(clientID string, isAdmin bool) error {
|
||||||
|
tx, err := m.begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
c, err := m.clientRepo.Get(tx, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Admin = isAdmin
|
||||||
|
err = m.clientRepo.Update(tx, c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ClientManager) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
||||||
|
clientSecret, err := m.clientRepo.GetSecret(nil, creds.ID)
|
||||||
|
if err != nil || clientSecret == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error Decoding client creds: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := CompareHashAndPassword(clientSecret, dec) == nil
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ClientManager) generateClientCredentials(cli client.Client) (client.Client, error) {
|
||||||
|
// Generate Client ID
|
||||||
|
if len(cli.Metadata.RedirectURIs) < 1 {
|
||||||
|
return cli, errors.New("no client redirect url given")
|
||||||
|
}
|
||||||
|
clientID, err := m.clientIDGenerator(cli.Metadata.RedirectURIs[0].Host)
|
||||||
|
if err != nil {
|
||||||
|
return cli, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate Secret
|
||||||
|
secret, err := m.secretGenerator()
|
||||||
|
if err != nil {
|
||||||
|
return cli, err
|
||||||
|
}
|
||||||
|
clientSecret := base64.URLEncoding.EncodeToString(secret)
|
||||||
|
cli.Credentials = oidc.ClientCredentials{
|
||||||
|
ID: clientID,
|
||||||
|
Secret: clientSecret,
|
||||||
|
}
|
||||||
|
return cli, nil
|
||||||
|
}
|
165
client/manager/manager_test.go
Normal file
165
client/manager/manager_test.go
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/db"
|
||||||
|
"github.com/coreos/go-oidc/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testFixtures struct {
|
||||||
|
clientRepo client.ClientRepo
|
||||||
|
mgr *ClientManager
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
goodSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeTestFixtures() *testFixtures {
|
||||||
|
f := &testFixtures{}
|
||||||
|
|
||||||
|
dbMap := db.NewMemDB()
|
||||||
|
clients := []client.Client{
|
||||||
|
{
|
||||||
|
Credentials: oidc.ClientCredentials{
|
||||||
|
ID: "client.example.com",
|
||||||
|
Secret: goodSecret,
|
||||||
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
{Scheme: "http", Host: "client.example.com", Path: "/"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Admin: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
f.clientRepo = db.NewClientRepo(dbMap)
|
||||||
|
clientManager, err := NewClientManagerFromClients(f.clientRepo, db.TransactionFactory(dbMap), clients, ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
if err != nil {
|
||||||
|
panic("Failed to create client manager: " + err.Error())
|
||||||
|
}
|
||||||
|
f.mgr = clientManager
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetadata(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
clientID string
|
||||||
|
uri string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
clientID: "client.example.com",
|
||||||
|
uri: "http://client.example.com/",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
f := makeTestFixtures()
|
||||||
|
md, err := f.mgr.Metadata(tt.clientID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: unexpected err: %v", i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if md.RedirectURIs[0].String() != tt.uri {
|
||||||
|
t.Errorf("case %d: manager.Metadata.RedirectURIs: want=%q got=%q", i, tt.uri, md.RedirectURIs[0].String())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDexAdmin(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
clientID string
|
||||||
|
isAdmin bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
clientID: "client.example.com",
|
||||||
|
isAdmin: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
f := makeTestFixtures()
|
||||||
|
admin, err := f.mgr.IsDexAdmin(tt.clientID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: unexpected err: %v", i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if admin != tt.isAdmin {
|
||||||
|
t.Errorf("case %d: manager.Admin want=%t got=%t", i, tt.isAdmin, admin)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetDexAdmin(t *testing.T) {
|
||||||
|
f := makeTestFixtures()
|
||||||
|
err := f.mgr.SetDexAdmin("client.example.com", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
admin, _ := f.mgr.IsDexAdmin("client.example.com")
|
||||||
|
if admin {
|
||||||
|
t.Errorf("expected admin to be false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticate(t *testing.T) {
|
||||||
|
f := makeTestFixtures()
|
||||||
|
cm := oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
url.URL{Scheme: "http", Host: "example.com", Path: "/cb"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cli := client.Client{
|
||||||
|
Metadata: cm,
|
||||||
|
}
|
||||||
|
cc, err := f.mgr.New(cli)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := f.mgr.Authenticate(*cc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
} else if !ok {
|
||||||
|
t.Fatalf("Authentication failed for good creds")
|
||||||
|
}
|
||||||
|
|
||||||
|
creds := []oidc.ClientCredentials{
|
||||||
|
//completely made up
|
||||||
|
oidc.ClientCredentials{ID: "foo", Secret: "bar"},
|
||||||
|
|
||||||
|
// good client ID, bad secret
|
||||||
|
oidc.ClientCredentials{ID: cc.ID, Secret: "bar"},
|
||||||
|
|
||||||
|
// bad client ID, good secret
|
||||||
|
oidc.ClientCredentials{ID: "foo", Secret: cc.Secret},
|
||||||
|
|
||||||
|
// good client ID, secret with some fluff on the end
|
||||||
|
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
|
||||||
|
}
|
||||||
|
for i, c := range creds {
|
||||||
|
ok, err := f.mgr.Authenticate(c)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||||
|
} else if ok {
|
||||||
|
t.Errorf("case %d: authentication succeeded for bad creds", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
|
|
||||||
"github.com/coreos/dex/admin"
|
"github.com/coreos/dex/admin"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
pflag "github.com/coreos/dex/pkg/flag"
|
pflag "github.com/coreos/dex/pkg/flag"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
|
@ -119,8 +120,9 @@ func main() {
|
||||||
clientRepo := db.NewClientRepo(dbc)
|
clientRepo := db.NewClientRepo(dbc)
|
||||||
userManager := manager.NewUserManager(userRepo,
|
userManager := manager.NewUserManager(userRepo,
|
||||||
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
|
||||||
|
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
|
||||||
|
|
||||||
adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, *localConnectorID)
|
adminAPI := admin.NewAdminAPI(userRepo, pwiRepo, clientRepo, userManager, clientManager, *localConnectorID)
|
||||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
|
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf(err.Error())
|
log.Fatalf(err.Error())
|
||||||
|
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
|
@ -14,15 +15,15 @@ func newDBDriver(dsn string) (driver, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
drv := &dbDriver{
|
drv := &dbDriver{
|
||||||
ciRepo: db.NewClientRepo(dbc),
|
|
||||||
cfgRepo: db.NewConnectorConfigRepo(dbc),
|
cfgRepo: db.NewConnectorConfigRepo(dbc),
|
||||||
|
ciManager: manager.NewClientManager(db.NewClientRepo(dbc), db.TransactionFactory(dbc), manager.ManagerOptions{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
return drv, nil
|
return drv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type dbDriver struct {
|
type dbDriver struct {
|
||||||
ciRepo client.ClientRepo
|
ciManager *manager.ClientManager
|
||||||
cfgRepo *db.ConnectorConfigRepo
|
cfgRepo *db.ConnectorConfigRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,18 +31,10 @@ func (d *dbDriver) NewClient(meta oidc.ClientMetadata) (*oidc.ClientCredentials,
|
||||||
if err := meta.Valid(); err != nil {
|
if err := meta.Valid(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
cli := client.Client{
|
||||||
clientID, err := oidc.GenClientID(meta.RedirectURIs[0].Host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return d.ciRepo.New(client.Client{
|
|
||||||
Credentials: oidc.ClientCredentials{
|
|
||||||
ID: clientID,
|
|
||||||
},
|
|
||||||
Metadata: meta,
|
Metadata: meta,
|
||||||
})
|
}
|
||||||
|
return d.ciManager.New(cli)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) {
|
func (d *dbDriver) ConnectorConfigs() ([]connector.ConnectorConfig, error) {
|
||||||
|
|
201
db/client.go
201
db/client.go
|
@ -2,7 +2,6 @@ package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -10,24 +9,15 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
pcrypto "github.com/coreos/dex/pkg/crypto"
|
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
|
"github.com/coreos/dex/repo"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
clientTableName = "client_identity"
|
clientTableName = "client_identity"
|
||||||
|
|
||||||
bcryptHashCost = 10
|
|
||||||
|
|
||||||
// Blowfish, the algorithm underlying bcrypt, has a maximum
|
|
||||||
// password length of 72. We explicitly track and check this
|
|
||||||
// since the bcrypt library will silently ignore portions of
|
|
||||||
// a password past the first 72 characters.
|
|
||||||
maxSecretLength = 72
|
|
||||||
|
|
||||||
// postgres error codes
|
// postgres error codes
|
||||||
pgErrorCodeUniqueViolation = "23505" // unique_violation
|
pgErrorCodeUniqueViolation = "23505" // unique_violation
|
||||||
)
|
)
|
||||||
|
@ -42,17 +32,10 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientModel(cli client.Client) (*clientModel, error) {
|
func newClientModel(cli client.Client) (*clientModel, error) {
|
||||||
secretBytes, err := base64.URLEncoding.DecodeString(cli.Credentials.Secret)
|
hashed, err := client.HashSecret(cli.Credentials)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
hashed, err := bcrypt.GenerateFromPassword([]byte(
|
|
||||||
secretBytes),
|
|
||||||
bcryptHashCost)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
bmeta, err := json.Marshal(&cli.Metadata)
|
bmeta, err := json.Marshal(&cli.Metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -92,56 +75,20 @@ func (m *clientModel) Client() (*client.Client, error) {
|
||||||
|
|
||||||
func NewClientRepo(dbm *gorp.DbMap) client.ClientRepo {
|
func NewClientRepo(dbm *gorp.DbMap) client.ClientRepo {
|
||||||
return newClientRepo(dbm)
|
return newClientRepo(dbm)
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientRepoWithSecretGenerator(dbm *gorp.DbMap, secGen SecretGenerator) client.ClientRepo {
|
|
||||||
rep := newClientRepo(dbm)
|
|
||||||
rep.secretGenerator = secGen
|
|
||||||
return rep
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientRepo(dbm *gorp.DbMap) *clientRepo {
|
func newClientRepo(dbm *gorp.DbMap) *clientRepo {
|
||||||
return &clientRepo{
|
return &clientRepo{
|
||||||
db: &db{dbm},
|
db: &db{dbm},
|
||||||
secretGenerator: DefaultSecretGenerator,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientRepoFromClients(dbm *gorp.DbMap, clients []client.Client) (client.ClientRepo, error) {
|
|
||||||
repo := newClientRepo(dbm)
|
|
||||||
tx, err := repo.begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
exec := repo.executor(tx)
|
|
||||||
for _, c := range clients {
|
|
||||||
if c.Credentials.Secret == "" {
|
|
||||||
return nil, fmt.Errorf("client %q has no secret", c.Credentials.ID)
|
|
||||||
}
|
|
||||||
cm, err := newClientModel(c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = exec.Insert(cm)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return repo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientRepo struct {
|
type clientRepo struct {
|
||||||
*db
|
*db
|
||||||
secretGenerator SecretGenerator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *clientRepo) Get(clientID string) (client.Client, error) {
|
func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) {
|
||||||
m, err := r.executor(nil).Get(clientModel{}, clientID)
|
m, err := r.executor(tx).Get(clientModel{}, clientID)
|
||||||
if err == sql.ErrNoRows || m == nil {
|
if err == sql.ErrNoRows || m == nil {
|
||||||
return client.Client{}, client.ErrorNotFound
|
return client.Client{}, client.ErrorNotFound
|
||||||
}
|
}
|
||||||
|
@ -163,82 +110,28 @@ func (r *clientRepo) Get(clientID string) (client.Client, error) {
|
||||||
return *ci, nil
|
return *ci, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *clientRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
|
func (r *clientRepo) GetSecret(tx repo.Transaction, clientID string) ([]byte, error) {
|
||||||
c, err := r.Get(clientID)
|
m, err := r.getModel(tx, clientID)
|
||||||
if err != nil {
|
if err != nil || m == nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return m.Secret, nil
|
||||||
return &c.Metadata, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *clientRepo) IsDexAdmin(clientID string) (bool, error) {
|
func (r *clientRepo) Update(tx repo.Transaction, cli client.Client) error {
|
||||||
m, err := r.executor(nil).Get(clientModel{}, clientID)
|
if cli.Credentials.ID == "" {
|
||||||
if m == nil || err != nil {
|
return client.ErrorNotFound
|
||||||
return false, err
|
|
||||||
}
|
}
|
||||||
|
// make sure this client exists already
|
||||||
cim, ok := m.(*clientModel)
|
_, err := r.get(tx, cli.Credentials.ID)
|
||||||
if !ok {
|
|
||||||
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
|
|
||||||
return false, errors.New("unrecognized model")
|
|
||||||
}
|
|
||||||
|
|
||||||
return cim.DexAdmin, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *clientRepo) SetDexAdmin(clientID string, isAdmin bool) error {
|
|
||||||
tx, err := r.begin()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
err = r.update(tx, cli)
|
||||||
exec := r.executor(tx)
|
|
||||||
|
|
||||||
m, err := exec.Get(clientModel{}, clientID)
|
|
||||||
if m == nil || err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cim, ok := m.(*clientModel)
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
|
|
||||||
return errors.New("unrecognized model")
|
|
||||||
}
|
|
||||||
|
|
||||||
cim.DexAdmin = isAdmin
|
|
||||||
_, err = exec.Update(cim)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *clientRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
|
|
||||||
m, err := r.executor(nil).Get(clientModel{}, creds.ID)
|
|
||||||
if m == nil || err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cim, ok := m.(*clientModel)
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
|
|
||||||
return false, errors.New("unrecognized model")
|
|
||||||
}
|
|
||||||
|
|
||||||
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error Decoding client creds: %v", err)
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dec) > maxSecretLength {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ok = bcrypt.CompareHashAndPassword(cim.Secret, dec) == nil
|
|
||||||
return ok, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var alreadyExistsCheckers []func(err error) bool
|
var alreadyExistsCheckers []func(err error) bool
|
||||||
|
@ -260,26 +153,14 @@ func isAlreadyExistsErr(err error) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
type SecretGenerator func() ([]byte, error)
|
func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) {
|
||||||
|
|
||||||
func DefaultSecretGenerator() ([]byte, error) {
|
|
||||||
return pcrypto.RandBytes(maxSecretLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
|
|
||||||
secret, err := r.secretGenerator()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cli.Credentials.Secret = base64.URLEncoding.EncodeToString(secret)
|
|
||||||
cim, err := newClientModel(cli)
|
cim, err := newClientModel(cli)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.executor(nil).Insert(cim); err != nil {
|
if err := r.executor(tx).Insert(cim); err != nil {
|
||||||
if isAlreadyExistsErr(err) {
|
if isAlreadyExistsErr(err) {
|
||||||
err = errors.New("client ID already exists")
|
err = errors.New("client ID already exists")
|
||||||
}
|
}
|
||||||
|
@ -294,10 +175,10 @@ func (r *clientRepo) New(cli client.Client) (*oidc.ClientCredentials, error) {
|
||||||
return &cc, nil
|
return &cc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *clientRepo) All() ([]client.Client, error) {
|
func (r *clientRepo) All(tx repo.Transaction) ([]client.Client, error) {
|
||||||
qt := r.quote(clientTableName)
|
qt := r.quote(clientTableName)
|
||||||
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
q := fmt.Sprintf("SELECT * FROM %s", qt)
|
||||||
objs, err := r.executor(nil).Select(&clientModel{}, q)
|
objs, err := r.executor(tx).Select(&clientModel{}, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -317,3 +198,47 @@ func (r *clientRepo) All() ([]client.Client, error) {
|
||||||
}
|
}
|
||||||
return cs, nil
|
return cs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *clientRepo) get(tx repo.Transaction, clientID string) (client.Client, error) {
|
||||||
|
cm, err := r.getModel(tx, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return client.Client{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cli, err := cm.Client()
|
||||||
|
if err != nil {
|
||||||
|
return client.Client{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return *cli, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepo) getModel(tx repo.Transaction, clientID string) (*clientModel, error) {
|
||||||
|
ex := r.executor(tx)
|
||||||
|
|
||||||
|
m, err := ex.Get(clientModel{}, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if m == nil {
|
||||||
|
return nil, client.ErrorNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
cm, ok := m.(*clientModel)
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
|
||||||
|
return nil, errors.New("unrecognized model")
|
||||||
|
}
|
||||||
|
return cm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepo) update(tx repo.Transaction, cli client.Client) error {
|
||||||
|
ex := r.executor(tx)
|
||||||
|
cm, err := newClientModel(cli)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = ex.Update(cm)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/refresh"
|
"github.com/coreos/dex/refresh"
|
||||||
"github.com/coreos/dex/session"
|
"github.com/coreos/dex/session"
|
||||||
|
@ -191,7 +192,7 @@ func TestDBClientRepoMetadata(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := r.New(client.Client{
|
_, err := r.New(nil, client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "foo",
|
ID: "foo",
|
||||||
},
|
},
|
||||||
|
@ -201,20 +202,22 @@ func TestDBClientRepoMetadata(t *testing.T) {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := r.Metadata("foo")
|
got, err := r.Get(nil, "foo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := pretty.Compare(cm, *got); diff != "" {
|
if diff := pretty.Compare(cm, got.Metadata); diff != "" {
|
||||||
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
|
t.Fatalf("Retrieved incorrect ClientMetadata: Compare(want,got): %v", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDBClientRepoMetadataNoExist(t *testing.T) {
|
func TestDBClientRepoMetadataNoExist(t *testing.T) {
|
||||||
r := db.NewClientRepo(connect(t))
|
c := connect(t)
|
||||||
|
r := db.NewClientRepo(c)
|
||||||
|
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{})
|
||||||
|
|
||||||
got, err := r.Metadata("noexist")
|
got, err := m.Metadata("noexist")
|
||||||
if err != client.ErrorNotFound {
|
if err != client.ErrorNotFound {
|
||||||
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
|
t.Errorf("want==%q, got==%q", client.ErrorNotFound, err)
|
||||||
}
|
}
|
||||||
|
@ -232,7 +235,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := r.New(client.Client{
|
if _, err := r.New(nil, client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "foo",
|
ID: "foo",
|
||||||
},
|
},
|
||||||
|
@ -247,7 +250,7 @@ func TestDBClientRepoNewDuplicate(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := r.New(client.Client{
|
if _, err := r.New(nil, client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "foo",
|
ID: "foo",
|
||||||
},
|
},
|
||||||
|
@ -261,7 +264,7 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
|
||||||
|
|
||||||
for _, admin := range []bool{true, false} {
|
for _, admin := range []bool{true, false} {
|
||||||
r := db.NewClientRepo(connect(t))
|
r := db.NewClientRepo(connect(t))
|
||||||
if _, err := r.New(client.Client{
|
if _, err := r.New(nil, client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "foo",
|
ID: "foo",
|
||||||
},
|
},
|
||||||
|
@ -275,15 +278,15 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
|
||||||
t.Fatalf("expected non-nil error: %v", err)
|
t.Fatalf("expected non-nil error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
gotAdmin, err := r.IsDexAdmin("foo")
|
gotAdmin, err := r.Get(nil, "foo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected non-nil error")
|
t.Fatalf("expected non-nil error")
|
||||||
}
|
}
|
||||||
if gotAdmin != admin {
|
if gotAdmin.Admin != admin {
|
||||||
t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin)
|
t.Errorf("want=%v, gotAdmin=%v", admin, gotAdmin)
|
||||||
}
|
}
|
||||||
|
|
||||||
cli, err := r.Get("foo")
|
cli, err := r.Get(nil, "foo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected non-nil error")
|
t.Fatalf("expected non-nil error")
|
||||||
}
|
}
|
||||||
|
@ -294,29 +297,35 @@ func TestDBClientRepoNewAdmin(t *testing.T) {
|
||||||
|
|
||||||
}
|
}
|
||||||
func TestDBClientRepoAuthenticate(t *testing.T) {
|
func TestDBClientRepoAuthenticate(t *testing.T) {
|
||||||
r := db.NewClientRepo(connect(t))
|
c := connect(t)
|
||||||
|
r := db.NewClientRepo(c)
|
||||||
|
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
m := manager.NewClientManager(r, db.TransactionFactory(c), manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
|
||||||
cm := oidc.ClientMetadata{
|
cm := oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
|
url.URL{Scheme: "http", Host: "127.0.0.1:5556", Path: "/cb"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
cli := client.Client{
|
||||||
cc, err := r.New(client.Client{
|
|
||||||
Credentials: oidc.ClientCredentials{
|
|
||||||
ID: "baz",
|
|
||||||
},
|
|
||||||
Metadata: cm,
|
Metadata: cm,
|
||||||
})
|
}
|
||||||
|
cc, err := m.New(cli)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if cc.ID != "baz" {
|
if cc.ID != "127.0.0.1:5556" {
|
||||||
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
|
t.Fatalf("Returned ClientCredentials has incorrect ID: want=baz got=%s", cc.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := r.Authenticate(*cc)
|
ok, err := m.Authenticate(*cc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
} else if !ok {
|
} else if !ok {
|
||||||
|
@ -337,7 +346,7 @@ func TestDBClientRepoAuthenticate(t *testing.T) {
|
||||||
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
|
oidc.ClientCredentials{ID: cc.ID, Secret: fmt.Sprintf("%sfluff", cc.Secret)},
|
||||||
}
|
}
|
||||||
for i, c := range creds {
|
for i, c := range creds {
|
||||||
ok, err := r.Authenticate(c)
|
ok, err := m.Authenticate(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||||
} else if ok {
|
} else if ok {
|
||||||
|
@ -355,7 +364,7 @@ func TestDBClientAll(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := r.New(client.Client{
|
_, err := r.New(nil, client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "foo",
|
ID: "foo",
|
||||||
},
|
},
|
||||||
|
@ -365,7 +374,7 @@ func TestDBClientAll(t *testing.T) {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := r.All()
|
got, err := r.All(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
@ -383,7 +392,7 @@ func TestDBClientAll(t *testing.T) {
|
||||||
url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
|
url.URL{Scheme: "http", Host: "foo.com", Path: "/cb"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, err = r.New(client.Client{
|
_, err = r.New(nil, client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "bar",
|
ID: "bar",
|
||||||
},
|
},
|
||||||
|
@ -393,7 +402,7 @@ func TestDBClientAll(t *testing.T) {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err = r.All()
|
got, err = r.All(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,14 +3,10 @@ package repo
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
"github.com/go-gorp/gorp"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
"github.com/coreos/dex/db"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -47,95 +43,3 @@ var (
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func newClientRepo(t *testing.T) client.ClientRepo {
|
|
||||||
dsn := os.Getenv("DEX_TEST_DSN")
|
|
||||||
var dbMap *gorp.DbMap
|
|
||||||
if dsn == "" {
|
|
||||||
dbMap = db.NewMemDB()
|
|
||||||
} else {
|
|
||||||
dbMap = connect(t)
|
|
||||||
}
|
|
||||||
repo, err := db.NewClientRepoFromClients(dbMap, testClients)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create client repo from clients: %v", err)
|
|
||||||
}
|
|
||||||
return repo
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetSetAdminClient(t *testing.T) {
|
|
||||||
startAdmins := []string{"client2"}
|
|
||||||
tests := []struct {
|
|
||||||
// client ID
|
|
||||||
cid string
|
|
||||||
|
|
||||||
// initial state of client
|
|
||||||
wantAdmin bool
|
|
||||||
|
|
||||||
// final state of client
|
|
||||||
setAdmin bool
|
|
||||||
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
cid: "client1",
|
|
||||||
wantAdmin: false,
|
|
||||||
setAdmin: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
cid: "client1",
|
|
||||||
wantAdmin: false,
|
|
||||||
setAdmin: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
cid: "client2",
|
|
||||||
wantAdmin: true,
|
|
||||||
setAdmin: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
cid: "client2",
|
|
||||||
wantAdmin: true,
|
|
||||||
setAdmin: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
Tests:
|
|
||||||
for i, tt := range tests {
|
|
||||||
repo := newClientRepo(t)
|
|
||||||
for _, cid := range startAdmins {
|
|
||||||
err := repo.SetDexAdmin(cid, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: failed to set dex admin: %v", i, err)
|
|
||||||
continue Tests
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
gotAdmin, err := repo.IsDexAdmin(tt.cid)
|
|
||||||
if tt.wantErr {
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("case %d: want non-nil err", i)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
|
||||||
}
|
|
||||||
if gotAdmin != tt.wantAdmin {
|
|
||||||
t.Errorf("case %d: want=%v, got=%v", i, tt.wantAdmin, gotAdmin)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = repo.SetDexAdmin(tt.cid, tt.setAdmin)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
gotAdmin, err = repo.IsDexAdmin(tt.cid)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
|
||||||
}
|
|
||||||
if gotAdmin != tt.setAdmin {
|
|
||||||
t.Errorf("case %d: want=%v, got=%v", i, tt.setAdmin, gotAdmin)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/refresh"
|
"github.com/coreos/dex/refresh"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
|
@ -27,7 +28,7 @@ func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients
|
||||||
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
|
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
|
||||||
t.Fatalf("Unable to add users: %v", err)
|
t.Fatalf("Unable to add users: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil {
|
if _, err := manager.NewClientManagerFromClients(db.NewClientRepo(dbMap), db.TransactionFactory(dbMap), clients, manager.ManagerOptions{}); err != nil {
|
||||||
t.Fatalf("Unable to add clients: %v", err)
|
t.Fatalf("Unable to add clients: %v", err)
|
||||||
}
|
}
|
||||||
return db.NewRefreshTokenRepo(dbMap)
|
return db.NewRefreshTokenRepo(dbMap)
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/dex/admin"
|
"github.com/coreos/dex/admin"
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/schema/adminschema"
|
"github.com/coreos/dex/schema/adminschema"
|
||||||
"github.com/coreos/dex/server"
|
"github.com/coreos/dex/server"
|
||||||
|
@ -87,12 +88,16 @@ func makeAdminAPITestFixtures() *adminAPITestFixtures {
|
||||||
secGen := func() ([]byte, error) {
|
secGen := func() ([]byte, error) {
|
||||||
return []byte(fmt.Sprintf("client_%v", cliCount)), nil
|
return []byte(fmt.Sprintf("client_%v", cliCount)), nil
|
||||||
}
|
}
|
||||||
cr := db.NewClientRepoWithSecretGenerator(dbMap, secGen)
|
cr := db.NewClientRepo(dbMap)
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return fmt.Sprintf("client_%v", hostport), nil
|
||||||
|
}
|
||||||
|
cm := manager.NewClientManager(cr, db.TransactionFactory(dbMap), manager.ManagerOptions{SecretGenerator: secGen, ClientIDGenerator: clientIDGenerator})
|
||||||
|
|
||||||
f.cr = cr
|
f.cr = cr
|
||||||
f.ur = ur
|
f.ur = ur
|
||||||
f.pwr = pwr
|
f.pwr = pwr
|
||||||
f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, "local")
|
f.adAPI = admin.NewAdminAPI(ur, pwr, cr, um, cm, "local")
|
||||||
f.adSrv = server.NewAdminServer(f.adAPI, nil, adminAPITestSecret)
|
f.adSrv = server.NewAdminServer(f.adAPI, nil, adminAPITestSecret)
|
||||||
f.hSrv = httptest.NewServer(f.adSrv.HTTPHandler())
|
f.hSrv = httptest.NewServer(f.adSrv.HTTPHandler())
|
||||||
f.hc = &http.Client{
|
f.hc = &http.Client{
|
||||||
|
@ -268,14 +273,6 @@ func TestCreateAdmin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateClient(t *testing.T) {
|
func TestCreateClient(t *testing.T) {
|
||||||
oldGen := admin.ClientIDGenerator
|
|
||||||
admin.ClientIDGenerator = func(hostport string) (string, error) {
|
|
||||||
return fmt.Sprintf("client_%v", hostport), nil
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
admin.ClientIDGenerator = oldGen
|
|
||||||
}()
|
|
||||||
|
|
||||||
mustParseURL := func(s string) *url.URL {
|
mustParseURL := func(s string) *url.URL {
|
||||||
u, err := url.Parse(s)
|
u, err := url.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -402,7 +399,7 @@ func TestCreateClient(t *testing.T) {
|
||||||
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
|
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
repoClient, err := f.cr.Get(resp.Client.Id)
|
repoClient, err := f.cr.Get(nil, resp.Client.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("case %d: Unexpected error getting client: %v", i, err)
|
t.Errorf("case %d: Unexpected error getting client: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,9 +14,10 @@ import (
|
||||||
|
|
||||||
func TestClientCreate(t *testing.T) {
|
func TestClientCreate(t *testing.T) {
|
||||||
ci := client.Client{
|
ci := client.Client{
|
||||||
|
// Credentials are for reference, they are actually generated by the client manager
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "72de74a9",
|
ID: "authn.example.com",
|
||||||
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
|
@ -73,7 +74,7 @@ func TestClientCreate(t *testing.T) {
|
||||||
t.Error("Expected non-empty Client Secret")
|
t.Error("Expected non-empty Client Secret")
|
||||||
}
|
}
|
||||||
|
|
||||||
meta, err := srv.ClientRepo.Metadata(newClient.Id)
|
meta, err := srv.ClientManager.Metadata(newClient.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error looking up client metadata: %v", err)
|
t.Errorf("Error looking up client metadata: %v", err)
|
||||||
} else if meta == nil {
|
} else if meta == nil {
|
||||||
|
|
|
@ -22,9 +22,10 @@ var (
|
||||||
clock = clockwork.NewFakeClock()
|
clock = clockwork.NewFakeClock()
|
||||||
|
|
||||||
testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"}
|
testIssuerURL = url.URL{Scheme: "https", Host: "auth.example.com"}
|
||||||
testClientID = "XXX"
|
testClientID = "client.example.com"
|
||||||
testClientSecret = base64.URLEncoding.EncodeToString([]byte("yyy"))
|
testClientSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
|
||||||
testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
|
testRedirectURL = url.URL{Scheme: "https", Host: "client.example.com", Path: "/redirect"}
|
||||||
|
testBadRedirectURL = url.URL{Scheme: "https", Host: "bad.example.com", Path: "/redirect"}
|
||||||
testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
|
testResetPasswordURL = url.URL{Scheme: "https", Host: "auth.example.com", Path: "/resetPassword"}
|
||||||
testPrivKey, _ = key.GeneratePrivateKey()
|
testPrivKey, _ = key.GeneratePrivateKey()
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
phttp "github.com/coreos/dex/pkg/http"
|
phttp "github.com/coreos/dex/pkg/http"
|
||||||
|
@ -35,7 +36,15 @@ func mockServer(cis []client.Client) (*server.Server, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
clientRepo, err := db.NewClientRepoFromClients(dbMap, cis)
|
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbMap)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), cis, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -45,6 +54,7 @@ func mockServer(cis []client.Client) (*server.Server, error) {
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
ClientRepo: clientRepo,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,15 +92,21 @@ func verifyUserClaims(claims jose.Claims, ci *client.Client, user *user.User, is
|
||||||
expectedSub, expectedName = user.ID, user.DisplayName
|
expectedSub, expectedName = user.ID, user.DisplayName
|
||||||
}
|
}
|
||||||
|
|
||||||
if aud := claims["aud"].(string); aud != ci.Credentials.ID {
|
if aud, ok := claims["aud"].(string); !ok {
|
||||||
|
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", ci.Credentials.ID)
|
||||||
|
} else if aud != ci.Credentials.ID {
|
||||||
return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID)
|
return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sub := claims["sub"].(string); sub != expectedSub {
|
if sub, ok := claims["sub"].(string); !ok {
|
||||||
|
return fmt.Errorf("unexpected claim value for sub, got=nil, want=%v", expectedSub)
|
||||||
|
} else if sub != expectedSub {
|
||||||
return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub)
|
return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub)
|
||||||
}
|
}
|
||||||
|
|
||||||
if name := claims["name"].(string); name != expectedName {
|
if name, ok := claims["name"].(string); !ok {
|
||||||
|
return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", expectedName)
|
||||||
|
} else if name != expectedName {
|
||||||
return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName)
|
return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,17 +133,34 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
||||||
ID: "local",
|
ID: "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
validRedirURL := url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "client.example.com",
|
||||||
|
Path: "/callback",
|
||||||
|
}
|
||||||
ci := client.Client{
|
ci := client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "72de74a9",
|
ID: validRedirURL.Host,
|
||||||
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
validRedirURL,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
dbMap := db.NewMemDB()
|
dbMap := db.NewMemDB()
|
||||||
cir, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci})
|
clientRepo := db.NewClientRepo(dbMap)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create client identity repo: " + err.Error())
|
t.Fatalf("Failed to create client identity manager: " + err.Error())
|
||||||
}
|
}
|
||||||
passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo})
|
passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -164,7 +197,8 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
||||||
IssuerURL: issuerURL,
|
IssuerURL: issuerURL,
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
ClientRepo: cir,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
Templates: template.New(connector.LoginPageTemplateName),
|
Templates: template.New(connector.LoginPageTemplateName),
|
||||||
Connectors: []connector.Connector{},
|
Connectors: []connector.Connector{},
|
||||||
UserRepo: userRepo,
|
UserRepo: userRepo,
|
||||||
|
@ -188,7 +222,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
||||||
HTTPClient: sClient,
|
HTTPClient: sClient,
|
||||||
ProviderConfig: pcfg,
|
ProviderConfig: pcfg,
|
||||||
Credentials: ci.Credentials,
|
Credentials: ci.Credentials,
|
||||||
RedirectURL: "http://client.example.com",
|
RedirectURL: validRedirURL.String(),
|
||||||
KeySet: *ks,
|
KeySet: *ks,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,10 +297,20 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPClientCredsToken(t *testing.T) {
|
func TestHTTPClientCredsToken(t *testing.T) {
|
||||||
|
validRedirURL := url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "client.example.com",
|
||||||
|
Path: "/callback",
|
||||||
|
}
|
||||||
ci := client.Client{
|
ci := client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "72de74a9",
|
ID: validRedirURL.Host,
|
||||||
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
validRedirURL,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cis := []client.Client{ci}
|
cis := []client.Client{ci}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"google.golang.org/api/googleapi"
|
"google.golang.org/api/googleapi"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
"github.com/coreos/dex/server"
|
"github.com/coreos/dex/server"
|
||||||
|
@ -79,7 +80,7 @@ var (
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
userBadClientID = "ZZZ"
|
userBadClientID = testBadRedirectURL.Host
|
||||||
|
|
||||||
userGoodToken = makeUserToken(testIssuerURL,
|
userGoodToken = makeUserToken(testIssuerURL,
|
||||||
"ID-1", testClientID, time.Hour*1, testPrivKey)
|
"ID-1", testClientID, time.Hour*1, testPrivKey)
|
||||||
|
@ -101,8 +102,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
||||||
f := &userAPITestFixtures{}
|
f := &userAPITestFixtures{}
|
||||||
|
|
||||||
dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
|
dbMap, _, _, um := makeUserObjects(userUsers, userPasswords)
|
||||||
cir := func() client.ClientRepo {
|
clients := []client.Client{
|
||||||
repo, err := db.NewClientRepoFromClients(dbMap, []client.Client{
|
|
||||||
client.Client{
|
client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: testClientID,
|
ID: testClientID,
|
||||||
|
@ -121,18 +121,23 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
testRedirectURL,
|
testBadRedirectURL,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
panic("Failed to create client identity repo: " + err.Error())
|
|
||||||
}
|
}
|
||||||
return repo
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
}()
|
return hostport, nil
|
||||||
|
}
|
||||||
cir.SetDexAdmin(testClientID, true)
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte(testClientSecret), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbMap)
|
||||||
|
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
if err != nil {
|
||||||
|
panic("Failed to create client identity manager: " + err.Error())
|
||||||
|
}
|
||||||
|
clientManager.SetDexAdmin(testClientID, true)
|
||||||
|
|
||||||
noop := func() error { return nil }
|
noop := func() error { return nil }
|
||||||
|
|
||||||
|
@ -153,8 +158,9 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
|
||||||
|
|
||||||
f.emailer = &testEmailer{}
|
f.emailer = &testEmailer{}
|
||||||
um.Clock = clock
|
um.Clock = clock
|
||||||
api := api.NewUsersAPI(dbMap, um, f.emailer, "local")
|
|
||||||
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, cir)
|
api := api.NewUsersAPI(um, clientManager, refreshRepo, f.emailer, "local")
|
||||||
|
usrSrv := server.NewUserMgmtServer(api, jwtvFactory, um, clientManager)
|
||||||
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
|
f.hSrv = httptest.NewServer(usrSrv.HTTPHandler())
|
||||||
|
|
||||||
f.trans = &tokenHandlerTransport{
|
f.trans = &tokenHandlerTransport{
|
||||||
|
@ -536,7 +542,7 @@ func TestCreateUser(t *testing.T) {
|
||||||
wantEmalier := testEmailer{
|
wantEmalier := testEmailer{
|
||||||
cantEmail: tt.cantEmail,
|
cantEmail: tt.cantEmail,
|
||||||
lastEmail: tt.req.User.Email,
|
lastEmail: tt.req.User.Email,
|
||||||
lastClientID: "XXX",
|
lastClientID: testClientID,
|
||||||
lastWasInvite: true,
|
lastWasInvite: true,
|
||||||
lastRedirectURL: *urlParsed,
|
lastRedirectURL: *urlParsed,
|
||||||
}
|
}
|
||||||
|
@ -799,7 +805,7 @@ func TestResendEmailInvitation(t *testing.T) {
|
||||||
wantEmalier := testEmailer{
|
wantEmalier := testEmailer{
|
||||||
cantEmail: tt.cantEmail,
|
cantEmail: tt.cantEmail,
|
||||||
lastEmail: strings.ToLower(tt.email),
|
lastEmail: strings.ToLower(tt.email),
|
||||||
lastClientID: "XXX",
|
lastClientID: testClientID,
|
||||||
lastWasInvite: true,
|
lastWasInvite: true,
|
||||||
lastRedirectURL: *urlParsed,
|
lastRedirectURL: *urlParsed,
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/go-oidc/jose"
|
"github.com/coreos/go-oidc/jose"
|
||||||
"github.com/coreos/go-oidc/key"
|
"github.com/coreos/go-oidc/key"
|
||||||
|
@ -14,7 +14,7 @@ import (
|
||||||
|
|
||||||
type clientTokenMiddleware struct {
|
type clientTokenMiddleware struct {
|
||||||
issuerURL string
|
issuerURL string
|
||||||
ciRepo client.ClientRepo
|
ciManager *manager.ClientManager
|
||||||
keysFunc func() ([]key.PublicKey, error)
|
keysFunc func() ([]key.PublicKey, error)
|
||||||
next http.Handler
|
next http.Handler
|
||||||
}
|
}
|
||||||
|
@ -30,8 +30,8 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.ciRepo == nil {
|
if c.ciManager == nil {
|
||||||
log.Errorf("Misconfigured clientTokenMiddleware, ClientRepo is not set")
|
log.Errorf("Misconfigured clientTokenMiddleware, ClientManager is not set")
|
||||||
respondError()
|
respondError()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -83,7 +83,7 @@ func (c *clientTokenMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
md, err := c.ciRepo.Metadata(clientID)
|
md, err := c.ciManager.Metadata(clientID)
|
||||||
if md == nil || err != nil {
|
if md == nil || err != nil {
|
||||||
log.Errorf("Failed to find clientID: %s, error=%v", clientID, err)
|
log.Errorf("Failed to find clientID: %s, error=%v", clientID, err)
|
||||||
respondError()
|
respondError()
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -10,6 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/go-oidc/jose"
|
"github.com/coreos/go-oidc/jose"
|
||||||
"github.com/coreos/go-oidc/key"
|
"github.com/coreos/go-oidc/key"
|
||||||
|
@ -25,22 +25,23 @@ func (h staticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
func TestClientToken(t *testing.T) {
|
func TestClientToken(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
tomorrow := now.Add(24 * time.Hour)
|
tomorrow := now.Add(24 * time.Hour)
|
||||||
validClientID := "valid-client"
|
clientMetadata := oidc.ClientMetadata{
|
||||||
ci := client.Client{
|
|
||||||
Credentials: oidc.ClientCredentials{
|
|
||||||
ID: validClientID,
|
|
||||||
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
|
||||||
},
|
|
||||||
Metadata: oidc.ClientMetadata{
|
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
|
{Scheme: "https", Host: "authn.example.com", Path: "/callback"},
|
||||||
},
|
},
|
||||||
},
|
|
||||||
}
|
}
|
||||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
|
||||||
|
dbm := db.NewMemDB()
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager := clientmanager.NewClientManager(clientRepo, db.TransactionFactory(dbm), clientmanager.ManagerOptions{})
|
||||||
|
cli := client.Client{
|
||||||
|
Metadata: clientMetadata,
|
||||||
|
}
|
||||||
|
creds, err := clientManager.New(cli)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
t.Fatalf("Failed to create client: %v", err)
|
||||||
}
|
}
|
||||||
|
validClientID := creds.ID
|
||||||
|
|
||||||
privKey, err := key.GeneratePrivateKey()
|
privKey, err := key.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -65,63 +66,63 @@ func TestClientToken(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
keys []key.PublicKey
|
keys []key.PublicKey
|
||||||
repo client.ClientRepo
|
manager *clientmanager.ClientManager
|
||||||
header string
|
header string
|
||||||
wantCode int
|
wantCode int
|
||||||
}{
|
}{
|
||||||
// valid token
|
// valid token
|
||||||
{
|
{
|
||||||
keys: []key.PublicKey{pubKey},
|
keys: []key.PublicKey{pubKey},
|
||||||
repo: repo,
|
manager: clientManager,
|
||||||
header: fmt.Sprintf("BEARER %s", validJWT),
|
header: fmt.Sprintf("BEARER %s", validJWT),
|
||||||
wantCode: http.StatusOK,
|
wantCode: http.StatusOK,
|
||||||
},
|
},
|
||||||
// invalid token
|
// invalid token
|
||||||
{
|
{
|
||||||
keys: []key.PublicKey{pubKey},
|
keys: []key.PublicKey{pubKey},
|
||||||
repo: repo,
|
manager: clientManager,
|
||||||
header: fmt.Sprintf("BEARER %s", invalidJWT),
|
header: fmt.Sprintf("BEARER %s", invalidJWT),
|
||||||
wantCode: http.StatusUnauthorized,
|
wantCode: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
// empty header
|
// empty header
|
||||||
{
|
{
|
||||||
keys: []key.PublicKey{pubKey},
|
keys: []key.PublicKey{pubKey},
|
||||||
repo: repo,
|
manager: clientManager,
|
||||||
header: "",
|
header: "",
|
||||||
wantCode: http.StatusUnauthorized,
|
wantCode: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
// unparsable token
|
// unparsable token
|
||||||
{
|
{
|
||||||
keys: []key.PublicKey{pubKey},
|
keys: []key.PublicKey{pubKey},
|
||||||
repo: repo,
|
manager: clientManager,
|
||||||
header: "BEARER xxx",
|
header: "BEARER xxx",
|
||||||
wantCode: http.StatusUnauthorized,
|
wantCode: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
// no verification keys
|
// no verification keys
|
||||||
{
|
{
|
||||||
keys: []key.PublicKey{},
|
keys: []key.PublicKey{},
|
||||||
repo: repo,
|
manager: clientManager,
|
||||||
header: fmt.Sprintf("BEARER %s", validJWT),
|
header: fmt.Sprintf("BEARER %s", validJWT),
|
||||||
wantCode: http.StatusUnauthorized,
|
wantCode: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
// nil repo
|
// nil repo
|
||||||
{
|
{
|
||||||
keys: []key.PublicKey{pubKey},
|
keys: []key.PublicKey{pubKey},
|
||||||
repo: nil,
|
manager: nil,
|
||||||
header: fmt.Sprintf("BEARER %s", validJWT),
|
header: fmt.Sprintf("BEARER %s", validJWT),
|
||||||
wantCode: http.StatusUnauthorized,
|
wantCode: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
// empty repo
|
// empty repo
|
||||||
{
|
{
|
||||||
keys: []key.PublicKey{pubKey},
|
keys: []key.PublicKey{pubKey},
|
||||||
repo: db.NewClientRepo(db.NewMemDB()),
|
manager: clientmanager.NewClientManager(db.NewClientRepo(db.NewMemDB()), db.TransactionFactory(db.NewMemDB()), clientmanager.ManagerOptions{}),
|
||||||
header: fmt.Sprintf("BEARER %s", validJWT),
|
header: fmt.Sprintf("BEARER %s", validJWT),
|
||||||
wantCode: http.StatusUnauthorized,
|
wantCode: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
// client not in repo
|
// client not in repo
|
||||||
{
|
{
|
||||||
keys: []key.PublicKey{pubKey},
|
keys: []key.PublicKey{pubKey},
|
||||||
repo: repo,
|
manager: clientManager,
|
||||||
header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
|
header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
|
||||||
wantCode: http.StatusUnauthorized,
|
wantCode: http.StatusUnauthorized,
|
||||||
},
|
},
|
||||||
|
@ -131,7 +132,7 @@ func TestClientToken(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
mw := &clientTokenMiddleware{
|
mw := &clientTokenMiddleware{
|
||||||
issuerURL: validIss,
|
issuerURL: validIss,
|
||||||
ciRepo: tt.repo,
|
ciManager: tt.manager,
|
||||||
keysFunc: func() ([]key.PublicKey, error) {
|
keysFunc: func() ([]key.PublicKey, error) {
|
||||||
return tt.keys, nil
|
return tt.keys, nil
|
||||||
},
|
},
|
||||||
|
|
|
@ -39,18 +39,10 @@ func (s *Server) handleClientRegistrationRequest(r *http.Request) (*oidc.ClientR
|
||||||
}
|
}
|
||||||
|
|
||||||
// metadata is guarenteed to have at least one redirect_uri by earlier validation.
|
// metadata is guarenteed to have at least one redirect_uri by earlier validation.
|
||||||
id, err := oidc.GenClientID(clientMetadata.RedirectURIs[0].Host)
|
cli := client.Client{
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Faild to create client ID: %v", err)
|
|
||||||
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
|
|
||||||
}
|
|
||||||
|
|
||||||
creds, err := s.ClientRepo.New(client.Client{
|
|
||||||
Credentials: oidc.ClientCredentials{
|
|
||||||
ID: id,
|
|
||||||
},
|
|
||||||
Metadata: clientMetadata,
|
Metadata: clientMetadata,
|
||||||
})
|
}
|
||||||
|
creds, err := s.ClientManager.New(cli)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to create new client identity: %v", err)
|
log.Errorf("Failed to create new client identity: %v", err)
|
||||||
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
|
return nil, newAPIError(oauth2.ErrorServerError, "unable to save client metadata")
|
||||||
|
|
|
@ -143,7 +143,7 @@ func TestClientRegistration(t *testing.T) {
|
||||||
return fmt.Errorf("no client id in registration response")
|
return fmt.Errorf("no client id in registration response")
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, err := fixtures.clientRepo.Metadata(r.ClientID)
|
metadata, err := fixtures.clientManager.Metadata(r.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to lookup client id after creation")
|
return fmt.Errorf("failed to lookup client id after creation")
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,21 +6,20 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client/manager"
|
||||||
phttp "github.com/coreos/dex/pkg/http"
|
phttp "github.com/coreos/dex/pkg/http"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
"github.com/coreos/go-oidc/oidc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type clientResource struct {
|
type clientResource struct {
|
||||||
repo client.ClientRepo
|
manager *manager.ClientManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerClientResource(prefix string, repo client.ClientRepo) (string, http.Handler) {
|
func registerClientResource(prefix string, manager *manager.ClientManager) (string, http.Handler) {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
c := &clientResource{
|
c := &clientResource{
|
||||||
repo: repo,
|
manager: manager,
|
||||||
}
|
}
|
||||||
relPath := "clients"
|
relPath := "clients"
|
||||||
absPath := path.Join(prefix, relPath)
|
absPath := path.Join(prefix, relPath)
|
||||||
|
@ -41,7 +40,7 @@ func (c *clientResource) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientResource) list(w http.ResponseWriter, r *http.Request) {
|
func (c *clientResource) list(w http.ResponseWriter, r *http.Request) {
|
||||||
cs, err := c.repo.All()
|
cs, err := c.manager.All()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients"))
|
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "error listing clients"))
|
||||||
return
|
return
|
||||||
|
@ -88,16 +87,7 @@ func (c *clientResource) create(w http.ResponseWriter, r *http.Request) {
|
||||||
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error()))
|
writeAPIError(w, http.StatusBadRequest, newAPIError(errorInvalidClientMetadata, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
creds, err := c.manager.New(ci)
|
||||||
clientID, err := oidc.GenClientID(ci.Metadata.RedirectURIs[0].Host)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed generating ID for new client: %v", err)
|
|
||||||
writeAPIError(w, http.StatusInternalServerError, newAPIError(errorServerError, "unable to generate client ID"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ci.Credentials.ID = clientID
|
|
||||||
creds, err := c.repo.New(ci)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed creating client: %v", err)
|
log.Errorf("Failed creating client: %v", err)
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
"github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
|
@ -28,8 +29,10 @@ func makeBody(s string) io.ReadCloser {
|
||||||
func TestCreateInvalidRequest(t *testing.T) {
|
func TestCreateInvalidRequest(t *testing.T) {
|
||||||
u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
|
u := &url.URL{Scheme: "http", Host: "example.com", Path: "clients"}
|
||||||
h := http.Header{"Content-Type": []string{"application/json"}}
|
h := http.Header{"Content-Type": []string{"application/json"}}
|
||||||
repo := db.NewClientRepo(db.NewMemDB())
|
dbm := db.NewMemDB()
|
||||||
res := &clientResource{repo: repo}
|
repo := db.NewClientRepo(dbm)
|
||||||
|
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
|
||||||
|
res := &clientResource{manager: manager}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
req *http.Request
|
req *http.Request
|
||||||
wantCode int
|
wantCode int
|
||||||
|
@ -119,8 +122,10 @@ func TestCreateInvalidRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreate(t *testing.T) {
|
func TestCreate(t *testing.T) {
|
||||||
repo := db.NewClientRepo(db.NewMemDB())
|
dbm := db.NewMemDB()
|
||||||
res := &clientResource{repo: repo}
|
repo := db.NewClientRepo(dbm)
|
||||||
|
manager := manager.NewClientManager(repo, db.TransactionFactory(dbm), manager.ManagerOptions{})
|
||||||
|
res := &clientResource{manager: manager}
|
||||||
tests := [][]string{
|
tests := [][]string{
|
||||||
[]string{"http://example.com"},
|
[]string{"http://example.com"},
|
||||||
[]string{"https://example.com"},
|
[]string{"https://example.com"},
|
||||||
|
@ -190,7 +195,7 @@ func TestList(t *testing.T) {
|
||||||
{
|
{
|
||||||
cs: []client.Client{
|
cs: []client.Client{
|
||||||
client.Client{
|
client.Client{
|
||||||
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
|
Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
url.URL{Scheme: "http", Host: "example.com"},
|
url.URL{Scheme: "http", Host: "example.com"},
|
||||||
|
@ -200,7 +205,7 @@ func TestList(t *testing.T) {
|
||||||
},
|
},
|
||||||
want: []*schema.Client{
|
want: []*schema.Client{
|
||||||
&schema.Client{
|
&schema.Client{
|
||||||
Id: "foo",
|
Id: "example.com",
|
||||||
RedirectURIs: []string{"http://example.com"},
|
RedirectURIs: []string{"http://example.com"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -209,7 +214,7 @@ func TestList(t *testing.T) {
|
||||||
{
|
{
|
||||||
cs: []client.Client{
|
cs: []client.Client{
|
||||||
client.Client{
|
client.Client{
|
||||||
Credentials: oidc.ClientCredentials{ID: "foo", Secret: b64Encode("bar")},
|
Credentials: oidc.ClientCredentials{ID: "example.com", Secret: b64Encode("secret")},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
url.URL{Scheme: "http", Host: "example.com"},
|
url.URL{Scheme: "http", Host: "example.com"},
|
||||||
|
@ -217,21 +222,21 @@ func TestList(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
client.Client{
|
client.Client{
|
||||||
Credentials: oidc.ClientCredentials{ID: "biz", Secret: b64Encode("bang")},
|
Credentials: oidc.ClientCredentials{ID: "example2.com", Secret: b64Encode("secret")},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
url.URL{Scheme: "https", Host: "example.com", Path: "one/two/three"},
|
url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []*schema.Client{
|
want: []*schema.Client{
|
||||||
&schema.Client{
|
&schema.Client{
|
||||||
Id: "biz",
|
Id: "example2.com",
|
||||||
RedirectURIs: []string{"https://example.com/one/two/three"},
|
RedirectURIs: []string{"https://example2.com/one/two/three"},
|
||||||
},
|
},
|
||||||
&schema.Client{
|
&schema.Client{
|
||||||
Id: "foo",
|
Id: "example.com",
|
||||||
RedirectURIs: []string{"http://example.com"},
|
RedirectURIs: []string{"http://example.com"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -239,12 +244,20 @@ func TestList(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), tt.cs)
|
dbm := db.NewMemDB()
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := manager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), tt.cs, manager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
res := &clientResource{repo: repo}
|
res := &clientResource{manager: clientManager}
|
||||||
|
|
||||||
r, err := http.NewRequest("GET", "http://example.com/clients", nil)
|
r, err := http.NewRequest("GET", "http://example.com/clients", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/go-gorp/gorp"
|
"github.com/go-gorp/gorp"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/email"
|
"github.com/coreos/dex/email"
|
||||||
|
@ -114,9 +115,11 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
|
return fmt.Errorf("unable to read clients from file %s: %v", cfg.ClientsFile, err)
|
||||||
}
|
}
|
||||||
ciRepo, err := db.NewClientRepoFromClients(dbMap, clients)
|
|
||||||
if err != nil {
|
clientRepo := db.NewClientRepo(dbMap)
|
||||||
return fmt.Errorf("failed to create client identity repo: %v", err)
|
|
||||||
|
for _, c := range clients {
|
||||||
|
clientRepo.New(nil, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.Open(cfg.ConnectorsFile)
|
f, err := os.Open(cfg.ConnectorsFile)
|
||||||
|
@ -155,7 +158,12 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
||||||
|
|
||||||
txnFactory := db.TransactionFactory(dbMap)
|
txnFactory := db.TransactionFactory(dbMap)
|
||||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
|
||||||
srv.ClientRepo = ciRepo
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to create client identity manager: %v", err)
|
||||||
|
}
|
||||||
|
srv.ClientRepo = clientRepo
|
||||||
|
srv.ClientManager = clientManager
|
||||||
srv.KeySetRepo = kRepo
|
srv.KeySetRepo = kRepo
|
||||||
srv.ConnectorConfigRepo = cfgRepo
|
srv.ConnectorConfigRepo = cfgRepo
|
||||||
srv.UserRepo = userRepo
|
srv.UserRepo = userRepo
|
||||||
|
@ -253,11 +261,13 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
||||||
userRepo := db.NewUserRepo(dbc)
|
userRepo := db.NewUserRepo(dbc)
|
||||||
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
pwiRepo := db.NewPasswordInfoRepo(dbc)
|
||||||
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
|
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
|
||||||
|
clientManager := clientmanager.NewClientManager(ciRepo, db.TransactionFactory(dbc), clientmanager.ManagerOptions{})
|
||||||
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
|
||||||
|
|
||||||
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
|
||||||
|
|
||||||
srv.ClientRepo = ciRepo
|
srv.ClientRepo = ciRepo
|
||||||
|
srv.ClientManager = clientManager
|
||||||
srv.KeySetRepo = kRepo
|
srv.KeySetRepo = kRepo
|
||||||
srv.ConnectorConfigRepo = cfgRepo
|
srv.ConnectorConfigRepo = cfgRepo
|
||||||
srv.UserRepo = userRepo
|
srv.UserRepo = userRepo
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
useremail "github.com/coreos/dex/user/email"
|
useremail "github.com/coreos/dex/user/email"
|
||||||
|
@ -28,7 +29,7 @@ func handleVerifyEmailResendFunc(
|
||||||
srvKeysFunc func() ([]key.PublicKey, error),
|
srvKeysFunc func() ([]key.PublicKey, error),
|
||||||
emailer *useremail.UserEmailer,
|
emailer *useremail.UserEmailer,
|
||||||
userRepo user.UserRepo,
|
userRepo user.UserRepo,
|
||||||
clientRepo client.ClientRepo) http.HandlerFunc {
|
clientManager *clientmanager.ClientManager) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
decoder := json.NewDecoder(r.Body)
|
decoder := json.NewDecoder(r.Body)
|
||||||
var params struct {
|
var params struct {
|
||||||
|
@ -57,7 +58,7 @@ func handleVerifyEmailResendFunc(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cm, err := clientRepo.Metadata(clientID)
|
cm, err := clientManager.Metadata(clientID)
|
||||||
if err == client.ErrorNotFound {
|
if err == client.ErrorNotFound {
|
||||||
log.Errorf("No such client: %v", err)
|
log.Errorf("No such client: %v", err)
|
||||||
writeAPIError(w, http.StatusBadRequest,
|
writeAPIError(w, http.StatusBadRequest,
|
||||||
|
|
|
@ -130,7 +130,7 @@ func TestHandleVerifyEmailResend(t *testing.T) {
|
||||||
keysFunc,
|
keysFunc,
|
||||||
f.srv.UserEmailer,
|
f.srv.UserEmailer,
|
||||||
f.userRepo,
|
f.userRepo,
|
||||||
f.clientRepo)
|
f.clientManager)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
u := "http://example.com"
|
u := "http://example.com"
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/session/manager"
|
"github.com/coreos/dex/session/manager"
|
||||||
|
@ -75,15 +76,12 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
idpcs := []connector.Connector{
|
idpcs := []connector.Connector{
|
||||||
&fakeConnector{loginURL: "http://fake.example.com"},
|
&fakeConnector{loginURL: "http://fake.example.com"},
|
||||||
}
|
}
|
||||||
srv := &Server{
|
dbm := db.NewMemDB()
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
clients := []client.Client{
|
||||||
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
|
||||||
ClientRepo: func() client.ClientRepo {
|
|
||||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
|
||||||
client.Client{
|
client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: "client.example.com",
|
||||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
|
@ -91,12 +89,24 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
|
||||||
}
|
}
|
||||||
return repo
|
|
||||||
}(),
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
|
}
|
||||||
|
srv := &Server{
|
||||||
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
|
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||||
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -108,7 +118,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
{
|
{
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"client_id": []string{"XXX"},
|
"client_id": []string{"client.example.com"},
|
||||||
"connector_id": []string{"fake"},
|
"connector_id": []string{"fake"},
|
||||||
"scope": []string{"openid"},
|
"scope": []string{"openid"},
|
||||||
},
|
},
|
||||||
|
@ -121,7 +131,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"redirect_uri": []string{"http://client.example.com/callback"},
|
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||||
"client_id": []string{"XXX"},
|
"client_id": []string{"client.example.com"},
|
||||||
"connector_id": []string{"fake"},
|
"connector_id": []string{"fake"},
|
||||||
"scope": []string{"openid"},
|
"scope": []string{"openid"},
|
||||||
},
|
},
|
||||||
|
@ -134,7 +144,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
||||||
"client_id": []string{"XXX"},
|
"client_id": []string{"client.example.com"},
|
||||||
"connector_id": []string{"fake"},
|
"connector_id": []string{"fake"},
|
||||||
"scope": []string{"openid"},
|
"scope": []string{"openid"},
|
||||||
},
|
},
|
||||||
|
@ -157,7 +167,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
{
|
{
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"response_type": []string{"token"},
|
"response_type": []string{"token"},
|
||||||
"client_id": []string{"XXX"},
|
"client_id": []string{"client.example.com"},
|
||||||
"connector_id": []string{"fake"},
|
"connector_id": []string{"fake"},
|
||||||
"scope": []string{"openid"},
|
"scope": []string{"openid"},
|
||||||
},
|
},
|
||||||
|
@ -170,11 +180,33 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"redirect_uri": []string{"http://client.example.com/callback"},
|
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||||
"client_id": []string{"XXX"},
|
"client_id": []string{"client.example.com"},
|
||||||
"connector_id": []string{"fake"},
|
"connector_id": []string{"fake"},
|
||||||
},
|
},
|
||||||
wantCode: http.StatusBadRequest,
|
wantCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
|
// empty response_type
|
||||||
|
{
|
||||||
|
query: url.Values{
|
||||||
|
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||||
|
"client_id": []string{"client.example.com"},
|
||||||
|
"connector_id": []string{"fake"},
|
||||||
|
"scope": []string{"openid"},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusFound,
|
||||||
|
wantLocation: "http://client.example.com/callback?error=unsupported_response_type&state=",
|
||||||
|
},
|
||||||
|
|
||||||
|
// empty client_id
|
||||||
|
{
|
||||||
|
query: url.Values{
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
||||||
|
"connector_id": []string{"fake"},
|
||||||
|
"scope": []string{"openid"},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
@ -204,14 +236,12 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
||||||
idpcs := []connector.Connector{
|
idpcs := []connector.Connector{
|
||||||
&fakeConnector{loginURL: "http://fake.example.com"},
|
&fakeConnector{loginURL: "http://fake.example.com"},
|
||||||
}
|
}
|
||||||
srv := &Server{
|
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
dbm := db.NewMemDB()
|
||||||
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
clients := []client.Client{
|
||||||
ClientRepo: func() client.ClientRepo {
|
|
||||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
|
||||||
client.Client{
|
client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: "foo.example.com",
|
||||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
|
@ -221,12 +251,24 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
|
||||||
}
|
}
|
||||||
return repo
|
|
||||||
}(),
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
|
}
|
||||||
|
srv := &Server{
|
||||||
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
|
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
|
||||||
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -239,7 +281,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"redirect_uri": []string{"http://foo.example.com/callback"},
|
"redirect_uri": []string{"http://foo.example.com/callback"},
|
||||||
"client_id": []string{"XXX"},
|
"client_id": []string{"foo.example.com"},
|
||||||
"connector_id": []string{"fake"},
|
"connector_id": []string{"fake"},
|
||||||
"scope": []string{"openid"},
|
"scope": []string{"openid"},
|
||||||
},
|
},
|
||||||
|
@ -252,7 +294,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"redirect_uri": []string{"http://bar.example.com/callback"},
|
"redirect_uri": []string{"http://bar.example.com/callback"},
|
||||||
"client_id": []string{"XXX"},
|
"client_id": []string{"foo.example.com"},
|
||||||
"connector_id": []string{"fake"},
|
"connector_id": []string{"fake"},
|
||||||
"scope": []string{"openid"},
|
"scope": []string{"openid"},
|
||||||
},
|
},
|
||||||
|
@ -265,7 +307,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
||||||
"client_id": []string{"XXX"},
|
"client_id": []string{"foo.example.com"},
|
||||||
"connector_id": []string{"fake"},
|
"connector_id": []string{"fake"},
|
||||||
"scope": []string{"openid"},
|
"scope": []string{"openid"},
|
||||||
},
|
},
|
||||||
|
@ -276,7 +318,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
||||||
{
|
{
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"client_id": []string{"XXX"},
|
"client_id": []string{"foo.example.com"},
|
||||||
"connector_id": []string{"fake"},
|
"connector_id": []string{"fake"},
|
||||||
"scope": []string{"openid"},
|
"scope": []string{"openid"},
|
||||||
},
|
},
|
||||||
|
@ -328,8 +370,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
||||||
"grant_type": []string{"invalid!"},
|
"grant_type": []string{"invalid!"},
|
||||||
"code": []string{"someCode"},
|
"code": []string{"someCode"},
|
||||||
},
|
},
|
||||||
user: "XXX",
|
user: testClientID,
|
||||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
wantCode: http.StatusBadRequest,
|
wantCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -338,8 +380,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
||||||
query: url.Values{
|
query: url.Values{
|
||||||
"grant_type": []string{"authorization_code"},
|
"grant_type": []string{"authorization_code"},
|
||||||
},
|
},
|
||||||
user: "XXX",
|
user: testClientID,
|
||||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
wantCode: http.StatusBadRequest,
|
wantCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -349,8 +391,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
||||||
"grant_type": []string{"authorization_code"},
|
"grant_type": []string{"authorization_code"},
|
||||||
"code": []string{""},
|
"code": []string{""},
|
||||||
},
|
},
|
||||||
user: "XXX",
|
user: testClientID,
|
||||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
wantCode: http.StatusBadRequest,
|
wantCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -371,8 +413,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
||||||
"grant_type": []string{"authorization_code"},
|
"grant_type": []string{"authorization_code"},
|
||||||
"code": []string{"asdasd"},
|
"code": []string{"asdasd"},
|
||||||
},
|
},
|
||||||
user: "XXX",
|
user: testClientID,
|
||||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
wantCode: http.StatusBadRequest,
|
wantCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -382,8 +424,8 @@ func TestHandleTokenFunc(t *testing.T) {
|
||||||
"grant_type": []string{"authorization_code"},
|
"grant_type": []string{"authorization_code"},
|
||||||
"code": []string{"code-2"},
|
"code": []string{"code-2"},
|
||||||
},
|
},
|
||||||
user: "XXX",
|
user: testClientID,
|
||||||
passwd: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
passwd: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
wantCode: http.StatusOK,
|
wantCode: http.StatusOK,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -402,7 +444,7 @@ func TestHandleTokenFunc(t *testing.T) {
|
||||||
|
|
||||||
// need to create session in order to exchange the code (generated by the NewSessionKey func) for token
|
// need to create session in order to exchange the code (generated by the NewSessionKey func) for token
|
||||||
setSession := func() error {
|
setSession := func() error {
|
||||||
sid, err := fx.sessionManager.NewSession("local", "XXX", "", testRedirectURL, "", true, []string{"openid"})
|
sid, err := fx.sessionManager.NewSession("local", testClientID, "", testRedirectURL, "", true, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("case %d: cannot create session, error=%v", i, err)
|
return fmt.Errorf("case %d: cannot create session, error=%v", i, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/coreos/go-oidc/key"
|
"github.com/coreos/go-oidc/key"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
sessionmanager "github.com/coreos/dex/session/manager"
|
sessionmanager "github.com/coreos/dex/session/manager"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
|
@ -29,7 +30,7 @@ type SendResetPasswordEmailHandler struct {
|
||||||
tpl *template.Template
|
tpl *template.Template
|
||||||
emailer *useremail.UserEmailer
|
emailer *useremail.UserEmailer
|
||||||
sm *sessionmanager.SessionManager
|
sm *sessionmanager.SessionManager
|
||||||
cr client.ClientRepo
|
cm *clientmanager.ClientManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SendResetPasswordEmailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *SendResetPasswordEmailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -128,7 +129,7 @@ func (h *SendResetPasswordEmailHandler) validateRedirectURL(clientID string, red
|
||||||
return url.URL{}, false
|
return url.URL{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
cm, err := h.cr.Metadata(clientID)
|
cm, err := h.cm.Metadata(clientID)
|
||||||
if err != nil || cm == nil {
|
if err != nil || cm == nil {
|
||||||
log.Errorf("Error getting ClientMetadata: %v", err)
|
log.Errorf("Error getting ClientMetadata: %v", err)
|
||||||
return url.URL{}, false
|
return url.URL{}, false
|
||||||
|
|
|
@ -253,7 +253,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
|
||||||
t.Fatalf("case %d: could not make test fixtures: %v", i, err)
|
t.Fatalf("case %d: could not make test fixtures: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, []string{"openid"})
|
_, err = f.srv.NewSession("local", testClientID, "", f.redirectURL, "", true, []string{"openid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: could not create new session: %v", i, err)
|
t.Fatalf("case %d: could not create new session: %v", i, err)
|
||||||
}
|
}
|
||||||
|
@ -267,7 +267,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
|
||||||
tpl: f.srv.SendResetPasswordEmailTemplate,
|
tpl: f.srv.SendResetPasswordEmailTemplate,
|
||||||
emailer: f.srv.UserEmailer,
|
emailer: f.srv.UserEmailer,
|
||||||
sm: f.sessionManager,
|
sm: f.sessionManager,
|
||||||
cr: f.clientRepo,
|
cm: f.clientManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
|
@ -295,7 +295,7 @@ func TestHandleRegister(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"})
|
key, err := f.srv.NewSession(tt.connID, testClientID, "", f.redirectURL, "", true, []string{"openid"})
|
||||||
t.Logf("case %d: key for NewSession: %v", i, key)
|
t.Logf("case %d: key for NewSession: %v", i, key)
|
||||||
|
|
||||||
if tt.attachRemote {
|
if tt.attachRemote {
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"github.com/jonboulle/clockwork"
|
"github.com/jonboulle/clockwork"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/refresh"
|
"github.com/coreos/dex/refresh"
|
||||||
|
@ -72,6 +73,7 @@ type Server struct {
|
||||||
Connectors []connector.Connector
|
Connectors []connector.Connector
|
||||||
UserRepo user.UserRepo
|
UserRepo user.UserRepo
|
||||||
UserManager *usermanager.UserManager
|
UserManager *usermanager.UserManager
|
||||||
|
ClientManager *clientmanager.ClientManager
|
||||||
PasswordInfoRepo user.PasswordInfoRepo
|
PasswordInfoRepo user.PasswordInfoRepo
|
||||||
RefreshTokenRepo refresh.RefreshTokenRepo
|
RefreshTokenRepo refresh.RefreshTokenRepo
|
||||||
UserEmailer *useremail.UserEmailer
|
UserEmailer *useremail.UserEmailer
|
||||||
|
@ -213,13 +215,13 @@ func (s *Server) HTTPHandler() http.Handler {
|
||||||
s.KeyManager.PublicKeys,
|
s.KeyManager.PublicKeys,
|
||||||
s.UserEmailer,
|
s.UserEmailer,
|
||||||
s.UserRepo,
|
s.UserRepo,
|
||||||
s.ClientRepo)))
|
s.ClientManager)))
|
||||||
|
|
||||||
mux.Handle(httpPathSendResetPassword, &SendResetPasswordEmailHandler{
|
mux.Handle(httpPathSendResetPassword, &SendResetPasswordEmailHandler{
|
||||||
tpl: s.SendResetPasswordEmailTemplate,
|
tpl: s.SendResetPasswordEmailTemplate,
|
||||||
emailer: s.UserEmailer,
|
emailer: s.UserEmailer,
|
||||||
sm: s.SessionManager,
|
sm: s.SessionManager,
|
||||||
cr: s.ClientRepo,
|
cm: s.ClientManager,
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.Handle(httpPathResetPassword, &ResetPasswordHandler{
|
mux.Handle(httpPathResetPassword, &ResetPasswordHandler{
|
||||||
|
@ -256,11 +258,11 @@ func (s *Server) HTTPHandler() http.Handler {
|
||||||
apiBasePath := path.Join(httpPathAPI, APIVersion)
|
apiBasePath := path.Join(httpPathAPI, APIVersion)
|
||||||
registerDiscoveryResource(apiBasePath, mux)
|
registerDiscoveryResource(apiBasePath, mux)
|
||||||
|
|
||||||
clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientRepo)
|
clientPath, clientHandler := registerClientResource(apiBasePath, s.ClientManager)
|
||||||
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
|
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
|
||||||
|
|
||||||
usersAPI := usersapi.NewUsersAPI(s.dbMap, s.UserManager, s.UserEmailer, s.localConnectorID)
|
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientManager, s.RefreshTokenRepo, s.UserEmailer, s.localConnectorID)
|
||||||
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientRepo).HTTPHandler()
|
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientManager).HTTPHandler()
|
||||||
|
|
||||||
mux.Handle(apiBasePath+"/", handler)
|
mux.Handle(apiBasePath+"/", handler)
|
||||||
|
|
||||||
|
@ -271,14 +273,14 @@ func (s *Server) HTTPHandler() http.Handler {
|
||||||
func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler {
|
func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler {
|
||||||
return &clientTokenMiddleware{
|
return &clientTokenMiddleware{
|
||||||
issuerURL: s.IssuerURL.String(),
|
issuerURL: s.IssuerURL.String(),
|
||||||
ciRepo: s.ClientRepo,
|
ciManager: s.ClientManager,
|
||||||
keysFunc: s.KeyManager.PublicKeys,
|
keysFunc: s.KeyManager.PublicKeys,
|
||||||
next: handler,
|
next: handler,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) {
|
func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) {
|
||||||
return s.ClientRepo.Metadata(clientID)
|
return s.ClientManager.Metadata(clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
|
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
|
||||||
|
@ -365,9 +367,9 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
|
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
|
||||||
ok, err := s.ClientRepo.Authenticate(creds)
|
ok, err := s.ClientManager.Authenticate(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -397,7 +399,7 @@ func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
|
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
|
||||||
ok, err := s.ClientRepo.Authenticate(creds)
|
ok, err := s.ClientManager.Authenticate(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||||
|
@ -466,7 +468,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
|
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
|
||||||
ok, err := s.ClientRepo.Authenticate(creds)
|
ok, err := s.ClientManager.Authenticate(creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
|
||||||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/refresh/refreshtest"
|
"github.com/coreos/dex/refresh/refreshtest"
|
||||||
"github.com/coreos/dex/session/manager"
|
"github.com/coreos/dex/session/manager"
|
||||||
|
@ -21,7 +22,12 @@ import (
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
)
|
)
|
||||||
|
|
||||||
var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secrete"))
|
var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
|
||||||
|
var validRedirURL = url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "client.example.com",
|
||||||
|
Path: "/callback",
|
||||||
|
}
|
||||||
|
|
||||||
type StaticKeyManager struct {
|
type StaticKeyManager struct {
|
||||||
key.PrivateKeyManager
|
key.PrivateKeyManager
|
||||||
|
@ -132,8 +138,8 @@ func TestServerNewSession(t *testing.T) {
|
||||||
nonce := "oncenay"
|
nonce := "oncenay"
|
||||||
ci := client.Client{
|
ci := client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: testClientID,
|
||||||
Secret: "secrete",
|
Secret: clientTestSecret,
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
|
@ -181,7 +187,7 @@ func TestServerNewSession(t *testing.T) {
|
||||||
func TestServerLogin(t *testing.T) {
|
func TestServerLogin(t *testing.T) {
|
||||||
ci := client.Client{
|
ci := client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: testClientID,
|
||||||
Secret: clientTestSecret,
|
Secret: clientTestSecret,
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
|
@ -194,13 +200,13 @@ func TestServerLogin(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
ciRepo := func() client.ClientRepo {
|
|
||||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
dbm := db.NewMemDB()
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), []client.Client{ci}, clientmanager.ManagerOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
}
|
}
|
||||||
return repo
|
|
||||||
}()
|
|
||||||
|
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||||
|
@ -222,7 +228,8 @@ func TestServerLogin(t *testing.T) {
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
ClientRepo: ciRepo,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
UserRepo: userRepo,
|
UserRepo: userRepo,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,20 +251,30 @@ func TestServerLogin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
||||||
ciRepo := func() client.ClientRepo {
|
clients := []client.Client{
|
||||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
|
||||||
client.Client{
|
client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX", Secret: clientTestSecret,
|
ID: testClientID, Secret: clientTestSecret,
|
||||||
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
validRedirURL,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
|
||||||
}
|
}
|
||||||
return repo
|
dbm := db.NewMemDB()
|
||||||
}()
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
|
}
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
|
||||||
}
|
}
|
||||||
|
@ -266,11 +283,12 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
ClientRepo: ciRepo,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
ident := oidc.Identity{ID: "YYY", Name: "elroy", Email: "elroy@example.com"}
|
ident := oidc.Identity{ID: "YYY", Name: "elroy", Email: "elroy@example.com"}
|
||||||
code, err := srv.Login(ident, "XXX")
|
code, err := srv.Login(ident, testClientID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Expected non-nil error")
|
t.Fatalf("Expected non-nil error")
|
||||||
}
|
}
|
||||||
|
@ -283,27 +301,28 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
|
||||||
func TestServerLoginDisabledUser(t *testing.T) {
|
func TestServerLoginDisabledUser(t *testing.T) {
|
||||||
ci := client.Client{
|
ci := client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: testClientID,
|
||||||
Secret: clientTestSecret,
|
Secret: clientTestSecret,
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
url.URL{
|
validRedirURL,
|
||||||
Scheme: "http",
|
|
||||||
Host: "client.example.com",
|
|
||||||
Path: "/callback",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
ciRepo := func() client.ClientRepo {
|
clients := []client.Client{ci}
|
||||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
dbm := db.NewMemDB()
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
}
|
}
|
||||||
return repo
|
|
||||||
}()
|
|
||||||
|
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||||
}
|
}
|
||||||
|
@ -338,7 +357,8 @@ func TestServerLoginDisabledUser(t *testing.T) {
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
ClientRepo: ciRepo,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
UserRepo: userRepo,
|
UserRepo: userRepo,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,17 +377,28 @@ func TestServerLoginDisabledUser(t *testing.T) {
|
||||||
func TestServerCodeToken(t *testing.T) {
|
func TestServerCodeToken(t *testing.T) {
|
||||||
ci := client.Client{
|
ci := client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: testClientID,
|
||||||
Secret: clientTestSecret,
|
Secret: clientTestSecret,
|
||||||
},
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
validRedirURL,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
ciRepo := func() client.ClientRepo {
|
clients := []client.Client{ci}
|
||||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
dbm := db.NewMemDB()
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
}
|
}
|
||||||
return repo
|
|
||||||
}()
|
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||||
}
|
}
|
||||||
|
@ -384,7 +415,8 @@ func TestServerCodeToken(t *testing.T) {
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
ClientRepo: ciRepo,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
UserRepo: userRepo,
|
UserRepo: userRepo,
|
||||||
RefreshTokenRepo: refreshTokenRepo,
|
RefreshTokenRepo: refreshTokenRepo,
|
||||||
}
|
}
|
||||||
|
@ -443,17 +475,29 @@ func TestServerCodeToken(t *testing.T) {
|
||||||
func TestServerTokenUnrecognizedKey(t *testing.T) {
|
func TestServerTokenUnrecognizedKey(t *testing.T) {
|
||||||
ci := client.Client{
|
ci := client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: testClientID,
|
||||||
Secret: clientTestSecret,
|
Secret: clientTestSecret,
|
||||||
},
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
validRedirURL,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
ciRepo := func() client.ClientRepo {
|
|
||||||
repo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{ci})
|
clients := []client.Client{ci}
|
||||||
|
dbm := db.NewMemDB()
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create client identity repo: %v", err)
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
}
|
}
|
||||||
return repo
|
|
||||||
}()
|
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
signer: &StaticSigner{sig: []byte("beer"), err: nil},
|
||||||
}
|
}
|
||||||
|
@ -463,7 +507,8 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
|
||||||
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
ClientRepo: ciRepo,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
|
sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
|
||||||
|
@ -492,7 +537,7 @@ func TestServerTokenFail(t *testing.T) {
|
||||||
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
||||||
keyFixture := "goodkey"
|
keyFixture := "goodkey"
|
||||||
ccFixture := oidc.ClientCredentials{
|
ccFixture := oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: testClientID,
|
||||||
Secret: clientTestSecret,
|
Secret: clientTestSecret,
|
||||||
}
|
}
|
||||||
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
||||||
|
@ -569,14 +614,29 @@ func TestServerTokenFail(t *testing.T) {
|
||||||
km := &StaticKeyManager{
|
km := &StaticKeyManager{
|
||||||
signer: tt.signer,
|
signer: tt.signer,
|
||||||
}
|
}
|
||||||
ciRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
|
||||||
client.Client{Credentials: ccFixture},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
|
clients := []client.Client{
|
||||||
|
client.Client{
|
||||||
|
Credentials: ccFixture,
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
validRedirURL,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dbm := db.NewMemDB()
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
|
}
|
||||||
_, err = sm.AttachUser(sessionID, "testid-1")
|
_, err = sm.AttachUser(sessionID, "testid-1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||||
|
@ -593,7 +653,8 @@ func TestServerTokenFail(t *testing.T) {
|
||||||
IssuerURL: issuerURL,
|
IssuerURL: issuerURL,
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
ClientRepo: ciRepo,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
UserRepo: userRepo,
|
UserRepo: userRepo,
|
||||||
RefreshTokenRepo: refreshTokenRepo,
|
RefreshTokenRepo: refreshTokenRepo,
|
||||||
}
|
}
|
||||||
|
@ -623,14 +684,27 @@ func TestServerTokenFail(t *testing.T) {
|
||||||
|
|
||||||
func TestServerRefreshToken(t *testing.T) {
|
func TestServerRefreshToken(t *testing.T) {
|
||||||
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
|
||||||
|
clientA := client.Client{
|
||||||
credXXX := oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: testClientID,
|
||||||
Secret: clientTestSecret,
|
Secret: clientTestSecret,
|
||||||
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
url.URL{Scheme: "https", Host: "client.example.com", Path: "one/two/three"},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
credYYY := oidc.ClientCredentials{
|
clientB := client.Client{
|
||||||
ID: "YYY",
|
Credentials: oidc.ClientCredentials{
|
||||||
|
ID: "example2.com",
|
||||||
Secret: clientTestSecret,
|
Secret: clientTestSecret,
|
||||||
|
},
|
||||||
|
Metadata: oidc.ClientMetadata{
|
||||||
|
RedirectURIs: []url.URL{
|
||||||
|
url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
||||||
|
@ -647,47 +721,47 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
// Everything is good.
|
// Everything is good.
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
credXXX,
|
clientA.Credentials,
|
||||||
signerFixture,
|
signerFixture,
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
// Invalid refresh token(malformatted).
|
// Invalid refresh token(malformatted).
|
||||||
{
|
{
|
||||||
"invalid-token",
|
"invalid-token",
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
credXXX,
|
clientA.Credentials,
|
||||||
signerFixture,
|
signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||||
},
|
},
|
||||||
// Invalid refresh token(invalid payload content).
|
// Invalid refresh token(invalid payload content).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
credXXX,
|
clientA.Credentials,
|
||||||
signerFixture,
|
signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||||
},
|
},
|
||||||
// Invalid refresh token(invalid ID content).
|
// Invalid refresh token(invalid ID content).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
credXXX,
|
clientA.Credentials,
|
||||||
signerFixture,
|
signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||||||
},
|
},
|
||||||
// Invalid client(client is not associated with the token).
|
// Invalid client(client is not associated with the token).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
credYYY,
|
clientB.Credentials,
|
||||||
signerFixture,
|
signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
},
|
},
|
||||||
// Invalid client(no client ID).
|
// Invalid client(no client ID).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
oidc.ClientCredentials{ID: "", Secret: "aaa"},
|
||||||
signerFixture,
|
signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
|
@ -695,7 +769,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
// Invalid client(no such client).
|
// Invalid client(no such client).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
|
||||||
signerFixture,
|
signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
|
@ -703,24 +777,24 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
// Invalid client(no secrets).
|
// Invalid client(no secrets).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
oidc.ClientCredentials{ID: "XXX"},
|
oidc.ClientCredentials{ID: testClientID},
|
||||||
signerFixture,
|
signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
},
|
},
|
||||||
// Invalid client(invalid secret).
|
// Invalid client(invalid secret).
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
|
oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
|
||||||
signerFixture,
|
signerFixture,
|
||||||
oauth2.NewError(oauth2.ErrorInvalidClient),
|
oauth2.NewError(oauth2.ErrorInvalidClient),
|
||||||
},
|
},
|
||||||
// Signing operation fails.
|
// Signing operation fails.
|
||||||
{
|
{
|
||||||
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
|
||||||
"XXX",
|
clientA.Credentials.ID,
|
||||||
credXXX,
|
clientA.Credentials,
|
||||||
&StaticSigner{sig: nil, err: errors.New("fail")},
|
&StaticSigner{sig: nil, err: errors.New("fail")},
|
||||||
oauth2.NewError(oauth2.ErrorServerError),
|
oauth2.NewError(oauth2.ErrorServerError),
|
||||||
},
|
},
|
||||||
|
@ -731,15 +805,23 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
signer: tt.signer,
|
signer: tt.signer,
|
||||||
}
|
}
|
||||||
|
|
||||||
ciRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
clients := []client.Client{
|
||||||
client.Client{Credentials: credXXX},
|
clientA,
|
||||||
client.Client{Credentials: credYYY},
|
clientB,
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("case %d: failed to create client identity repo: %v", i, err)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
dbm := db.NewMemDB()
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
|
}
|
||||||
userRepo, err := makeNewUserRepo()
|
userRepo, err := makeNewUserRepo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
@ -750,7 +832,8 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: issuerURL,
|
IssuerURL: issuerURL,
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
ClientRepo: ciRepo,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
UserRepo: userRepo,
|
UserRepo: userRepo,
|
||||||
RefreshTokenRepo: refreshTokenRepo,
|
RefreshTokenRepo: refreshTokenRepo,
|
||||||
}
|
}
|
||||||
|
@ -772,7 +855,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Case %d: unexpected error: %v", i, err)
|
t.Errorf("Case %d: unexpected error: %v", i, err)
|
||||||
}
|
}
|
||||||
if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != "XXX" {
|
if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != testClientID {
|
||||||
t.Errorf("Case %d: invalid claims: %v", i, claims)
|
t.Errorf("Case %d: invalid claims: %v", i, claims)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -784,14 +867,22 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
signer: signerFixture,
|
signer: signerFixture,
|
||||||
}
|
}
|
||||||
|
|
||||||
ciRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
clients := []client.Client{
|
||||||
client.Client{Credentials: credXXX},
|
clientA,
|
||||||
client.Client{Credentials: credYYY},
|
clientB,
|
||||||
})
|
}
|
||||||
if err != nil {
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
t.Fatalf("failed to create client identity repo: %v", err)
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
dbm := db.NewMemDB()
|
||||||
|
clientRepo := db.NewClientRepo(dbm)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create client identity manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
userRepo, err := makeNewUserRepo()
|
userRepo, err := makeNewUserRepo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
@ -810,12 +901,13 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
IssuerURL: issuerURL,
|
IssuerURL: issuerURL,
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
ClientRepo: ciRepo,
|
ClientRepo: clientRepo,
|
||||||
|
ClientManager: clientManager,
|
||||||
UserRepo: userRepo,
|
UserRepo: userRepo,
|
||||||
RefreshTokenRepo: refreshTokenRepo,
|
RefreshTokenRepo: refreshTokenRepo,
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := refreshTokenRepo.Create("testid-2", credXXX.ID); err != nil {
|
if _, err := refreshTokenRepo.Create("testid-2", clientA.Credentials.ID); err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -826,7 +918,7 @@ func TestServerRefreshToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
srv.UserRepo = userRepo
|
srv.UserRepo = userRepo
|
||||||
|
|
||||||
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
_, err = srv.RefreshToken(clientA.Credentials, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
|
||||||
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
|
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
|
||||||
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
|
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
"github.com/coreos/dex/email"
|
"github.com/coreos/dex/email"
|
||||||
|
@ -26,7 +27,7 @@ const (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
|
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
|
||||||
testClientID = "XXX"
|
testClientID = "client.example.com"
|
||||||
|
|
||||||
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
|
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
|
||||||
|
|
||||||
|
@ -79,6 +80,7 @@ type testFixtures struct {
|
||||||
emailer *email.TemplatizedEmailer
|
emailer *email.TemplatizedEmailer
|
||||||
redirectURL url.URL
|
redirectURL url.URL
|
||||||
clientRepo client.ClientRepo
|
clientRepo client.ClientRepo
|
||||||
|
clientManager *clientmanager.ClientManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
|
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
|
||||||
|
@ -123,7 +125,7 @@ func makeTestFixtures() (*testFixtures, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
|
userManager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, db.TransactionFactory(dbMap), usermanager.ManagerOptions{})
|
||||||
|
|
||||||
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
|
||||||
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
|
||||||
|
@ -136,11 +138,11 @@ func makeTestFixtures() (*testFixtures, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
clientRepo, err := db.NewClientRepoFromClients(db.NewMemDB(), []client.Client{
|
clients := []client.Client{
|
||||||
client.Client{
|
client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: testClientID,
|
||||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
|
@ -148,11 +150,19 @@ func makeTestFixtures() (*testFixtures, error) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbMap)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
km := key.NewPrivateKeyManager()
|
km := key.NewPrivateKeyManager()
|
||||||
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)))
|
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{testPrivKey}, time.Now().Add(time.Minute)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -173,7 +183,8 @@ func makeTestFixtures() (*testFixtures, error) {
|
||||||
Templates: tpl,
|
Templates: tpl,
|
||||||
UserRepo: userRepo,
|
UserRepo: userRepo,
|
||||||
PasswordInfoRepo: pwRepo,
|
PasswordInfoRepo: pwRepo,
|
||||||
UserManager: manager,
|
UserManager: userManager,
|
||||||
|
ClientManager: clientManager,
|
||||||
KeyManager: km,
|
KeyManager: km,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,5 +218,6 @@ func makeTestFixtures() (*testFixtures, error) {
|
||||||
sessionManager: sessionManager,
|
sessionManager: sessionManager,
|
||||||
emailer: emailer,
|
emailer: emailer,
|
||||||
clientRepo: clientRepo,
|
clientRepo: clientRepo,
|
||||||
|
clientManager: clientManager,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,12 +11,12 @@ import (
|
||||||
"github.com/coreos/go-oidc/oidc"
|
"github.com/coreos/go-oidc/oidc"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/dex/user/api"
|
"github.com/coreos/dex/user/api"
|
||||||
"github.com/coreos/dex/user/manager"
|
usermanager "github.com/coreos/dex/user/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -38,16 +38,16 @@ var (
|
||||||
type UserMgmtServer struct {
|
type UserMgmtServer struct {
|
||||||
api *api.UsersAPI
|
api *api.UsersAPI
|
||||||
jwtvFactory JWTVerifierFactory
|
jwtvFactory JWTVerifierFactory
|
||||||
um *manager.UserManager
|
um *usermanager.UserManager
|
||||||
cir client.ClientRepo
|
cm *clientmanager.ClientManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *manager.UserManager, cir client.ClientRepo) *UserMgmtServer {
|
func NewUserMgmtServer(userMgmtAPI *api.UsersAPI, jwtvFactory JWTVerifierFactory, um *usermanager.UserManager, cm *clientmanager.ClientManager) *UserMgmtServer {
|
||||||
return &UserMgmtServer{
|
return &UserMgmtServer{
|
||||||
api: userMgmtAPI,
|
api: userMgmtAPI,
|
||||||
jwtvFactory: jwtvFactory,
|
jwtvFactory: jwtvFactory,
|
||||||
um: um,
|
um: um,
|
||||||
cir: cir,
|
cm: cm,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,7 +295,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request, requiresAdmin bool) (api.Cred
|
||||||
return api.Creds{}, err
|
return api.Creds{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
isAdmin, err := s.cir.IsDexAdmin(clientID)
|
isAdmin, err := s.cm.IsDexAdmin(clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("userMgmtServer: GetCreds err: %q", err)
|
log.Errorf("userMgmtServer: GetCreds err: %q", err)
|
||||||
return api.Creds{}, err
|
return api.Creds{}, err
|
||||||
|
|
2
test
2
test
|
@ -18,7 +18,7 @@ if [ ! -d $GOPATH/pkg ]; then
|
||||||
echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests."
|
echo "WARNING: No cached builds detected. Please run the ./build script to speed up future tests."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin"
|
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/time pkg/html functional/repo server session session/manager user user/api user/manager user/email email admin client client/manager"
|
||||||
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
|
||||||
|
|
||||||
# user has not provided PKG override
|
# user has not provided PKG override
|
||||||
|
|
|
@ -9,15 +9,13 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gorp/gorp"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
"github.com/coreos/dex/db"
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/pkg/log"
|
"github.com/coreos/dex/pkg/log"
|
||||||
"github.com/coreos/dex/refresh"
|
"github.com/coreos/dex/refresh"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
"github.com/coreos/dex/user"
|
"github.com/coreos/dex/user"
|
||||||
"github.com/coreos/dex/user/manager"
|
usermanager "github.com/coreos/dex/user/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -88,9 +86,9 @@ func (e Error) Error() string {
|
||||||
// calling User. It is assumed that the clientID has already validated as an
|
// calling User. It is assumed that the clientID has already validated as an
|
||||||
// admin app before calling.
|
// admin app before calling.
|
||||||
type UsersAPI struct {
|
type UsersAPI struct {
|
||||||
manager *manager.UserManager
|
userManager *usermanager.UserManager
|
||||||
localConnectorID string
|
localConnectorID string
|
||||||
clientRepo client.ClientRepo
|
clientManager *clientmanager.ClientManager
|
||||||
refreshRepo refresh.RefreshTokenRepo
|
refreshRepo refresh.RefreshTokenRepo
|
||||||
emailer Emailer
|
emailer Emailer
|
||||||
}
|
}
|
||||||
|
@ -105,11 +103,11 @@ type Creds struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(ericchiang): Don't pass a dbMap. See #385.
|
// TODO(ericchiang): Don't pass a dbMap. See #385.
|
||||||
func NewUsersAPI(dbMap *gorp.DbMap, userManager *manager.UserManager, emailer Emailer, localConnectorID string) *UsersAPI {
|
func NewUsersAPI(userManager *usermanager.UserManager, clientManager *clientmanager.ClientManager, refreshRepo refresh.RefreshTokenRepo, emailer Emailer, localConnectorID string) *UsersAPI {
|
||||||
return &UsersAPI{
|
return &UsersAPI{
|
||||||
manager: userManager,
|
userManager: userManager,
|
||||||
refreshRepo: db.NewRefreshTokenRepo(dbMap),
|
refreshRepo: refreshRepo,
|
||||||
clientRepo: db.NewClientRepo(dbMap),
|
clientManager: clientManager,
|
||||||
localConnectorID: localConnectorID,
|
localConnectorID: localConnectorID,
|
||||||
emailer: emailer,
|
emailer: emailer,
|
||||||
}
|
}
|
||||||
|
@ -122,7 +120,7 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
|
||||||
return schema.User{}, ErrorUnauthorized
|
return schema.User{}, ErrorUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
usr, err := u.manager.Get(id)
|
usr, err := u.userManager.Get(id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return schema.User{}, mapError(err)
|
return schema.User{}, mapError(err)
|
||||||
|
@ -137,7 +135,7 @@ func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema
|
||||||
return schema.UserDisableResponse{}, ErrorUnauthorized
|
return schema.UserDisableResponse{}, ErrorUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := u.manager.Disable(userID, disable); err != nil {
|
if err := u.userManager.Disable(userID, disable); err != nil {
|
||||||
return schema.UserDisableResponse{}, mapError(err)
|
return schema.UserDisableResponse{}, mapError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,7 +155,7 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
|
||||||
return schema.UserCreateResponse{}, mapError(err)
|
return schema.UserCreateResponse{}, mapError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, err := u.clientRepo.Metadata(creds.ClientID)
|
metadata, err := u.clientManager.Metadata(creds.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return schema.UserCreateResponse{}, mapError(err)
|
return schema.UserCreateResponse{}, mapError(err)
|
||||||
}
|
}
|
||||||
|
@ -167,12 +165,12 @@ func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (s
|
||||||
return schema.UserCreateResponse{}, ErrorInvalidRedirectURL
|
return schema.UserCreateResponse{}, ErrorInvalidRedirectURL
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := u.manager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID)
|
id, err := u.userManager.CreateUser(schemaUserToUser(usr), user.Password(hash), u.localConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return schema.UserCreateResponse{}, mapError(err)
|
return schema.UserCreateResponse{}, mapError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
userUser, err := u.manager.Get(id)
|
userUser, err := u.userManager.Get(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return schema.UserCreateResponse{}, mapError(err)
|
return schema.UserCreateResponse{}, mapError(err)
|
||||||
}
|
}
|
||||||
|
@ -202,7 +200,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
|
||||||
return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized
|
return schema.ResendEmailInvitationResponse{}, ErrorUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, err := u.clientRepo.Metadata(creds.ClientID)
|
metadata, err := u.clientManager.Metadata(creds.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return schema.ResendEmailInvitationResponse{}, mapError(err)
|
return schema.ResendEmailInvitationResponse{}, mapError(err)
|
||||||
}
|
}
|
||||||
|
@ -213,7 +211,7 @@ func (u *UsersAPI) ResendEmailInvitation(creds Creds, userID string, redirURL ur
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve user to check if it's already created
|
// Retrieve user to check if it's already created
|
||||||
userUser, err := u.manager.Get(userID)
|
userUser, err := u.userManager.Get(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return schema.ResendEmailInvitationResponse{}, mapError(err)
|
return schema.ResendEmailInvitationResponse{}, mapError(err)
|
||||||
}
|
}
|
||||||
|
@ -251,7 +249,7 @@ func (u *UsersAPI) ListUsers(creds Creds, maxResults int, nextPageToken string)
|
||||||
return nil, "", ErrorMaxResultsTooHigh
|
return nil, "", ErrorMaxResultsTooHigh
|
||||||
}
|
}
|
||||||
|
|
||||||
users, tok, err := u.manager.List(user.UserFilter{}, maxResults, nextPageToken)
|
users, tok, err := u.userManager.List(user.UserFilter{}, maxResults, nextPageToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", mapError(err)
|
return nil, "", mapError(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/kylelemons/godebug/pretty"
|
"github.com/kylelemons/godebug/pretty"
|
||||||
|
|
||||||
"github.com/coreos/dex/client"
|
"github.com/coreos/dex/client"
|
||||||
|
clientmanager "github.com/coreos/dex/client/manager"
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/db"
|
"github.com/coreos/dex/db"
|
||||||
schema "github.com/coreos/dex/schema/workerschema"
|
schema "github.com/coreos/dex/schema/workerschema"
|
||||||
|
@ -51,13 +52,14 @@ func (t *testEmailer) sendEmail(email string, redirectURL url.URL, clientID stri
|
||||||
|
|
||||||
var (
|
var (
|
||||||
clock = clockwork.NewFakeClock()
|
clock = clockwork.NewFakeClock()
|
||||||
|
goodClientID = "client.example.com"
|
||||||
|
|
||||||
goodCreds = Creds{
|
goodCreds = Creds{
|
||||||
User: user.User{
|
User: user.User{
|
||||||
ID: "ID-1",
|
ID: "ID-1",
|
||||||
Admin: true,
|
Admin: true,
|
||||||
},
|
},
|
||||||
ClientID: "XXX",
|
ClientID: goodClientID,
|
||||||
}
|
}
|
||||||
|
|
||||||
badCreds = Creds{
|
badCreds = Creds{
|
||||||
|
@ -72,7 +74,7 @@ var (
|
||||||
Admin: true,
|
Admin: true,
|
||||||
Disabled: true,
|
Disabled: true,
|
||||||
},
|
},
|
||||||
ClientID: "XXX",
|
ClientID: goodClientID,
|
||||||
}
|
}
|
||||||
|
|
||||||
resetPasswordURL = url.URL{
|
resetPasswordURL = url.URL{
|
||||||
|
@ -82,7 +84,7 @@ var (
|
||||||
|
|
||||||
validRedirURL = url.URL{
|
validRedirURL = url.URL{
|
||||||
Scheme: "http",
|
Scheme: "http",
|
||||||
Host: "client.example.com",
|
Host: goodClientID,
|
||||||
Path: "/callback",
|
Path: "/callback",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -158,8 +160,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
mgr.Clock = clock
|
mgr.Clock = clock
|
||||||
ci := client.Client{
|
ci := client.Client{
|
||||||
Credentials: oidc.ClientCredentials{
|
Credentials: oidc.ClientCredentials{
|
||||||
ID: "XXX",
|
ID: goodClientID,
|
||||||
Secret: base64.URLEncoding.EncodeToString([]byte("secrete")),
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
|
||||||
},
|
},
|
||||||
Metadata: oidc.ClientMetadata{
|
Metadata: oidc.ClientMetadata{
|
||||||
RedirectURIs: []url.URL{
|
RedirectURIs: []url.URL{
|
||||||
|
@ -167,8 +169,17 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if _, err := db.NewClientRepoFromClients(dbMap, []client.Client{ci}); err != nil {
|
|
||||||
panic("Failed to create client repo: " + err.Error())
|
clientIDGenerator := func(hostport string) (string, error) {
|
||||||
|
return hostport, nil
|
||||||
|
}
|
||||||
|
secGen := func() ([]byte, error) {
|
||||||
|
return []byte("secret"), nil
|
||||||
|
}
|
||||||
|
clientRepo := db.NewClientRepo(dbMap)
|
||||||
|
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
|
||||||
|
if err != nil {
|
||||||
|
panic("Failed to create client manager: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Used in TestRevokeRefreshToken test.
|
// Used in TestRevokeRefreshToken test.
|
||||||
|
@ -176,8 +187,8 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
clientID string
|
clientID string
|
||||||
userID string
|
userID string
|
||||||
}{
|
}{
|
||||||
{"XXX", "ID-1"},
|
{goodClientID, "ID-1"},
|
||||||
{"XXX", "ID-2"},
|
{goodClientID, "ID-2"},
|
||||||
}
|
}
|
||||||
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
refreshRepo := db.NewRefreshTokenRepo(dbMap)
|
||||||
for _, token := range refreshTokens {
|
for _, token := range refreshTokens {
|
||||||
|
@ -187,7 +198,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
emailer := &testEmailer{}
|
emailer := &testEmailer{}
|
||||||
api := NewUsersAPI(dbMap, mgr, emailer, "local")
|
api := NewUsersAPI(mgr, clientManager, refreshRepo, emailer, "local")
|
||||||
return api, emailer
|
return api, emailer
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -582,8 +593,8 @@ func TestRevokeRefreshToken(t *testing.T) {
|
||||||
before []string // clientIDs expected before the change.
|
before []string // clientIDs expected before the change.
|
||||||
after []string // clientIDs expected after the change.
|
after []string // clientIDs expected after the change.
|
||||||
}{
|
}{
|
||||||
{"ID-1", "XXX", []string{"XXX"}, []string{}},
|
{"ID-1", goodClientID, []string{goodClientID}, []string{}},
|
||||||
{"ID-2", "XXX", []string{"XXX"}, []string{}},
|
{"ID-2", goodClientID, []string{goodClientID}, []string{}},
|
||||||
}
|
}
|
||||||
|
|
||||||
api, _ := makeTestFixtures()
|
api, _ := makeTestFixtures()
|
||||||
|
|
Reference in a new issue