dex/db/key.go

212 lines
3.8 KiB
Go

package db
import (
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-gorp/gorp"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/go-oidc/key"
)
const (
keyTableName = "key"
)
var (
ErrorCannotDecryptKeys = errors.New("Cannot Decrypt Keys")
)
func init() {
register(table{
name: keyTableName,
model: privateKeySetBlob{},
autoinc: false,
})
}
func newPrivateKeySetModel(pks *key.PrivateKeySet) (*privateKeySetModel, error) {
pkeys := pks.Keys()
keys := make([]privateKeyModel, len(pkeys))
for i, pkey := range pkeys {
keys[i] = privateKeyModel{
ID: pkey.ID(),
PKCS1: x509.MarshalPKCS1PrivateKey(pkey.PrivateKey),
}
}
m := privateKeySetModel{
Keys: keys,
ExpiresAt: pks.ExpiresAt(),
}
return &m, nil
}
type privateKeyModel struct {
ID string `json:"id"`
PKCS1 []byte `json:"pkcs1"`
}
func (m *privateKeyModel) PrivateKey() (*key.PrivateKey, error) {
d, err := x509.ParsePKCS1PrivateKey(m.PKCS1)
if err != nil {
return nil, err
}
pk := key.PrivateKey{
KeyID: m.ID,
PrivateKey: d,
}
return &pk, nil
}
type privateKeySetModel struct {
Keys []privateKeyModel `json:"keys"`
ExpiresAt time.Time `json:"expires_at"`
}
func (m *privateKeySetModel) PrivateKeySet() (*key.PrivateKeySet, error) {
keys := make([]*key.PrivateKey, len(m.Keys))
for i, pkm := range m.Keys {
pk, err := pkm.PrivateKey()
if err != nil {
return nil, err
}
keys[i] = pk
}
return key.NewPrivateKeySet(keys, m.ExpiresAt), nil
}
type privateKeySetBlob struct {
Value []byte `db:"value"`
}
func NewPrivateKeySetRepo(dbm *gorp.DbMap, useOldFormat bool, secrets ...[]byte) (*PrivateKeySetRepo, error) {
if len(secrets) == 0 {
return nil, errors.New("must provide at least one key secret")
}
for i, secret := range secrets {
if len(secret) != 32 {
return nil, fmt.Errorf("key secret %d: expected 32-byte secret", i)
}
}
r := &PrivateKeySetRepo{
db: &db{dbm},
useOldFormat: useOldFormat,
secrets: secrets,
}
return r, nil
}
type PrivateKeySetRepo struct {
*db
useOldFormat bool
secrets [][]byte
}
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
qt := r.quote(keyTableName)
tx, err := r.begin()
if err != nil {
return err
}
defer tx.Rollback()
exec := r.executor(tx)
if _, err := exec.Exec(fmt.Sprintf("DELETE FROM %s", qt)); err != nil {
return err
}
pks, ok := ks.(*key.PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
}
m, err := newPrivateKeySetModel(pks)
if err != nil {
return err
}
j, err := json.Marshal(m)
if err != nil {
return err
}
var v []byte
if r.useOldFormat {
v, err = pcrypto.AESEncrypt(j, r.active())
} else {
v, err = pcrypto.Encrypt(j, r.active())
}
if err != nil {
return err
}
b := &privateKeySetBlob{Value: v}
if err := exec.Insert(b); err != nil {
return err
}
return tx.Commit()
}
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
qt := r.quote(keyTableName)
objs, err := r.executor(nil).Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
if err != nil {
return nil, err
}
if len(objs) == 0 {
return nil, key.ErrorNoKeys
}
b, ok := objs[0].(*privateKeySetBlob)
if !ok {
return nil, errors.New("unable to cast to KeySet")
}
var pks *key.PrivateKeySet
for _, secret := range r.secrets {
var j []byte
if r.useOldFormat {
j, err = pcrypto.AESDecrypt(b.Value, secret)
} else {
j, err = pcrypto.Decrypt(b.Value, secret)
}
if err != nil {
continue
}
var m privateKeySetModel
if err = json.Unmarshal(j, &m); err != nil {
continue
}
pks, err = m.PrivateKeySet()
if err != nil {
continue
}
break
}
if err != nil {
return nil, ErrorCannotDecryptKeys
}
return key.KeySet(pks), nil
}
func (r *PrivateKeySetRepo) active() []byte {
return r.secrets[0]
}