diff --git a/awskms.go b/awskms.go index cdf1d01..5e1e931 100644 --- a/awskms.go +++ b/awskms.go @@ -1,6 +1,8 @@ package awskms import ( + "context" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/kms/kmsiface" @@ -30,6 +32,22 @@ func Encrypt(plaintext []byte) ([]byte, error) { return res.CiphertextBlob, nil } +// EncryptWithContext performs encryption on plaintext with AWS KMS +// Returns empty []byte on error +func EncryptWithContext(ctx context.Context, plaintext []byte) ([]byte, error) { + input := &kms.EncryptInput{ + KeyId: aws.String(KeyID), + Plaintext: plaintext, + } + + res, err := Client.EncryptWithContext(ctx, input) + if err != nil { + return []byte{}, err + } + + return res.CiphertextBlob, nil +} + // Decrypt performs decryption on ciphertext with AWS KMS // Returns empty []byte on error func Decrypt(ciphertext []byte) ([]byte, error) { @@ -44,3 +62,18 @@ func Decrypt(ciphertext []byte) ([]byte, error) { return res.Plaintext, nil } + +// DecryptWithContext performs decryption on ciphertext with AWS KMS +// Returns empty []byte on error +func DecryptWithContext(ctx context.Context, ciphertext []byte) ([]byte, error) { + input := &kms.DecryptInput{ + CiphertextBlob: ciphertext, + } + + res, err := Client.DecryptWithContext(ctx, input) + if err != nil { + return []byte{}, err + } + + return res.Plaintext, nil +} diff --git a/awskms_test.go b/awskms_test.go index cd8d787..efd6d5f 100644 --- a/awskms_test.go +++ b/awskms_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/kms/kmsiface" "github.com/reddotpay/awskms" @@ -26,6 +28,17 @@ func (m *mockKMSClient) Encrypt(input *kms.EncryptInput) (*kms.EncryptOutput, er } } +func (m *mockKMSClient) EncryptWithContext(ctx aws.Context, input *kms.EncryptInput, req ...request.Option) (*kms.EncryptOutput, error) { + switch { + case bytes.Equal(input.Plaintext, []byte("error")): + return &kms.EncryptOutput{}, errors.New("AWS KMS error!") + default: + return &kms.EncryptOutput{ + CiphertextBlob: []byte("ciphertext"), + }, nil + } +} + func (m *mockKMSClient) Decrypt(input *kms.DecryptInput) (*kms.DecryptOutput, error) { switch { case bytes.Equal(input.CiphertextBlob, []byte("error")): @@ -37,6 +50,17 @@ func (m *mockKMSClient) Decrypt(input *kms.DecryptInput) (*kms.DecryptOutput, er } } +func (m *mockKMSClient) DecryptWithContext(ctx aws.Context, input *kms.DecryptInput, req ...request.Option) (*kms.DecryptOutput, error) { + switch { + case bytes.Equal(input.CiphertextBlob, []byte("error")): + return &kms.DecryptOutput{}, errors.New("AWS KMS error!") + default: + return &kms.DecryptOutput{ + Plaintext: []byte("plaintext"), + }, nil + } +} + func init() { awskms.Client = &mockKMSClient{} awskms.KeyID = "keyID" @@ -57,6 +81,21 @@ func TestAwsKms_Encrypt_Error(t *testing.T) { assert.Equal(t, []byte(""), ciphertext) } +func TestAwsKms_EncryptWithContext(t *testing.T) { + plaintext := []byte("plaintext") + ciphertext, err := awskms.EncryptWithContext(aws.BackgroundContext(), plaintext) + assert.Nil(t, err) + assert.Equal(t, []byte("ciphertext"), ciphertext) +} + +func TestAwsKms_EncryptWithContext_Error(t *testing.T) { + plaintext := []byte("error") + ciphertext, err := awskms.EncryptWithContext(aws.BackgroundContext(), plaintext) + assert.NotNil(t, err) + assert.Equal(t, "AWS KMS error!", err.Error()) + assert.Equal(t, []byte(""), ciphertext) +} + func TestAwsKms_Decrypt(t *testing.T) { ciphertext := []byte("ciphertext") plaintext, err := awskms.Decrypt(ciphertext) @@ -71,3 +110,18 @@ func TestAwsKms_Decrypt_Error(t *testing.T) { assert.Equal(t, "AWS KMS error!", err.Error()) assert.Equal(t, []byte(""), plaintext) } + +func TestAwsKms_DecryptWithContext(t *testing.T) { + ciphertext := []byte("ciphertext") + plaintext, err := awskms.DecryptWithContext(aws.BackgroundContext(), ciphertext) + assert.Nil(t, err) + assert.Equal(t, []byte("plaintext"), plaintext) +} + +func TestAwsKms_DecryptWithContext_Error(t *testing.T) { + ciphertext := []byte("error") + plaintext, err := awskms.DecryptWithContext(aws.BackgroundContext(), ciphertext) + assert.NotNil(t, err) + assert.Equal(t, "AWS KMS error!", err.Error()) + assert.Equal(t, []byte(""), plaintext) +}