dex/db/client.go

360 lines
7.4 KiB
Go
Raw Normal View History

2015-08-18 05:57:27 +05:30
package db
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/url"
"reflect"
2015-08-18 05:57:27 +05:30
"github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
2015-08-18 05:57:27 +05:30
"github.com/coreos/dex/client"
"github.com/coreos/dex/pkg/log"
2016-05-12 03:05:24 +05:30
"github.com/coreos/dex/repo"
2015-08-18 05:57:27 +05:30
)
const (
clientTableName = "client_identity"
trustedPeerTableName = "trusted_peers"
2015-08-18 05:57:27 +05:30
// postgres error codes
pgErrorCodeUniqueViolation = "23505" // unique_violation
)
var (
localHostRedirectURL = mustParseURL("http://localhost:0")
)
2015-08-18 05:57:27 +05:30
func init() {
register(table{
name: clientTableName,
model: clientModel{},
2015-08-18 05:57:27 +05:30
autoinc: false,
pkey: []string{"id"},
})
register(table{
name: trustedPeerTableName,
model: trustedPeerModel{},
autoinc: false,
pkey: []string{"client_id", "trusted_client_id"},
})
2015-08-18 05:57:27 +05:30
}
func newClientModel(cli client.Client) (*clientModel, error) {
hashed, err := client.HashSecret(cli.Credentials)
if err != nil {
return nil, err
}
if cli.Public {
// Metadata.Valid(), and therefore json.Unmarshal(metadata) complains
// when there's no RedirectURIs, so we set them to a fixed value here,
// and remove it when translating back to a client.Client
cli.Metadata.RedirectURIs = []url.URL{
localHostRedirectURL,
}
}
bmeta, err := json.Marshal(&cli.Metadata)
2015-08-18 05:57:27 +05:30
if err != nil {
return nil, err
}
cim := clientModel{
ID: cli.Credentials.ID,
2015-08-18 05:57:27 +05:30
Secret: hashed,
Metadata: string(bmeta),
DexAdmin: cli.Admin,
Public: cli.Public,
2015-08-18 05:57:27 +05:30
}
return &cim, nil
}
type clientModel struct {
2015-08-18 05:57:27 +05:30
ID string `db:"id"`
Secret []byte `db:"secret"`
Metadata string `db:"metadata"`
2015-08-20 04:10:36 +05:30
DexAdmin bool `db:"dex_admin"`
Public bool `db:"public"`
2015-08-18 05:57:27 +05:30
}
type trustedPeerModel struct {
ClientID string `db:"client_id"`
TrustedClientID string `db:"trusted_client_id"`
}
func (m *clientModel) Client() (*client.Client, error) {
ci := client.Client{
2015-08-18 05:57:27 +05:30
Credentials: oidc.ClientCredentials{
ID: m.ID,
2015-08-18 05:57:27 +05:30
},
Admin: m.DexAdmin,
Public: m.Public,
2015-08-18 05:57:27 +05:30
}
if err := json.Unmarshal([]byte(m.Metadata), &ci.Metadata); err != nil {
2015-08-18 05:57:27 +05:30
return nil, err
}
if ci.Public {
ci.Metadata.RedirectURIs = []url.URL{}
}
2015-08-18 05:57:27 +05:30
return &ci, nil
}
func NewClientRepo(dbm *gorp.DbMap) client.ClientRepo {
return newClientRepo(dbm)
}
func newClientRepo(dbm *gorp.DbMap) *clientRepo {
return &clientRepo{
db: &db{dbm},
2016-02-09 05:31:16 +05:30
}
2015-08-18 05:57:27 +05:30
}
type clientRepo struct {
*db
2015-08-18 05:57:27 +05:30
}
2016-05-12 03:05:24 +05:30
func (r *clientRepo) Get(tx repo.Transaction, clientID string) (client.Client, error) {
m, err := r.executor(tx).Get(clientModel{}, clientID)
2015-08-18 05:57:27 +05:30
if err == sql.ErrNoRows || m == nil {
return client.Client{}, client.ErrorNotFound
2015-08-18 05:57:27 +05:30
}
if err != nil {
return client.Client{}, err
2015-08-18 05:57:27 +05:30
}
cim, ok := m.(*clientModel)
2015-08-18 05:57:27 +05:30
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return client.Client{}, errors.New("unrecognized model")
}
ci, err := cim.Client()
if err != nil {
return client.Client{}, err
2015-08-18 05:57:27 +05:30
}
return *ci, nil
}
func (r *clientRepo) GetSecret(tx repo.Transaction, clientID string) ([]byte, error) {
m, err := r.getModel(tx, clientID)
if err != nil || m == nil {
2015-08-18 05:57:27 +05:30
return nil, err
}
return m.Secret, nil
2015-08-18 05:57:27 +05:30
}
func (r *clientRepo) Update(tx repo.Transaction, cli client.Client) error {
if cli.Credentials.ID == "" {
return client.ErrorNotFound
2015-08-18 05:57:27 +05:30
}
// make sure this client exists already
_, err := r.get(tx, cli.Credentials.ID)
2015-08-18 05:57:27 +05:30
if err != nil {
return err
}
err = r.update(tx, cli)
2015-08-18 05:57:27 +05:30
if err != nil {
return err
}
return nil
2015-08-18 05:57:27 +05:30
}
var alreadyExistsCheckers []func(err error) bool
func registerAlreadyExistsChecker(f func(err error) bool) {
alreadyExistsCheckers = append(alreadyExistsCheckers, f)
}
// isAlreadyExistsErr detects database error codes for failing a unique constraint.
//
// Because database drivers are optionally compiled, use registerAlreadyExistsChecker to
// register driver specific implementations.
func isAlreadyExistsErr(err error) bool {
for _, checker := range alreadyExistsCheckers {
if checker(err) {
return true
}
}
return false
}
2016-05-12 03:05:24 +05:30
func (r *clientRepo) New(tx repo.Transaction, cli client.Client) (*oidc.ClientCredentials, error) {
cim, err := newClientModel(cli)
2015-08-18 05:57:27 +05:30
if err != nil {
return nil, err
}
2016-05-12 03:05:24 +05:30
if err := r.executor(tx).Insert(cim); err != nil {
if isAlreadyExistsErr(err) {
2016-06-29 04:39:20 +05:30
return nil, client.ErrorDuplicateClientID
2015-08-18 05:57:27 +05:30
}
return nil, err
}
cc := oidc.ClientCredentials{
ID: cli.Credentials.ID,
Secret: cli.Credentials.Secret,
2015-08-18 05:57:27 +05:30
}
return &cc, nil
}
2016-05-12 03:05:24 +05:30
func (r *clientRepo) All(tx repo.Transaction) ([]client.Client, error) {
qt := r.quote(clientTableName)
2015-08-18 05:57:27 +05:30
q := fmt.Sprintf("SELECT * FROM %s", qt)
2016-05-12 03:05:24 +05:30
objs, err := r.executor(tx).Select(&clientModel{}, q)
2015-08-18 05:57:27 +05:30
if err != nil {
return nil, err
}
cs := make([]client.Client, len(objs))
2015-08-18 05:57:27 +05:30
for i, obj := range objs {
m, ok := obj.(*clientModel)
2015-08-18 05:57:27 +05:30
if !ok {
return nil, errors.New("unable to cast client identity to clientModel")
2015-08-18 05:57:27 +05:30
}
ci, err := m.Client()
2015-08-18 05:57:27 +05:30
if err != nil {
return nil, err
}
cs[i] = *ci
}
return cs, nil
}
func NewClientRepoFromClients(dbm *gorp.DbMap, cs []client.LoadableClient) (client.ClientRepo, error) {
repo := NewClientRepo(dbm).(*clientRepo)
for _, c := range cs {
cm, err := newClientModel(c.Client)
if err != nil {
return nil, err
}
err = repo.executor(nil).Insert(cm)
if err != nil {
return nil, err
}
err = repo.SetTrustedPeers(nil, c.Client.Credentials.ID, c.TrustedPeers)
if err != nil {
return nil, err
}
}
return repo, nil
}
func (r *clientRepo) get(tx repo.Transaction, clientID string) (client.Client, error) {
cm, err := r.getModel(tx, clientID)
if err != nil {
return client.Client{}, err
}
cli, err := cm.Client()
if err != nil {
return client.Client{}, err
}
return *cli, nil
}
func (r *clientRepo) getModel(tx repo.Transaction, clientID string) (*clientModel, error) {
ex := r.executor(tx)
m, err := ex.Get(clientModel{}, clientID)
if err != nil {
return nil, err
}
if m == nil {
return nil, client.ErrorNotFound
}
cm, ok := m.(*clientModel)
if !ok {
log.Errorf("expected clientModel but found %v", reflect.TypeOf(m))
return nil, errors.New("unrecognized model")
}
return cm, nil
}
func (r *clientRepo) update(tx repo.Transaction, cli client.Client) error {
ex := r.executor(tx)
cm, err := newClientModel(cli)
if err != nil {
return err
}
_, err = ex.Update(cm)
return err
}
func (r *clientRepo) GetTrustedPeers(tx repo.Transaction, clientID string) ([]string, error) {
ex := r.executor(tx)
if clientID == "" {
return nil, client.ErrorInvalidClientID
}
qt := r.quote(trustedPeerTableName)
var ids []string
_, err := ex.Select(&ids, fmt.Sprintf("SELECT trusted_client_id from %s where client_id = $1", qt), clientID)
if err != nil {
if err != sql.ErrNoRows {
return nil, err
}
return nil, nil
}
return ids, nil
}
func (r *clientRepo) SetTrustedPeers(tx repo.Transaction, clientID string, clientIDs []string) error {
ex := r.executor(tx)
qt := r.quote(trustedPeerTableName)
// First delete all existing rows
_, err := ex.Exec(fmt.Sprintf("DELETE from %s where client_id = $1", qt), clientID)
if err != nil {
return err
}
// Ensure that the client exists.
_, err = r.get(tx, clientID)
if err != nil {
return err
}
// Set the clients
rows := []interface{}{}
for _, curID := range clientIDs {
rows = append(rows, &trustedPeerModel{
ClientID: clientID,
TrustedClientID: curID,
})
}
err = ex.Insert(rows...)
if err != nil {
return err
}
return nil
}
func mustParseURL(s string) url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return *u
}