db: PrivateKeySetRepo now takes >1 secrets
The first secret is used to encrypt, the rest are for decryption; if the first doesn't work, the rest are tried in order. The makes it possible to rotate keys.
This commit is contained in:
parent
72c3b0c31a
commit
c8feb5c33d
3 changed files with 102 additions and 43 deletions
61
db/key.go
61
db/key.go
|
@ -18,6 +18,10 @@ const (
|
|||
keyTableName = "key"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorCannotDecryptKeys = errors.New("Cannot Decrypt Keys")
|
||||
)
|
||||
|
||||
func init() {
|
||||
register(table{
|
||||
name: keyTableName,
|
||||
|
@ -85,23 +89,24 @@ type privateKeySetBlob struct {
|
|||
Value []byte `db:"value"`
|
||||
}
|
||||
|
||||
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, error) {
|
||||
bsecret := []byte(secret)
|
||||
if len(bsecret) != 32 {
|
||||
return nil, errors.New("expected 32-byte secret")
|
||||
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secrets ...[]byte) (*PrivateKeySetRepo, error) {
|
||||
for i, secret := range secrets {
|
||||
if len(secret) != 32 {
|
||||
return nil, fmt.Errorf("key secret %d: expected 32-byte secret", i)
|
||||
}
|
||||
}
|
||||
|
||||
r := &PrivateKeySetRepo{
|
||||
dbMap: dbm,
|
||||
secret: []byte(secret),
|
||||
dbMap: dbm,
|
||||
secrets: secrets,
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
type PrivateKeySetRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
secret []byte
|
||||
dbMap *gorp.DbMap
|
||||
secrets [][]byte
|
||||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||||
|
@ -126,7 +131,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
|||
return err
|
||||
}
|
||||
|
||||
v, err := pcrypto.AESEncrypt(j, r.secret)
|
||||
v, err := pcrypto.AESEncrypt(j, r.active())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -151,20 +156,32 @@ func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
|
|||
return nil, errors.New("unable to cast to KeySet")
|
||||
}
|
||||
|
||||
j, err := pcrypto.AESDecrypt(b.Value, r.secret)
|
||||
var pks *key.PrivateKeySet
|
||||
for _, secret := range r.secrets {
|
||||
var j []byte
|
||||
j, err = pcrypto.AESDecrypt(b.Value, secret)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var m privateKeySetModel
|
||||
if err = json.Unmarshal(j, &m); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pks, err = m.PrivateKeySet()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("unable to decrypt key set")
|
||||
return nil, ErrorCannotDecryptKeys
|
||||
}
|
||||
|
||||
var m privateKeySetModel
|
||||
if err := json.Unmarshal(j, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pks, err := m.PrivateKeySet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return key.KeySet(pks), nil
|
||||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) active() []byte {
|
||||
return r.secrets[0]
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
)
|
||||
|
||||
func TestNewPrivateKeySetRepoInvalidKey(t *testing.T) {
|
||||
_, err := NewPrivateKeySetRepo(nil, "sharks")
|
||||
_, err := NewPrivateKeySetRepo(nil, []byte("sharks"))
|
||||
if err == nil {
|
||||
t.Fatalf("Expected non-nil error")
|
||||
}
|
||||
|
|
|
@ -114,33 +114,75 @@ func TestDBSessionRepoCreateUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
|
||||
r, err := db.NewPrivateKeySetRepo(connect(t), "roflroflroflroflroflroflroflrofl")
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
|
||||
s2 := []byte("oooooooooooooooooooooooooooooooo")
|
||||
s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")
|
||||
|
||||
keys := []*key.PrivateKey{}
|
||||
for i := 0; i < 2; i++ {
|
||||
k, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate RSA key: %v", err)
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
k1, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate RSA key: %v", err)
|
||||
ks := key.NewPrivateKeySet(
|
||||
[]*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute))
|
||||
|
||||
tests := []struct {
|
||||
setSecrets [][]byte
|
||||
getSecrets [][]byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
// same secrets used to encrypt, decrypt
|
||||
setSecrets: [][]byte{s1, s2},
|
||||
getSecrets: [][]byte{s1, s2},
|
||||
},
|
||||
{
|
||||
// setSecrets got rotated, but getSecrets didn't yet.
|
||||
setSecrets: [][]byte{s2, s3},
|
||||
getSecrets: [][]byte{s1, s2},
|
||||
},
|
||||
{
|
||||
// getSecrets doesn't have s3
|
||||
setSecrets: [][]byte{s3},
|
||||
getSecrets: [][]byte{s1, s2},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
k2, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate RSA key: %v", err)
|
||||
}
|
||||
for i, tt := range tests {
|
||||
setRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.setSecrets...)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
ks := key.NewPrivateKeySet([]*key.PrivateKey{k1, k2}, time.Now().Add(time.Minute))
|
||||
if err := r.Set(ks); err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
getRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.getSecrets...)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
got, err := r.Get()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if err := setRepo.Set(ks); err != nil {
|
||||
t.Fatalf("case %d: Unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
got, err := getRepo.Get()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want err, got nil", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: Unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(ks, got); diff != "" {
|
||||
t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff)
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(ks, got); diff != "" {
|
||||
t.Fatalf("Retrieved incorrect KeySet: Compare(want,got): %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Reference in a new issue