Skip to content

Commit

Permalink
add methods with context
Browse files Browse the repository at this point in the history
  • Loading branch information
darylnwk committed Jul 31, 2018
1 parent 8bdfb1b commit c5199a9
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
33 changes: 33 additions & 0 deletions awskms.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package awskms

import (
"context"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
Expand Down Expand Up @@ -30,6 +32,22 @@ func Encrypt(plaintext []byte) ([]byte, error) {
return res.CiphertextBlob, nil
}

// EncryptWithContext performs encryption on plaintext with AWS KMS
// Returns empty []byte on error
func EncryptWithContext(ctx context.Context, plaintext []byte) ([]byte, error) {
input := &kms.EncryptInput{
KeyId: aws.String(KeyID),
Plaintext: plaintext,
}

res, err := Client.EncryptWithContext(ctx, input)
if err != nil {
return []byte{}, err
}

return res.CiphertextBlob, nil
}

// Decrypt performs decryption on ciphertext with AWS KMS
// Returns empty []byte on error
func Decrypt(ciphertext []byte) ([]byte, error) {
Expand All @@ -44,3 +62,18 @@ func Decrypt(ciphertext []byte) ([]byte, error) {

return res.Plaintext, nil
}

// DecryptWithContext performs decryption on ciphertext with AWS KMS
// Returns empty []byte on error
func DecryptWithContext(ctx context.Context, ciphertext []byte) ([]byte, error) {
input := &kms.DecryptInput{
CiphertextBlob: ciphertext,
}

res, err := Client.DecryptWithContext(ctx, input)
if err != nil {
return []byte{}, err
}

return res.Plaintext, nil
}
54 changes: 54 additions & 0 deletions awskms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
"github.com/reddotpay/awskms"
Expand All @@ -26,6 +28,17 @@ func (m *mockKMSClient) Encrypt(input *kms.EncryptInput) (*kms.EncryptOutput, er
}
}

func (m *mockKMSClient) EncryptWithContext(ctx aws.Context, input *kms.EncryptInput, req ...request.Option) (*kms.EncryptOutput, error) {
switch {
case bytes.Equal(input.Plaintext, []byte("error")):
return &kms.EncryptOutput{}, errors.New("AWS KMS error!")
default:
return &kms.EncryptOutput{
CiphertextBlob: []byte("ciphertext"),
}, nil
}
}

func (m *mockKMSClient) Decrypt(input *kms.DecryptInput) (*kms.DecryptOutput, error) {
switch {
case bytes.Equal(input.CiphertextBlob, []byte("error")):
Expand All @@ -37,6 +50,17 @@ func (m *mockKMSClient) Decrypt(input *kms.DecryptInput) (*kms.DecryptOutput, er
}
}

func (m *mockKMSClient) DecryptWithContext(ctx aws.Context, input *kms.DecryptInput, req ...request.Option) (*kms.DecryptOutput, error) {
switch {
case bytes.Equal(input.CiphertextBlob, []byte("error")):
return &kms.DecryptOutput{}, errors.New("AWS KMS error!")
default:
return &kms.DecryptOutput{
Plaintext: []byte("plaintext"),
}, nil
}
}

func init() {
awskms.Client = &mockKMSClient{}
awskms.KeyID = "keyID"
Expand All @@ -57,6 +81,21 @@ func TestAwsKms_Encrypt_Error(t *testing.T) {
assert.Equal(t, []byte(""), ciphertext)
}

func TestAwsKms_EncryptWithContext(t *testing.T) {
plaintext := []byte("plaintext")
ciphertext, err := awskms.EncryptWithContext(aws.BackgroundContext(), plaintext)
assert.Nil(t, err)
assert.Equal(t, []byte("ciphertext"), ciphertext)
}

func TestAwsKms_EncryptWithContext_Error(t *testing.T) {
plaintext := []byte("error")
ciphertext, err := awskms.EncryptWithContext(aws.BackgroundContext(), plaintext)
assert.NotNil(t, err)
assert.Equal(t, "AWS KMS error!", err.Error())
assert.Equal(t, []byte(""), ciphertext)
}

func TestAwsKms_Decrypt(t *testing.T) {
ciphertext := []byte("ciphertext")
plaintext, err := awskms.Decrypt(ciphertext)
Expand All @@ -71,3 +110,18 @@ func TestAwsKms_Decrypt_Error(t *testing.T) {
assert.Equal(t, "AWS KMS error!", err.Error())
assert.Equal(t, []byte(""), plaintext)
}

func TestAwsKms_DecryptWithContext(t *testing.T) {
ciphertext := []byte("ciphertext")
plaintext, err := awskms.DecryptWithContext(aws.BackgroundContext(), ciphertext)
assert.Nil(t, err)
assert.Equal(t, []byte("plaintext"), plaintext)
}

func TestAwsKms_DecryptWithContext_Error(t *testing.T) {
ciphertext := []byte("error")
plaintext, err := awskms.DecryptWithContext(aws.BackgroundContext(), ciphertext)
assert.NotNil(t, err)
assert.Equal(t, "AWS KMS error!", err.Error())
assert.Equal(t, []byte(""), plaintext)
}

0 comments on commit c5199a9

Please sign in to comment.