dex/db/client.go

271 lines
6 KiB
Go
Raw Normal View History

2015-08-18 05:57:27 +05:30
package db
import (
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"reflect"
2015-08-18 05:57:27 +05:30
"github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
2015-08-18 05:57:27 +05:30
"github.com/lib/pq"
2016-02-09 05:31:16 +05:30
"github.com/mattn/go-sqlite3"
2015-08-18 05:57:27 +05:30
"golang.org/x/crypto/bcrypt"
"github.com/coreos/dex/client"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/dex/pkg/log"
)
const (
clientIdentityTableName = "client_identity"
bcryptHashCost = 10
// Blowfish, the algorithm underlying bcrypt, has a maximum
// password length of 72. We explicitly track and check this
// since the bcrypt library will silently ignore portions of
// a password past the first 72 characters.
maxSecretLength = 72
// postgres error codes
pgErrorCodeUniqueViolation = "23505" // unique_violation
)
func init() {
register(table{
name: clientIdentityTableName,
model: clientIdentityModel{},
autoinc: false,
pkey: []string{"id"},
})
}
func newClientIdentityModel(id string, secret []byte, meta *oidc.ClientMetadata) (*clientIdentityModel, error) {
hashed, err := bcrypt.GenerateFromPassword(secret, bcryptHashCost)
if err != nil {
return nil, err
}
bmeta, err := json.Marshal(meta)
2015-08-18 05:57:27 +05:30
if err != nil {
return nil, err
}
cim := clientIdentityModel{
ID: id,
Secret: hashed,
Metadata: string(bmeta),
}
return &cim, nil
}
type clientIdentityModel struct {
ID string `db:"id"`
Secret []byte `db:"secret"`
Metadata string `db:"metadata"`
2015-08-20 04:10:36 +05:30
DexAdmin bool `db:"dex_admin"`
2015-08-18 05:57:27 +05:30
}
func (m *clientIdentityModel) ClientIdentity() (*oidc.ClientIdentity, error) {
ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
ID: m.ID,
Secret: string(m.Secret),
},
}
if err := json.Unmarshal([]byte(m.Metadata), &ci.Metadata); err != nil {
2015-08-18 05:57:27 +05:30
return nil, err
}
return &ci, nil
}
func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo {
return &clientIdentityRepo{dbMap: dbm}
}
func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) {
2016-02-09 05:31:16 +05:30
tx, err := dbm.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
2015-08-18 05:57:27 +05:30
for _, c := range clients {
dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret)
if err != nil {
return nil, err
}
cm, err := newClientIdentityModel(c.Credentials.ID, dec, &c.Metadata)
if err != nil {
return nil, err
}
2016-02-09 05:31:16 +05:30
err = tx.Insert(cm)
2015-08-18 05:57:27 +05:30
if err != nil {
return nil, err
}
}
2016-02-09 05:31:16 +05:30
if err := tx.Commit(); err != nil {
return nil, err
}
return NewClientIdentityRepo(dbm), nil
2015-08-18 05:57:27 +05:30
}
type clientIdentityRepo struct {
dbMap *gorp.DbMap
}
func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) {
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
if err == sql.ErrNoRows || m == nil {
return nil, client.ErrorNotFound
}
if err != nil {
return nil, err
}
cim, ok := m.(*clientIdentityModel)
if !ok {
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
2015-08-18 05:57:27 +05:30
return nil, errors.New("unrecognized model")
}
ci, err := cim.ClientIdentity()
if err != nil {
return nil, err
}
return &ci.Metadata, nil
}
func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) {
m, err := r.dbMap.Get(clientIdentityModel{}, clientID)
if m == nil || err != nil {
return false, err
}
cim, ok := m.(*clientIdentityModel)
if !ok {
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
2015-08-18 05:57:27 +05:30
return false, errors.New("unrecognized model")
}
return cim.DexAdmin, nil
}
func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error {
tx, err := r.dbMap.Begin()
if err != nil {
return err
}
2016-02-09 05:31:16 +05:30
defer tx.Rollback()
2015-08-18 05:57:27 +05:30
2016-02-09 05:31:16 +05:30
m, err := tx.Get(clientIdentityModel{}, clientID)
2015-08-18 05:57:27 +05:30
if m == nil || err != nil {
rollback(tx)
return err
}
cim, ok := m.(*clientIdentityModel)
if !ok {
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
2015-08-18 05:57:27 +05:30
return errors.New("unrecognized model")
}
cim.DexAdmin = isAdmin
2016-02-09 05:31:16 +05:30
_, err = tx.Update(cim)
2015-08-18 05:57:27 +05:30
if err != nil {
return err
}
2016-02-09 05:31:16 +05:30
return tx.Commit()
2015-08-18 05:57:27 +05:30
}
func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) {
m, err := r.dbMap.Get(clientIdentityModel{}, creds.ID)
if m == nil || err != nil {
return false, err
}
cim, ok := m.(*clientIdentityModel)
if !ok {
log.Errorf("expected clientIdentityModel but found %v", reflect.TypeOf(m))
2015-08-18 05:57:27 +05:30
return false, errors.New("unrecognized model")
}
dec, err := base64.URLEncoding.DecodeString(creds.Secret)
if err != nil {
log.Errorf("error Decoding client creds: %v", err)
2015-08-18 05:57:27 +05:30
return false, nil
}
if len(dec) > maxSecretLength {
return false, nil
}
ok = bcrypt.CompareHashAndPassword(cim.Secret, dec) == nil
return ok, nil
}
func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.ClientCredentials, error) {
secret, err := pcrypto.RandBytes(maxSecretLength)
if err != nil {
return nil, err
}
cim, err := newClientIdentityModel(id, secret, &meta)
if err != nil {
return nil, err
}
if err := r.dbMap.Insert(cim); err != nil {
2016-02-09 05:31:16 +05:30
switch sqlErr := err.(type) {
case *pq.Error:
if sqlErr.Code == pgErrorCodeUniqueViolation {
err = errors.New("client ID already exists")
}
case *sqlite3.Error:
if sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique {
err = errors.New("client ID already exists")
}
2015-08-18 05:57:27 +05:30
}
return nil, err
}
cc := oidc.ClientCredentials{
ID: id,
Secret: base64.URLEncoding.EncodeToString(secret),
}
return &cc, nil
}
func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) {
2016-02-09 05:31:16 +05:30
qt := r.dbMap.Dialect.QuotedTableForQuery("", clientIdentityTableName)
2015-08-18 05:57:27 +05:30
q := fmt.Sprintf("SELECT * FROM %s", qt)
objs, err := r.dbMap.Select(&clientIdentityModel{}, q)
if err != nil {
return nil, err
}
cs := make([]oidc.ClientIdentity, len(objs))
for i, obj := range objs {
m, ok := obj.(*clientIdentityModel)
if !ok {
return nil, errors.New("unable to cast client identity to clientIdentityModel")
}
ci, err := m.ClientIdentity()
if err != nil {
return nil, err
}
cs[i] = *ci
}
return cs, nil
}