171 lines
3.1 KiB
Go
171 lines
3.1 KiB
Go
|
package db
|
||
|
|
||
|
import (
|
||
|
"crypto/x509"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"time"
|
||
|
|
||
|
"github.com/coopernurse/gorp"
|
||
|
"github.com/lib/pq"
|
||
|
|
||
|
pcrypto "github.com/coreos/dex/pkg/crypto"
|
||
|
"github.com/coreos/go-oidc/key"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
keyTableName = "key"
|
||
|
)
|
||
|
|
||
|
func init() {
|
||
|
register(table{
|
||
|
name: keyTableName,
|
||
|
model: privateKeySetBlob{},
|
||
|
autoinc: false,
|
||
|
pkey: []string{"value"},
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func newPrivateKeySetModel(pks *key.PrivateKeySet) (*privateKeySetModel, error) {
|
||
|
pkeys := pks.Keys()
|
||
|
keys := make([]privateKeyModel, len(pkeys))
|
||
|
for i, pkey := range pkeys {
|
||
|
keys[i] = privateKeyModel{
|
||
|
ID: pkey.ID(),
|
||
|
PKCS1: x509.MarshalPKCS1PrivateKey(pkey.PrivateKey),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
m := privateKeySetModel{
|
||
|
Keys: keys,
|
||
|
ExpiresAt: pks.ExpiresAt(),
|
||
|
}
|
||
|
|
||
|
return &m, nil
|
||
|
}
|
||
|
|
||
|
type privateKeyModel struct {
|
||
|
ID string `json:"id"`
|
||
|
PKCS1 []byte `json:"pkcs1"`
|
||
|
}
|
||
|
|
||
|
func (m *privateKeyModel) PrivateKey() (*key.PrivateKey, error) {
|
||
|
d, err := x509.ParsePKCS1PrivateKey(m.PKCS1)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
pk := key.PrivateKey{
|
||
|
KeyID: m.ID,
|
||
|
PrivateKey: d,
|
||
|
}
|
||
|
|
||
|
return &pk, nil
|
||
|
}
|
||
|
|
||
|
type privateKeySetModel struct {
|
||
|
Keys []privateKeyModel `json:"keys"`
|
||
|
ExpiresAt time.Time `json:"expires_at"`
|
||
|
}
|
||
|
|
||
|
func (m *privateKeySetModel) PrivateKeySet() (*key.PrivateKeySet, error) {
|
||
|
keys := make([]*key.PrivateKey, len(m.Keys))
|
||
|
for i, pkm := range m.Keys {
|
||
|
pk, err := pkm.PrivateKey()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
keys[i] = pk
|
||
|
}
|
||
|
return key.NewPrivateKeySet(keys, m.ExpiresAt), nil
|
||
|
}
|
||
|
|
||
|
type privateKeySetBlob struct {
|
||
|
Value []byte `db:"value"`
|
||
|
}
|
||
|
|
||
|
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, error) {
|
||
|
bsecret := []byte(secret)
|
||
|
if len(bsecret) != 32 {
|
||
|
return nil, errors.New("expected 32-byte secret")
|
||
|
}
|
||
|
|
||
|
r := &PrivateKeySetRepo{
|
||
|
dbMap: dbm,
|
||
|
secret: []byte(secret),
|
||
|
}
|
||
|
|
||
|
return r, nil
|
||
|
}
|
||
|
|
||
|
type PrivateKeySetRepo struct {
|
||
|
dbMap *gorp.DbMap
|
||
|
secret []byte
|
||
|
}
|
||
|
|
||
|
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||
|
qt := pq.QuoteIdentifier(keyTableName)
|
||
|
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
pks, ok := ks.(*key.PrivateKeySet)
|
||
|
if !ok {
|
||
|
return errors.New("unable to cast to PrivateKeySet")
|
||
|
}
|
||
|
|
||
|
m, err := newPrivateKeySetModel(pks)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
j, err := json.Marshal(m)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
v, err := pcrypto.AESEncrypt(j, r.secret)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
b := &privateKeySetBlob{Value: v}
|
||
|
return r.dbMap.Insert(b)
|
||
|
}
|
||
|
|
||
|
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
|
||
|
qt := pq.QuoteIdentifier(keyTableName)
|
||
|
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if len(objs) == 0 {
|
||
|
return nil, key.ErrorNoKeys
|
||
|
}
|
||
|
|
||
|
b, ok := objs[0].(*privateKeySetBlob)
|
||
|
if !ok {
|
||
|
return nil, errors.New("unable to cast to KeySet")
|
||
|
}
|
||
|
|
||
|
j, err := pcrypto.AESDecrypt(b.Value, r.secret)
|
||
|
if err != nil {
|
||
|
return nil, errors.New("unable to decrypt key set")
|
||
|
}
|
||
|
|
||
|
var m privateKeySetModel
|
||
|
if err := json.Unmarshal(j, &m); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
pks, err := m.PrivateKeySet()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return key.KeySet(pks), nil
|
||
|
}
|