dex/db/key.go

171 lines
3.1 KiB
Go
Raw Normal View History

2015-08-18 05:57:27 +05:30
package db
import (
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-gorp/gorp"
2015-08-18 05:57:27 +05:30
"github.com/lib/pq"
pcrypto "github.com/coreos/dex/pkg/crypto"
"github.com/coreos/go-oidc/key"
)
const (
keyTableName = "key"
)
func init() {
register(table{
name: keyTableName,
model: privateKeySetBlob{},
autoinc: false,
pkey: []string{"value"},
})
}
func newPrivateKeySetModel(pks *key.PrivateKeySet) (*privateKeySetModel, error) {
pkeys := pks.Keys()
keys := make([]privateKeyModel, len(pkeys))
for i, pkey := range pkeys {
keys[i] = privateKeyModel{
ID: pkey.ID(),
PKCS1: x509.MarshalPKCS1PrivateKey(pkey.PrivateKey),
}
}
m := privateKeySetModel{
Keys: keys,
ExpiresAt: pks.ExpiresAt(),
}
return &m, nil
}
type privateKeyModel struct {
ID string `json:"id"`
PKCS1 []byte `json:"pkcs1"`
}
func (m *privateKeyModel) PrivateKey() (*key.PrivateKey, error) {
d, err := x509.ParsePKCS1PrivateKey(m.PKCS1)
if err != nil {
return nil, err
}
pk := key.PrivateKey{
KeyID: m.ID,
PrivateKey: d,
}
return &pk, nil
}
type privateKeySetModel struct {
Keys []privateKeyModel `json:"keys"`
ExpiresAt time.Time `json:"expires_at"`
}
func (m *privateKeySetModel) PrivateKeySet() (*key.PrivateKeySet, error) {
keys := make([]*key.PrivateKey, len(m.Keys))
for i, pkm := range m.Keys {
pk, err := pkm.PrivateKey()
if err != nil {
return nil, err
}
keys[i] = pk
}
return key.NewPrivateKeySet(keys, m.ExpiresAt), nil
}
type privateKeySetBlob struct {
Value []byte `db:"value"`
}
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, error) {
bsecret := []byte(secret)
if len(bsecret) != 32 {
return nil, errors.New("expected 32-byte secret")
}
r := &PrivateKeySetRepo{
dbMap: dbm,
secret: []byte(secret),
}
return r, nil
}
type PrivateKeySetRepo struct {
dbMap *gorp.DbMap
secret []byte
}
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
qt := pq.QuoteIdentifier(keyTableName)
_, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt))
if err != nil {
return err
}
pks, ok := ks.(*key.PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
}
m, err := newPrivateKeySetModel(pks)
if err != nil {
return err
}
j, err := json.Marshal(m)
if err != nil {
return err
}
v, err := pcrypto.AESEncrypt(j, r.secret)
if err != nil {
return err
}
b := &privateKeySetBlob{Value: v}
return r.dbMap.Insert(b)
}
func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
qt := pq.QuoteIdentifier(keyTableName)
objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt))
if err != nil {
return nil, err
}
if len(objs) == 0 {
return nil, key.ErrorNoKeys
}
b, ok := objs[0].(*privateKeySetBlob)
if !ok {
return nil, errors.New("unable to cast to KeySet")
}
j, err := pcrypto.AESDecrypt(b.Value, r.secret)
if err != nil {
return nil, errors.New("unable to decrypt key set")
}
var m privateKeySetModel
if err := json.Unmarshal(j, &m); err != nil {
return nil, err
}
pks, err := m.PrivateKeySet()
if err != nil {
return nil, err
}
return key.KeySet(pks), nil
}