diff --git a/pkg/crypto/aes.go b/pkg/crypto/aes.go index d5dad127..109cd201 100644 --- a/pkg/crypto/aes.go +++ b/pkg/crypto/aes.go @@ -7,6 +7,8 @@ import ( "errors" ) +const aesKeySize = 32 // force 256-bit AES + // pad uses the PKCS#7 padding scheme to align the a payload to a specific block size func pad(plaintext []byte, bsize int) ([]byte, error) { if bsize >= 256 { @@ -33,7 +35,7 @@ func unpad(paddedtext []byte) ([]byte, error) { return paddedtext[:length-(pad)], nil } -// AESEncrypt encrypts a payloaded with an AES cipher. +// **DEPRECATED** AESEncrypt encrypts a payloaded with an AES cipher. // The returned ciphertext has three notable properties: // 1. ciphertext is aligned to the standard AES block size // 2. ciphertext is padded using PKCS#7 @@ -61,7 +63,7 @@ func AESEncrypt(plaintext, key []byte) ([]byte, error) { return ciphertext, nil } -// AESDecrypt decrypts an encrypted payload with an AES cipher. +// **DEPRECATED** AESDecrypt decrypts an encrypted payload with an AES cipher. // The decryption algorithm makes three assumptions: // 1. ciphertext is aligned to the standard AES block size // 2. ciphertext is padded using PKCS#7 @@ -94,3 +96,49 @@ func AESDecrypt(ciphertext, key []byte) ([]byte, error) { return unpad(plaintext) } + +// Takes plaintext and a key, returns ciphertext or error +// Output takes the form nonce|ciphertext|tag where '|' indicates concatenation +func Encrypt(plaintext, key []byte) (ciphertext []byte, err error) { + if len(key) != aesKeySize { + return nil, aes.KeySizeError(len(key)) + } + + aes, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(aes) + if err != nil { + return nil, err + } + + nonce, err := RandBytes(gcm.NonceSize()) + if err != nil { + return nil, err + } + + return gcm.Seal(nonce, nonce, plaintext, nil), nil +} + +// Takes ciphertext and a key, returns plaintext or error +// Expects input form nonce|ciphertext|tag where '|' indicates concatenation +func Decrypt(ciphertext, key []byte) (plaintext []byte, err error) { + if len(key) != aesKeySize { + return nil, aes.KeySizeError(len(key)) + } + + aes, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(aes) + if err != nil { + return nil, err + } + + return gcm.Open(nil, ciphertext[:gcm.NonceSize()], + ciphertext[gcm.NonceSize():], nil) +} diff --git a/pkg/crypto/aes_test.go b/pkg/crypto/aes_test.go index eafb5b8b..5cb46c5a 100644 --- a/pkg/crypto/aes_test.go +++ b/pkg/crypto/aes_test.go @@ -1,6 +1,7 @@ package crypto import ( + "bytes" "reflect" "testing" ) @@ -91,3 +92,37 @@ func TestAESDecryptWrongKey(t *testing.T) { t.Fatalf("Data decrypted with different key matches original payload") } } + +func TestEncryptDecryptGCM(t *testing.T) { + gcmTests := []struct { + plaintext []byte + key []byte + }{ + { + plaintext: []byte("Hello, world!"), + key: append([]byte("shark"), make([]byte, 27)...), + }, + } + + for _, tt := range gcmTests { + ciphertext, err := Encrypt(tt.plaintext, tt.key) + if err != nil { + t.Fatal(err) + } + + plaintext, err := Decrypt(ciphertext, tt.key) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(plaintext, tt.plaintext) { + t.Errorf("plaintexts don't match") + } + + ciphertext[0] ^= 0xff + plaintext, err = Decrypt(ciphertext, tt.key) + if err == nil { + t.Errorf("gcmOpen should not have worked, but did") + } + } +}