diff --git a/encrypt/aes_cfb.go b/encrypt/aes_cfb.go index c6d1d61d..5221f616 100644 --- a/encrypt/aes_cfb.go +++ b/encrypt/aes_cfb.go @@ -13,6 +13,10 @@ type encryptionAESCFB struct { } func (e *encryptionAESCFB) Encrypt(src []byte, key string) ([]byte, error) { + if len(key) != e.level/8 { + return nil, fmt.Errorf("key length must be %d bytes", e.level/8) + } + block, err := aes.NewCipher([]byte(key)) if err != nil { return nil, err @@ -31,6 +35,10 @@ func (e *encryptionAESCFB) Encrypt(src []byte, key string) ([]byte, error) { } func (e *encryptionAESCFB) Decrypt(src []byte, key string) ([]byte, error) { + if len(key) != e.level/8 { + return nil, fmt.Errorf("key length must be %d bytes", e.level/8) + } + block, err := aes.NewCipher([]byte(key)) if err != nil { return nil, err diff --git a/encrypt/aes_gcm.go b/encrypt/aes_gcm.go index 437d2191..10cb2460 100644 --- a/encrypt/aes_gcm.go +++ b/encrypt/aes_gcm.go @@ -13,6 +13,10 @@ type encryptionAESGCM struct { } func (e *encryptionAESGCM) Encrypt(src []byte, key string) ([]byte, error) { + if len(key) != e.level/8 { + return nil, fmt.Errorf("key length must be %d bytes", e.level/8) + } + block, err := aes.NewCipher([]byte(key)) if err != nil { return nil, err @@ -33,6 +37,10 @@ func (e *encryptionAESGCM) Encrypt(src []byte, key string) ([]byte, error) { } func (e *encryptionAESGCM) Decrypt(src []byte, key string) ([]byte, error) { + if len(key) != e.level/8 { + return nil, fmt.Errorf("key length must be %d bytes", e.level/8) + } + block, err := aes.NewCipher([]byte(key)) if err != nil { return nil, err diff --git a/encrypt/encrypt_test.go b/encrypt/encrypt_test.go index bf0efaf3..6f2d11cf 100644 --- a/encrypt/encrypt_test.go +++ b/encrypt/encrypt_test.go @@ -44,6 +44,9 @@ func Test_EncryptDecrypt(t *testing.T) { require.NoError(t, err) plaintext := loremIpsumDolor + _, err = encrypter.Encrypt(plaintext, key[:len(key)-1]) + require.Error(t, err) + ciphertext, err := encrypter.Encrypt(plaintext, key) require.NoError(t, err)