db: PrivateKeySetRepo now takes >1 secrets
The first secret is used to encrypt, the rest are for decryption; if the first doesn't work, the rest are tried in order. The makes it possible to rotate keys.
This commit is contained in:
parent
72c3b0c31a
commit
c8feb5c33d
3 changed files with 102 additions and 43 deletions
43
db/key.go
43
db/key.go
|
@ -18,6 +18,10 @@ const (
|
||||||
keyTableName = "key"
|
keyTableName = "key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrorCannotDecryptKeys = errors.New("Cannot Decrypt Keys")
|
||||||
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
register(table{
|
register(table{
|
||||||
name: keyTableName,
|
name: keyTableName,
|
||||||
|
@ -85,15 +89,16 @@ type privateKeySetBlob struct {
|
||||||
Value []byte `db:"value"`
|
Value []byte `db:"value"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, error) {
|
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secrets ...[]byte) (*PrivateKeySetRepo, error) {
|
||||||
bsecret := []byte(secret)
|
for i, secret := range secrets {
|
||||||
if len(bsecret) != 32 {
|
if len(secret) != 32 {
|
||||||
return nil, errors.New("expected 32-byte secret")
|
return nil, fmt.Errorf("key secret %d: expected 32-byte secret", i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r := &PrivateKeySetRepo{
|
r := &PrivateKeySetRepo{
|
||||||
dbMap: dbm,
|
dbMap: dbm,
|
||||||
secret: []byte(secret),
|
secrets: secrets,
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
|
@ -101,7 +106,7 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, e
|
||||||
|
|
||||||
type PrivateKeySetRepo struct {
|
type PrivateKeySetRepo struct {
|
||||||
dbMap *gorp.DbMap
|
dbMap *gorp.DbMap
|
||||||
secret []byte
|
secrets [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||||||
|
@ -126,7 +131,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
v, err := pcrypto.AESEncrypt(j, r.secret)
|
v, err := pcrypto.AESEncrypt(j, r.active())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -151,20 +156,32 @@ func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
|
||||||
return nil, errors.New("unable to cast to KeySet")
|
return nil, errors.New("unable to cast to KeySet")
|
||||||
}
|
}
|
||||||
|
|
||||||
j, err := pcrypto.AESDecrypt(b.Value, r.secret)
|
var pks *key.PrivateKeySet
|
||||||
|
for _, secret := range r.secrets {
|
||||||
|
var j []byte
|
||||||
|
j, err = pcrypto.AESDecrypt(b.Value, secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("unable to decrypt key set")
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var m privateKeySetModel
|
var m privateKeySetModel
|
||||||
if err := json.Unmarshal(j, &m); err != nil {
|
if err = json.Unmarshal(j, &m); err != nil {
|
||||||
return nil, err
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
pks, err := m.PrivateKeySet()
|
pks, err = m.PrivateKeySet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrorCannotDecryptKeys
|
||||||
|
}
|
||||||
return key.KeySet(pks), nil
|
return key.KeySet(pks), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *PrivateKeySetRepo) active() []byte {
|
||||||
|
return r.secrets[0]
|
||||||
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewPrivateKeySetRepoInvalidKey(t *testing.T) {
|
func TestNewPrivateKeySetRepoInvalidKey(t *testing.T) {
|
||||||
_, err := NewPrivateKeySetRepo(nil, "sharks")
|
_, err := NewPrivateKeySetRepo(nil, []byte("sharks"))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Expected non-nil error")
|
t.Fatalf("Expected non-nil error")
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,33 +114,75 @@ func TestDBSessionRepoCreateUpdate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
|
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
|
||||||
r, err := db.NewPrivateKeySetRepo(connect(t), "roflroflroflroflroflroflroflrofl")
|
s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
|
||||||
|
s2 := []byte("oooooooooooooooooooooooooooooooo")
|
||||||
|
s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")
|
||||||
|
|
||||||
|
keys := []*key.PrivateKey{}
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
k, err := key.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to generate RSA key: %v", err)
|
||||||
|
}
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
ks := key.NewPrivateKeySet(
|
||||||
|
[]*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
setSecrets [][]byte
|
||||||
|
getSecrets [][]byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// same secrets used to encrypt, decrypt
|
||||||
|
setSecrets: [][]byte{s1, s2},
|
||||||
|
getSecrets: [][]byte{s1, s2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// setSecrets got rotated, but getSecrets didn't yet.
|
||||||
|
setSecrets: [][]byte{s2, s3},
|
||||||
|
getSecrets: [][]byte{s1, s2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// getSecrets doesn't have s3
|
||||||
|
setSecrets: [][]byte{s3},
|
||||||
|
getSecrets: [][]byte{s1, s2},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
setRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.setSecrets...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf(err.Error())
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
k1, err := key.GeneratePrivateKey()
|
getRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.getSecrets...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to generate RSA key: %v", err)
|
t.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
k2, err := key.GeneratePrivateKey()
|
if err := setRepo.Set(ks); err != nil {
|
||||||
if err != nil {
|
t.Fatalf("case %d: Unexpected error: %v", i, err)
|
||||||
t.Fatalf("Unable to generate RSA key: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ks := key.NewPrivateKeySet([]*key.PrivateKey{k1, k2}, time.Now().Add(time.Minute))
|
got, err := getRepo.Get()
|
||||||
if err := r.Set(ks); err != nil {
|
if tt.wantErr {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
if err == nil {
|
||||||
|
t.Errorf("case %d: want err, got nil", i)
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := r.Get()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("case %d: Unexpected error: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := pretty.Compare(ks, got); diff != "" {
|
if diff := pretty.Compare(ks, got); diff != "" {
|
||||||
t.Fatalf("Retrieved incorrect KeySet: Compare(want,got): %v", diff)
|
t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Reference in a new issue