db: PrivateKeySetRepo now takes >1 secrets

The first secret is used to encrypt, the rest are for decryption; if the
first doesn't work, the rest are tried in order.

The makes it possible to rotate keys.
This commit is contained in:
Bobby Rullo 2015-08-25 16:41:20 -07:00
parent 72c3b0c31a
commit c8feb5c33d
3 changed files with 102 additions and 43 deletions

View file

@ -18,6 +18,10 @@ const (
keyTableName = "key" keyTableName = "key"
) )
var (
ErrorCannotDecryptKeys = errors.New("Cannot Decrypt Keys")
)
func init() { func init() {
register(table{ register(table{
name: keyTableName, name: keyTableName,
@ -85,23 +89,24 @@ type privateKeySetBlob struct {
Value []byte `db:"value"` Value []byte `db:"value"`
} }
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, error) { func NewPrivateKeySetRepo(dbm *gorp.DbMap, secrets ...[]byte) (*PrivateKeySetRepo, error) {
bsecret := []byte(secret) for i, secret := range secrets {
if len(bsecret) != 32 { if len(secret) != 32 {
return nil, errors.New("expected 32-byte secret") return nil, fmt.Errorf("key secret %d: expected 32-byte secret", i)
}
} }
r := &PrivateKeySetRepo{ r := &PrivateKeySetRepo{
dbMap: dbm, dbMap: dbm,
secret: []byte(secret), secrets: secrets,
} }
return r, nil return r, nil
} }
type PrivateKeySetRepo struct { type PrivateKeySetRepo struct {
dbMap *gorp.DbMap dbMap *gorp.DbMap
secret []byte secrets [][]byte
} }
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
@ -126,7 +131,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
return err return err
} }
v, err := pcrypto.AESEncrypt(j, r.secret) v, err := pcrypto.AESEncrypt(j, r.active())
if err != nil { if err != nil {
return err return err
} }
@ -151,20 +156,32 @@ func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
return nil, errors.New("unable to cast to KeySet") return nil, errors.New("unable to cast to KeySet")
} }
j, err := pcrypto.AESDecrypt(b.Value, r.secret) var pks *key.PrivateKeySet
for _, secret := range r.secrets {
var j []byte
j, err = pcrypto.AESDecrypt(b.Value, secret)
if err != nil {
continue
}
var m privateKeySetModel
if err = json.Unmarshal(j, &m); err != nil {
continue
}
pks, err = m.PrivateKeySet()
if err != nil {
continue
}
break
}
if err != nil { if err != nil {
return nil, errors.New("unable to decrypt key set") return nil, ErrorCannotDecryptKeys
} }
var m privateKeySetModel
if err := json.Unmarshal(j, &m); err != nil {
return nil, err
}
pks, err := m.PrivateKeySet()
if err != nil {
return nil, err
}
return key.KeySet(pks), nil return key.KeySet(pks), nil
} }
func (r *PrivateKeySetRepo) active() []byte {
return r.secrets[0]
}

View file

@ -5,7 +5,7 @@ import (
) )
func TestNewPrivateKeySetRepoInvalidKey(t *testing.T) { func TestNewPrivateKeySetRepoInvalidKey(t *testing.T) {
_, err := NewPrivateKeySetRepo(nil, "sharks") _, err := NewPrivateKeySetRepo(nil, []byte("sharks"))
if err == nil { if err == nil {
t.Fatalf("Expected non-nil error") t.Fatalf("Expected non-nil error")
} }

View file

@ -114,33 +114,75 @@ func TestDBSessionRepoCreateUpdate(t *testing.T) {
} }
func TestDBPrivateKeySetRepoSetGet(t *testing.T) { func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
r, err := db.NewPrivateKeySetRepo(connect(t), "roflroflroflroflroflroflroflrofl") s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
if err != nil { s2 := []byte("oooooooooooooooooooooooooooooooo")
t.Fatalf(err.Error()) s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")
keys := []*key.PrivateKey{}
for i := 0; i < 2; i++ {
k, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("Unable to generate RSA key: %v", err)
}
keys = append(keys, k)
} }
k1, err := key.GeneratePrivateKey() ks := key.NewPrivateKeySet(
if err != nil { []*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute))
t.Fatalf("Unable to generate RSA key: %v", err)
tests := []struct {
setSecrets [][]byte
getSecrets [][]byte
wantErr bool
}{
{
// same secrets used to encrypt, decrypt
setSecrets: [][]byte{s1, s2},
getSecrets: [][]byte{s1, s2},
},
{
// setSecrets got rotated, but getSecrets didn't yet.
setSecrets: [][]byte{s2, s3},
getSecrets: [][]byte{s1, s2},
},
{
// getSecrets doesn't have s3
setSecrets: [][]byte{s3},
getSecrets: [][]byte{s1, s2},
wantErr: true,
},
} }
k2, err := key.GeneratePrivateKey() for i, tt := range tests {
if err != nil { setRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.setSecrets...)
t.Fatalf("Unable to generate RSA key: %v", err) if err != nil {
} t.Fatalf(err.Error())
}
ks := key.NewPrivateKeySet([]*key.PrivateKey{k1, k2}, time.Now().Add(time.Minute)) getRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.getSecrets...)
if err := r.Set(ks); err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf(err.Error())
} }
got, err := r.Get() if err := setRepo.Set(ks); err != nil {
if err != nil { t.Fatalf("case %d: Unexpected error: %v", i, err)
t.Fatalf("Unexpected error: %v", err) }
}
got, err := getRepo.Get()
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want err, got nil", i)
}
continue
}
if err != nil {
t.Fatalf("case %d: Unexpected error: %v", i, err)
}
if diff := pretty.Compare(ks, got); diff != "" {
t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff)
}
if diff := pretty.Compare(ks, got); diff != "" {
t.Fatalf("Retrieved incorrect KeySet: Compare(want,got): %v", diff)
} }
} }