diff --git a/pom.xml b/pom.xml
index c426db5e6..a8774878b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -42,7 +42,7 @@
com.amazonaws
aws-java-sdk
- 1.11.561
+ 1.11.677
true
diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java
new file mode 100644
index 000000000..58f78833c
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.exception;
+
+/**
+ * This exception is thrown when an Amazon Resource Name is provided that does not
+ * match the CMK Alias or ARN format.
+ */
+public class MalformedArnException extends AwsCryptoException {
+
+ private static final long serialVersionUID = -1L;
+
+ public MalformedArnException() {
+ super();
+ }
+
+ public MalformedArnException(final String message) {
+ super(message);
+ }
+
+ public MalformedArnException(final Throwable cause) {
+ super(cause);
+ }
+
+ public MalformedArnException(final String message, final Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java
new file mode 100644
index 000000000..aa1c87799
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.exception;
+
+/**
+ * This exception is thrown when the key used by KMS to decrypt a data key does not
+ * match the provider information contained within the encrypted data key.
+ */
+public class MismatchedDataKeyException extends IllegalStateException {
+
+ private static final long serialVersionUID = -1L;
+
+ public MismatchedDataKeyException() {
+ super();
+ }
+
+ public MismatchedDataKeyException(final String message) {
+ super(message);
+ }
+
+ public MismatchedDataKeyException(final Throwable cause) {
+ super(cause);
+ }
+
+ public MismatchedDataKeyException(final String message, final Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java
new file mode 100644
index 000000000..1ad8543d8
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java
@@ -0,0 +1,171 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.keyrings;
+
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
+import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
+import com.amazonaws.encryptionsdk.exception.MalformedArnException;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.DecryptDataKeyResult;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.GenerateDataKeyResult;
+import com.amazonaws.encryptionsdk.kms.KmsUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.isArnWellFormed;
+import static java.util.Collections.emptyList;
+import static java.util.Collections.unmodifiableList;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * A keyring which interacts with AWS Key Management Service (KMS) to create,
+ * encrypt, and decrypt data keys using KMS defined Customer Master Keys (CMKs).
+ */
+public class KmsKeyring implements Keyring {
+
+ private final DataKeyEncryptionDao dataKeyEncryptionDao;
+ private final List keyIds;
+ private final String generatorKeyId;
+ private final boolean isDiscovery;
+
+ KmsKeyring(DataKeyEncryptionDao dataKeyEncryptionDao, List keyIds, String generatorKeyId) {
+ requireNonNull(dataKeyEncryptionDao, "dataKeyEncryptionDao is required");
+ this.dataKeyEncryptionDao = dataKeyEncryptionDao;
+ this.keyIds = keyIds == null ? emptyList() : unmodifiableList(keyIds);
+ this.generatorKeyId = generatorKeyId;
+ this.isDiscovery = this.generatorKeyId == null && this.keyIds.isEmpty();
+
+ if (!this.keyIds.stream().allMatch(KmsUtils::isArnWellFormed)) {
+ throw new MalformedArnException("keyIds must contain only CMK aliases and well formed ARNs");
+ }
+
+ if (generatorKeyId != null) {
+ if (!isArnWellFormed(generatorKeyId)) {
+ throw new MalformedArnException("generatorKeyId must be either a CMK alias or a well formed ARN");
+ }
+ if (this.keyIds.contains(generatorKeyId)) {
+ throw new IllegalArgumentException("KeyIds should not contain the generatorKeyId");
+ }
+ }
+ }
+
+ @Override
+ public void onEncrypt(EncryptionMaterials encryptionMaterials) {
+ requireNonNull(encryptionMaterials, "encryptionMaterials are required");
+
+ // If this keyring is a discovery keyring, OnEncrypt MUST return the input encryption materials unmodified.
+ if (isDiscovery) {
+ return;
+ }
+
+ // If the input encryption materials do not contain a plaintext data key and this keyring does not
+ // have a generator defined, OnEncrypt MUST not modify the encryption materials and MUST fail.
+ if (!encryptionMaterials.hasPlaintextDataKey() && generatorKeyId == null) {
+ throw new AwsCryptoException("Encryption materials must contain either a plaintext data key or a generator");
+ }
+
+ final List keyIdsToEncrypt = new ArrayList<>(keyIds);
+
+ // If the input encryption materials do not contain a plaintext data key and a generator is defined onEncrypt
+ // MUST attempt to generate a new plaintext data key and encrypt that data key by calling KMS GenerateDataKey.
+ if (!encryptionMaterials.hasPlaintextDataKey()) {
+ generateDataKey(encryptionMaterials);
+ } else {
+ // If this keyring's generator is defined and was not used to generate a data key, OnEncrypt
+ // MUST also attempt to encrypt the plaintext data key using the CMK specified by the generator.
+ keyIdsToEncrypt.add(generatorKeyId);
+ }
+
+ // Given a plaintext data key in the encryption materials, OnEncrypt MUST attempt
+ // to encrypt the plaintext data key using each CMK specified in it's key IDs list.
+ for (String keyId : keyIdsToEncrypt) {
+ encryptDataKey(keyId, encryptionMaterials);
+ }
+ }
+
+ private void generateDataKey(final EncryptionMaterials encryptionMaterials) {
+ final GenerateDataKeyResult result = dataKeyEncryptionDao.generateDataKey(generatorKeyId,
+ encryptionMaterials.getAlgorithmSuite(), encryptionMaterials.getEncryptionContext());
+
+ encryptionMaterials.setPlaintextDataKey(result.getPlaintextDataKey(),
+ new KeyringTraceEntry(KMS_PROVIDER_ID, generatorKeyId, KeyringTraceFlag.GENERATED_DATA_KEY));
+ encryptionMaterials.addEncryptedDataKey(result.getEncryptedDataKey(),
+ new KeyringTraceEntry(KMS_PROVIDER_ID, generatorKeyId, KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT));
+ }
+
+ private void encryptDataKey(final String keyId, final EncryptionMaterials encryptionMaterials) {
+ final EncryptedDataKey encryptedDataKey = dataKeyEncryptionDao.encryptDataKey(keyId,
+ encryptionMaterials.getPlaintextDataKey(), encryptionMaterials.getEncryptionContext());
+
+ encryptionMaterials.addEncryptedDataKey(encryptedDataKey,
+ new KeyringTraceEntry(KMS_PROVIDER_ID, keyId, KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT));
+ }
+
+ @Override
+ public void onDecrypt(DecryptionMaterials decryptionMaterials, List extends EncryptedDataKey> encryptedDataKeys) {
+ requireNonNull(decryptionMaterials, "decryptionMaterials are required");
+ requireNonNull(encryptedDataKeys, "encryptedDataKeys are required");
+
+ if (decryptionMaterials.hasPlaintextDataKey() || encryptedDataKeys.isEmpty()) {
+ return;
+ }
+
+ if (!encryptedDataKeys.stream()
+ .filter(edk -> edk.getProviderId().equals(KMS_PROVIDER_ID))
+ .map(edk -> new String(edk.getProviderInformation(), PROVIDER_ENCODING))
+ .allMatch(KmsUtils::isArnWellFormed)) {
+ throw new MalformedArnException("encryptedDataKeys contains a malformed ARN");
+ }
+
+ for (EncryptedDataKey encryptedDataKey : encryptedDataKeys) {
+ if (okToDecrypt(encryptedDataKey)) {
+ try {
+ final DecryptDataKeyResult result = dataKeyEncryptionDao.decryptDataKey(encryptedDataKey,
+ decryptionMaterials.getAlgorithmSuite(), decryptionMaterials.getEncryptionContext());
+
+ decryptionMaterials.setPlaintextDataKey(result.getPlaintextDataKey(),
+ new KeyringTraceEntry(KMS_PROVIDER_ID, result.getKeyArn(),
+ KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT));
+ return;
+ } catch (CannotUnwrapDataKeyException e) {
+ continue;
+ }
+ }
+ }
+ }
+
+ private boolean okToDecrypt(EncryptedDataKey encryptedDataKey) {
+ // Only attempt to decrypt keys provided by KMS
+ if (!encryptedDataKey.getProviderId().equals(KMS_PROVIDER_ID)) {
+ return false;
+ }
+
+ // If this keyring is a discovery keyring, OnDecrypt MUST attempt to
+ // decrypt every encrypted data key in the input encrypted data key list
+ if (isDiscovery) {
+ return true;
+ }
+
+ final String providerInfo = new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING);
+
+ // OnDecrypt MUST attempt to decrypt each input encrypted data key in the input
+ // encrypted data key list where the key provider info has a value equal to one
+ // of the ARNs in this keyring's key IDs or the generator
+ return providerInfo.equals(generatorKeyId) || keyIds.stream().anyMatch(providerInfo::equals);
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java b/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java
new file mode 100644
index 000000000..4267ba6f8
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java
@@ -0,0 +1,104 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.encryptionsdk.CryptoAlgorithm;
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+
+import javax.crypto.SecretKey;
+import java.util.List;
+import java.util.Map;
+
+public interface DataKeyEncryptionDao {
+
+ /**
+ * Generates a unique data key, returning both the plaintext copy of the key and an encrypted copy encrypted using
+ * the customer master key specified by the given keyId.
+ *
+ * @param keyId The customer master key to encrypt the generated key with.
+ * @param algorithmSuite The algorithm suite associated with the key.
+ * @param encryptionContext The encryption context.
+ * @return GenerateDataKeyResult containing the plaintext data key and the encrypted data key.
+ */
+ GenerateDataKeyResult generateDataKey(String keyId, CryptoAlgorithm algorithmSuite, Map encryptionContext);
+
+ /**
+ * Encrypts the given plaintext data key using the customer aster key specified by the given keyId.
+ *
+ * @param keyId The customer master key to encrypt the plaintext data key with.
+ * @param plaintextDataKey The plaintext data key to encrypt.
+ * @param encryptionContext The encryption context.
+ * @return The encrypted data key.
+ */
+ EncryptedDataKey encryptDataKey(final String keyId, SecretKey plaintextDataKey, Map encryptionContext);
+
+ /**
+ * Decrypted the given encrypted data key.
+ *
+ * @param encryptedDataKey The encrypted data key to decrypt.
+ * @param algorithmSuite The algorithm suite associated with the key.
+ * @param encryptionContext The encryption context.
+ * @return DecryptDataKeyResult containing the plaintext data key and the ARN of the key that decrypted it.
+ */
+ DecryptDataKeyResult decryptDataKey(EncryptedDataKey encryptedDataKey, CryptoAlgorithm algorithmSuite, Map encryptionContext);
+
+ /**
+ * Constructs an instance of DataKeyEncryptionDao that uses AWS Key Management Service (KMS) for
+ * generation, encryption, and decryption of data keys.
+ *
+ * @param clientSupplier A supplier of AWSKMS clients
+ * @param grantTokens A list of grant tokens to supply to KMS
+ * @return The DataKeyEncryptionDao
+ */
+ static DataKeyEncryptionDao kms(KmsClientSupplier clientSupplier, List grantTokens) {
+ return new KmsDataKeyEncryptionDao(clientSupplier, grantTokens);
+ }
+
+ class GenerateDataKeyResult {
+ private final SecretKey plaintextDataKey;
+ private final EncryptedDataKey encryptedDataKey;
+
+ public GenerateDataKeyResult(SecretKey plaintextDataKey, EncryptedDataKey encryptedDataKey) {
+ this.plaintextDataKey = plaintextDataKey;
+ this.encryptedDataKey = encryptedDataKey;
+ }
+
+ public SecretKey getPlaintextDataKey() {
+ return plaintextDataKey;
+ }
+
+ public EncryptedDataKey getEncryptedDataKey() {
+ return encryptedDataKey;
+ }
+ }
+
+ class DecryptDataKeyResult {
+ private final String keyArn;
+ private final SecretKey plaintextDataKey;
+
+ public DecryptDataKeyResult(String keyArn, SecretKey plaintextDataKey) {
+ this.keyArn = keyArn;
+ this.plaintextDataKey = plaintextDataKey;
+ }
+
+ public String getKeyArn() {
+ return keyArn;
+ }
+
+ public SecretKey getPlaintextDataKey() {
+ return plaintextDataKey;
+ }
+
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java
new file mode 100644
index 000000000..55ef2678b
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.services.kms.AWSKMS;
+
+import javax.annotation.Nullable;
+
+/**
+ * Represents a function that accepts an AWS region and returns an {@code AWSKMS} client for that region. The
+ * function should be able to handle when the region is null.
+ */
+@FunctionalInterface
+public interface KmsClientSupplier {
+
+ /**
+ * Gets an {@code AWSKMS} client for the given regionId.
+ *
+ * @param regionId The AWS region (or null)
+ * @return The AWSKMS client
+ */
+ AWSKMS getClient(@Nullable String regionId);
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java
new file mode 100644
index 000000000..f67846c9f
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java
@@ -0,0 +1,171 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.AmazonServiceException;
+import com.amazonaws.AmazonWebServiceRequest;
+import com.amazonaws.encryptionsdk.CryptoAlgorithm;
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
+import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
+import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException;
+import com.amazonaws.encryptionsdk.internal.VersionInfo;
+import com.amazonaws.encryptionsdk.model.KeyBlob;
+import com.amazonaws.services.kms.model.DecryptRequest;
+import com.amazonaws.services.kms.model.EncryptRequest;
+import com.amazonaws.services.kms.model.GenerateDataKeyRequest;
+
+import javax.crypto.SecretKey;
+import javax.crypto.spec.SecretKeySpec;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.getClientByArn;
+import static java.util.Objects.requireNonNull;
+import static org.apache.commons.lang3.Validate.isTrue;
+
+/**
+ * An implementation of DataKeyEncryptionDao that uses AWS Key Management Service (KMS) for
+ * generation, encryption, and decryption of data keys. The KmsMethods interface is implemented
+ * to allow usage in KmsMasterKey.
+ */
+class KmsDataKeyEncryptionDao implements DataKeyEncryptionDao, KmsMethods {
+
+ private final KmsClientSupplier clientSupplier;
+ private List grantTokens;
+
+ KmsDataKeyEncryptionDao(KmsClientSupplier clientSupplier, List grantTokens) {
+ requireNonNull(clientSupplier, "clientSupplier is required");
+
+ this.clientSupplier = clientSupplier;
+ this.grantTokens = grantTokens == null ? new ArrayList<>() : new ArrayList<>(grantTokens);
+ }
+
+ @Override
+ public GenerateDataKeyResult generateDataKey(String keyId, CryptoAlgorithm algorithmSuite, Map encryptionContext) {
+ requireNonNull(keyId, "keyId is required");
+ requireNonNull(algorithmSuite, "algorithmSuite is required");
+ requireNonNull(encryptionContext, "encryptionContext is required");
+
+ final com.amazonaws.services.kms.model.GenerateDataKeyResult kmsResult;
+
+ try {
+ kmsResult = getClientByArn(keyId, clientSupplier)
+ .generateDataKey(updateUserAgent(
+ new GenerateDataKeyRequest()
+ .withKeyId(keyId)
+ .withNumberOfBytes(algorithmSuite.getDataKeyLength())
+ .withEncryptionContext(encryptionContext)
+ .withGrantTokens(grantTokens)));
+ } catch (final AmazonServiceException ex) {
+ throw new AwsCryptoException(ex);
+ }
+
+ final byte[] rawKey = new byte[algorithmSuite.getDataKeyLength()];
+ kmsResult.getPlaintext().get(rawKey);
+ if (kmsResult.getPlaintext().remaining() > 0) {
+ throw new IllegalStateException("Received an unexpected number of bytes from KMS");
+ }
+ final byte[] encryptedKey = new byte[kmsResult.getCiphertextBlob().remaining()];
+ kmsResult.getCiphertextBlob().get(encryptedKey);
+
+ return new GenerateDataKeyResult(new SecretKeySpec(rawKey, algorithmSuite.getDataKeyAlgo()),
+ new KeyBlob(KMS_PROVIDER_ID, kmsResult.getKeyId().getBytes(PROVIDER_ENCODING), encryptedKey));
+ }
+
+ @Override
+ public EncryptedDataKey encryptDataKey(final String keyId, SecretKey plaintextDataKey, Map encryptionContext) {
+ requireNonNull(keyId, "keyId is required");
+ requireNonNull(plaintextDataKey, "plaintextDataKey is required");
+ requireNonNull(encryptionContext, "encryptionContext is required");
+ isTrue(plaintextDataKey.getFormat().equals("RAW"), "Only RAW encoded keys are supported");
+
+ final com.amazonaws.services.kms.model.EncryptResult kmsResult;
+
+ try {
+ kmsResult = getClientByArn(keyId, clientSupplier)
+ .encrypt(updateUserAgent(new EncryptRequest()
+ .withKeyId(keyId)
+ .withPlaintext(ByteBuffer.wrap(plaintextDataKey.getEncoded()))
+ .withEncryptionContext(encryptionContext)
+ .withGrantTokens(grantTokens)));
+ } catch (final AmazonServiceException ex) {
+ throw new AwsCryptoException(ex);
+ }
+ final byte[] encryptedDataKey = new byte[kmsResult.getCiphertextBlob().remaining()];
+ kmsResult.getCiphertextBlob().get(encryptedDataKey);
+
+ return new KeyBlob(KMS_PROVIDER_ID, kmsResult.getKeyId().getBytes(PROVIDER_ENCODING), encryptedDataKey);
+
+ }
+
+ @Override
+ public DecryptDataKeyResult decryptDataKey(EncryptedDataKey encryptedDataKey, CryptoAlgorithm algorithmSuite, Map encryptionContext) {
+ requireNonNull(encryptedDataKey, "encryptedDataKey is required");
+ requireNonNull(algorithmSuite, "algorithmSuite is required");
+ requireNonNull(encryptionContext, "encryptionContext is required");
+
+ final String providerInformation = new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING);
+ final com.amazonaws.services.kms.model.DecryptResult kmsResult;
+
+ try {
+ kmsResult = getClientByArn(providerInformation, clientSupplier)
+ .decrypt(updateUserAgent(new DecryptRequest()
+ .withCiphertextBlob(ByteBuffer.wrap(encryptedDataKey.getEncryptedDataKey()))
+ .withEncryptionContext(encryptionContext)
+ .withGrantTokens(grantTokens)));
+ } catch (final AmazonServiceException ex) {
+ throw new CannotUnwrapDataKeyException(ex);
+ }
+
+ if (!kmsResult.getKeyId().equals(providerInformation)) {
+ throw new MismatchedDataKeyException("Received an unexpected key Id from KMS");
+ }
+
+ final byte[] rawKey = new byte[algorithmSuite.getDataKeyLength()];
+ kmsResult.getPlaintext().get(rawKey);
+ if (kmsResult.getPlaintext().remaining() > 0) {
+ throw new IllegalStateException("Received an unexpected number of bytes from KMS");
+ }
+
+ return new DecryptDataKeyResult(kmsResult.getKeyId(), new SecretKeySpec(rawKey, algorithmSuite.getDataKeyAlgo()));
+
+ }
+
+ private T updateUserAgent(T request) {
+ request.getRequestClientOptions().appendUserAgent(VersionInfo.USER_AGENT);
+
+ return request;
+ }
+
+ @Override
+ public void setGrantTokens(List grantTokens) {
+ this.grantTokens = new ArrayList<>(grantTokens);
+ }
+
+ @Override
+ public List getGrantTokens() {
+ return Collections.unmodifiableList(grantTokens);
+ }
+
+ @Override
+ public void addGrantToken(String grantToken) {
+ grantTokens.add(grantToken);
+ }
+}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java
index b78840221..4b6db613c 100644
--- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java
@@ -14,17 +14,12 @@
package com.amazonaws.encryptionsdk.kms;
import javax.crypto.SecretKey;
-import javax.crypto.spec.SecretKeySpec;
-import java.nio.ByteBuffer;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
-import com.amazonaws.AmazonServiceException;
-import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.encryptionsdk.AwsCrypto;
@@ -34,31 +29,20 @@
import com.amazonaws.encryptionsdk.MasterKey;
import com.amazonaws.encryptionsdk.MasterKeyProvider;
import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
+import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException;
import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException;
-import com.amazonaws.encryptionsdk.internal.VersionInfo;
import com.amazonaws.services.kms.AWSKMS;
-import com.amazonaws.services.kms.model.DecryptRequest;
-import com.amazonaws.services.kms.model.DecryptResult;
-import com.amazonaws.services.kms.model.EncryptRequest;
-import com.amazonaws.services.kms.model.EncryptResult;
-import com.amazonaws.services.kms.model.GenerateDataKeyRequest;
-import com.amazonaws.services.kms.model.GenerateDataKeyResult;
+
+import static java.util.Collections.emptyList;
/**
* Represents a single Customer Master Key (CMK) and is used to encrypt/decrypt data with
* {@link AwsCrypto}.
*/
public final class KmsMasterKey extends MasterKey implements KmsMethods {
- private final Supplier kms_;
+ private final KmsDataKeyEncryptionDao dataKeyEncryptionDao;
private final MasterKeyProvider sourceProvider_;
private final String id_;
- private final List grantTokens_ = new ArrayList<>();
-
- private T updateUserAgent(T request) {
- request.getRequestClientOptions().appendUserAgent(VersionInfo.USER_AGENT);
-
- return request;
- }
/**
*
@@ -84,7 +68,7 @@ static KmsMasterKey getInstance(final Supplier kms, final String id,
}
private KmsMasterKey(final Supplier kms, final String id, final MasterKeyProvider provider) {
- kms_ = kms;
+ dataKeyEncryptionDao = new KmsDataKeyEncryptionDao(s -> kms.get(), emptyList());
id_ = id;
sourceProvider_ = provider;
}
@@ -102,39 +86,27 @@ public String getKeyId() {
@Override
public DataKey generateDataKey(final CryptoAlgorithm algorithm,
final Map encryptionContext) {
- final GenerateDataKeyResult gdkResult = kms_.get().generateDataKey(updateUserAgent(
- new GenerateDataKeyRequest()
- .withKeyId(getKeyId())
- .withNumberOfBytes(algorithm.getDataKeyLength())
- .withEncryptionContext(encryptionContext)
- .withGrantTokens(grantTokens_)
- ));
- final byte[] rawKey = new byte[algorithm.getDataKeyLength()];
- gdkResult.getPlaintext().get(rawKey);
- if (gdkResult.getPlaintext().remaining() > 0) {
- throw new IllegalStateException("Recieved an unexpected number of bytes from KMS");
- }
- final byte[] encryptedKey = new byte[gdkResult.getCiphertextBlob().remaining()];
- gdkResult.getCiphertextBlob().get(encryptedKey);
-
- final SecretKeySpec key = new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo());
- return new DataKey<>(key, encryptedKey, gdkResult.getKeyId().getBytes(StandardCharsets.UTF_8), this);
+ final DataKeyEncryptionDao.GenerateDataKeyResult gdkResult = dataKeyEncryptionDao.generateDataKey(
+ getKeyId(), algorithm, encryptionContext);
+ return new DataKey<>(gdkResult.getPlaintextDataKey(),
+ gdkResult.getEncryptedDataKey().getEncryptedDataKey(),
+ gdkResult.getEncryptedDataKey().getProviderInformation(),
+ this);
}
@Override
public void setGrantTokens(final List grantTokens) {
- grantTokens_.clear();
- grantTokens_.addAll(grantTokens);
+ dataKeyEncryptionDao.setGrantTokens(grantTokens);
}
@Override
public List getGrantTokens() {
- return grantTokens_;
+ return dataKeyEncryptionDao.getGrantTokens();
}
@Override
public void addGrantToken(final String grantToken) {
- grantTokens_.add(grantToken);
+ dataKeyEncryptionDao.addGrantToken(grantToken);
}
@Override
@@ -142,22 +114,12 @@ public DataKey encryptDataKey(final CryptoAlgorithm algorithm,
final Map encryptionContext,
final DataKey> dataKey) {
final SecretKey key = dataKey.getKey();
- if (!key.getFormat().equals("RAW")) {
- throw new IllegalArgumentException("Only RAW encoded keys are supported");
- }
- try {
- final EncryptResult encryptResult = kms_.get().encrypt(updateUserAgent(
- new EncryptRequest()
- .withKeyId(id_)
- .withPlaintext(ByteBuffer.wrap(key.getEncoded()))
- .withEncryptionContext(encryptionContext)
- .withGrantTokens(grantTokens_)));
- final byte[] edk = new byte[encryptResult.getCiphertextBlob().remaining()];
- encryptResult.getCiphertextBlob().get(edk);
- return new DataKey<>(dataKey.getKey(), edk, encryptResult.getKeyId().getBytes(StandardCharsets.UTF_8), this);
- } catch (final AmazonServiceException asex) {
- throw new AwsCryptoException(asex);
- }
+ final EncryptedDataKey encryptedDataKey = dataKeyEncryptionDao.encryptDataKey(id_, key, encryptionContext);
+
+ return new DataKey<>(dataKey.getKey(),
+ encryptedDataKey.getEncryptedDataKey(),
+ encryptedDataKey.getProviderInformation(),
+ this);
}
@Override
@@ -168,24 +130,16 @@ public DataKey decryptDataKey(final CryptoAlgorithm algorithm,
final List exceptions = new ArrayList<>();
for (final EncryptedDataKey edk : encryptedDataKeys) {
try {
- final DecryptResult decryptResult = kms_.get().decrypt(updateUserAgent(
- new DecryptRequest()
- .withCiphertextBlob(ByteBuffer.wrap(edk.getEncryptedDataKey()))
- .withEncryptionContext(encryptionContext)
- .withGrantTokens(grantTokens_)));
- if (decryptResult.getKeyId().equals(id_)) {
- final byte[] rawKey = new byte[algorithm.getDataKeyLength()];
- decryptResult.getPlaintext().get(rawKey);
- if (decryptResult.getPlaintext().remaining() > 0) {
- throw new IllegalStateException("Received an unexpected number of bytes from KMS");
- }
- return new DataKey<>(
- new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()),
- edk.getEncryptedDataKey(),
- edk.getProviderInformation(), this);
- }
- } catch (final AmazonServiceException awsex) {
- exceptions.add(awsex);
+ final DataKeyEncryptionDao.DecryptDataKeyResult result = dataKeyEncryptionDao.decryptDataKey(edk, algorithm, encryptionContext);
+ return new DataKey<>(
+ result.getPlaintextDataKey(),
+ edk.getEncryptedDataKey(),
+ edk.getProviderInformation(), this);
+ } catch (final AwsCryptoException | MismatchedDataKeyException ex) {
+ // Earlier versions of KmsMasterKey compare the returned keyId to the encryptedDataKey
+ // provider information and skip that key if it doesn't match. KmsDataKeyEncryptionDao
+ // throws MismatchedDataKeyException in this case, so this maintains the existing behavior.
+ exceptions.add(ex);
}
}
diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java
new file mode 100644
index 000000000..35ddc37e2
--- /dev/null
+++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java
@@ -0,0 +1,71 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.arn.Arn;
+import com.amazonaws.encryptionsdk.exception.MalformedArnException;
+import com.amazonaws.services.kms.AWSKMS;
+
+public class KmsUtils {
+
+ private static final String ALIAS_PREFIX = "alias/";
+ /**
+ * The provider ID used for the KmsKeyring
+ */
+ public static final String KMS_PROVIDER_ID = "aws-kms";
+
+ /**
+ * Parses region from the given arn (if possible) and passes that region to the
+ * given clientSupplier to produce an {@code AWSKMS} client.
+ *
+ * @param arn The Amazon Resource Name or Key Alias
+ * @param clientSupplier The client supplier
+ * @return AWSKMS The client
+ * @throws MalformedArnException if the arn is malformed
+ */
+ public static AWSKMS getClientByArn(String arn, KmsClientSupplier clientSupplier) throws MalformedArnException {
+ if (isKeyAlias(arn)) {
+ return clientSupplier.getClient(null);
+ }
+
+ try {
+ return clientSupplier.getClient(Arn.fromString(arn).getRegion());
+ } catch (IllegalArgumentException e) {
+ throw new MalformedArnException(e);
+ }
+ }
+
+ /**
+ * Returns true if the given arn is a well formed Amazon Resource Name or Key Alias
+ *
+ * @param arn The Amazon Resource Name or Key Alias
+ * @return True if well formed, false otherwise
+ */
+ public static boolean isArnWellFormed(String arn) {
+ if (isKeyAlias(arn)) {
+ return true;
+ }
+
+ try {
+ Arn.fromString(arn);
+ return true;
+ } catch (IllegalArgumentException e) {
+ return false;
+ }
+ }
+
+ private static boolean isKeyAlias(String arn) {
+ return arn.startsWith(ALIAS_PREFIX) && !arn.contains(":");
+ }
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java b/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java
new file mode 100644
index 000000000..0b8034b9a
--- /dev/null
+++ b/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java
@@ -0,0 +1,298 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.keyrings;
+
+import com.amazonaws.encryptionsdk.CryptoAlgorithm;
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
+import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
+import com.amazonaws.encryptionsdk.exception.MalformedArnException;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao;
+import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.DecryptDataKeyResult;
+import com.amazonaws.encryptionsdk.model.KeyBlob;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import javax.crypto.SecretKey;
+import javax.crypto.spec.SecretKeySpec;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING;
+import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate;
+import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.ENCRYPTED_DATA_KEY;
+import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.GENERATED_DATA_KEY;
+import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+class KmsKeyringTest {
+
+ private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256;
+ private static final SecretKey PLAINTEXT_DATA_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo());
+ private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue");
+ private static final String GENERATOR_KEY_ID = "arn:aws:kms:us-east-1:999999999999:key/generator-89ab-cdef-fedc-ba9876543210";
+ private static final String KEY_ID_1 = "arn:aws:kms:us-east-1:999999999999:key/key1-23bv-sdfs-werw-234323nfdsf";
+ private static final String KEY_ID_2 = "arn:aws:kms:us-east-1:999999999999:key/key2-02ds-wvjs-aswe-a4923489273";
+ private static final EncryptedDataKey ENCRYPTED_GENERATOR_KEY = new KeyBlob(KMS_PROVIDER_ID,
+ GENERATOR_KEY_ID.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength()));
+ private static final EncryptedDataKey ENCRYPTED_KEY_1 = new KeyBlob(KMS_PROVIDER_ID,
+ KEY_ID_1.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength()));
+ private static final EncryptedDataKey ENCRYPTED_KEY_2 = new KeyBlob(KMS_PROVIDER_ID,
+ KEY_ID_2.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength()));
+ private static final KeyringTraceEntry ENCRYPTED_GENERATOR_KEY_TRACE =
+ new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT);
+ private static final KeyringTraceEntry ENCRYPTED_KEY_1_TRACE =
+ new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_1, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT);
+ private static final KeyringTraceEntry ENCRYPTED_KEY_2_TRACE =
+ new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT);
+ private static final KeyringTraceEntry GENERATED_DATA_KEY_TRACE =
+ new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, GENERATED_DATA_KEY);
+ @Mock(lenient = true) private DataKeyEncryptionDao dataKeyEncryptionDao;
+ private Keyring keyring;
+
+ @BeforeEach
+ void setup() {
+ when(dataKeyEncryptionDao.encryptDataKey(GENERATOR_KEY_ID, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_GENERATOR_KEY);
+ when(dataKeyEncryptionDao.encryptDataKey(KEY_ID_1, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_KEY_1);
+ when(dataKeyEncryptionDao.encryptDataKey(KEY_ID_2, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_KEY_2);
+
+ when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_GENERATOR_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT))
+ .thenReturn(new DecryptDataKeyResult(GENERATOR_KEY_ID, PLAINTEXT_DATA_KEY));
+ when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT))
+ .thenReturn(new DecryptDataKeyResult(KEY_ID_1, PLAINTEXT_DATA_KEY));
+ when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_2, ALGORITHM_SUITE, ENCRYPTION_CONTEXT))
+ .thenReturn(new DecryptDataKeyResult(KEY_ID_2, PLAINTEXT_DATA_KEY));
+
+ List keyIds = new ArrayList<>();
+ keyIds.add(KEY_ID_1);
+ keyIds.add(KEY_ID_2);
+ keyring = new KmsKeyring(dataKeyEncryptionDao, keyIds, GENERATOR_KEY_ID);
+ }
+
+ @Test
+ void testMalformedArns() {
+ assertThrows(MalformedArnException.class, () -> new KmsKeyring(dataKeyEncryptionDao, null, "badArn"));
+ assertThrows(MalformedArnException.class, () -> new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList("badArn"), GENERATOR_KEY_ID));
+
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(new KeyBlob(KMS_PROVIDER_ID, "badArn".getBytes(PROVIDER_ENCODING), new byte[]{}));
+ assertThrows(MalformedArnException.class, () -> keyring.onDecrypt(decryptionMaterials, encryptedDataKeys));
+
+ // Malformed Arn for a non KMS provider shouldn't fail
+ encryptedDataKeys.clear();
+ encryptedDataKeys.add(new KeyBlob("OtherProviderId", "badArn".getBytes(PROVIDER_ENCODING), new byte[]{}));
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+ }
+
+ @Test
+ void testGeneratorKeyInKeyIds() {
+ assertThrows(IllegalArgumentException.class, () -> new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList(GENERATOR_KEY_ID), GENERATOR_KEY_ID));
+ }
+
+ @Test
+ void testEncryptDecryptExistingDataKey() {
+ EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .plaintextDataKey(PLAINTEXT_DATA_KEY)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+
+ keyring.onEncrypt(encryptionMaterials);
+
+ assertEquals(3, encryptionMaterials.getEncryptedDataKeys().size());
+ assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_GENERATOR_KEY));
+ assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_1));
+ assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_2));
+
+ assertEquals(3, encryptionMaterials.getKeyringTrace().getEntries().size());
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_GENERATOR_KEY_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_2_TRACE));
+
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(ENCRYPTED_GENERATOR_KEY);
+ encryptedDataKeys.add(ENCRYPTED_KEY_1);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testEncryptNullDataKey() {
+ EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .keyringTrace(new KeyringTrace())
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+
+ when(dataKeyEncryptionDao.generateDataKey(GENERATOR_KEY_ID, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenReturn(new DataKeyEncryptionDao.GenerateDataKeyResult(PLAINTEXT_DATA_KEY, ENCRYPTED_GENERATOR_KEY));
+ keyring.onEncrypt(encryptionMaterials);
+
+ assertEquals(PLAINTEXT_DATA_KEY, encryptionMaterials.getPlaintextDataKey());
+
+ assertEquals(4, encryptionMaterials.getKeyringTrace().getEntries().size());
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(GENERATED_DATA_KEY_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_GENERATOR_KEY_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE));
+ assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_2_TRACE));
+
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(ENCRYPTED_GENERATOR_KEY);
+ encryptedDataKeys.add(ENCRYPTED_KEY_1);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testDiscoveryEncrypt() {
+ keyring = new KmsKeyring(dataKeyEncryptionDao, null, null);
+
+ EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+ keyring.onEncrypt(encryptionMaterials);
+
+ assertFalse(encryptionMaterials.hasPlaintextDataKey());
+ assertEquals(0, encryptionMaterials.getKeyringTrace().getEntries().size());
+ }
+
+ @Test
+ void testEncryptNoGeneratorOrPlaintextDataKey() {
+ List keyIds = new ArrayList<>();
+ keyIds.add(KEY_ID_1);
+ keyring = new KmsKeyring(dataKeyEncryptionDao, keyIds, null);
+
+ EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE).build();
+ assertThrows(AwsCryptoException.class, () -> keyring.onEncrypt(encryptionMaterials));
+ }
+
+ @Test
+ void testDecryptFirstKeyFails() {
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenThrow(new CannotUnwrapDataKeyException());
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(ENCRYPTED_KEY_1);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testDecryptFirstKeyWrongProvider() {
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ EncryptedDataKey wrongProviderKey = new KeyBlob("OtherProvider", KEY_ID_1.getBytes(PROVIDER_ENCODING), new byte[]{});
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(wrongProviderKey);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testDiscoveryDecrypt() {
+ keyring = new KmsKeyring(dataKeyEncryptionDao, null, null);
+
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ List encryptedDataKeys = new ArrayList<>();
+ encryptedDataKeys.add(ENCRYPTED_KEY_1);
+ encryptedDataKeys.add(ENCRYPTED_KEY_2);
+ keyring.onDecrypt(decryptionMaterials, encryptedDataKeys);
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+
+ KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_1, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT);
+ assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0));
+ }
+
+ @Test
+ void testDecryptAlreadyDecryptedDataKey() {
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .plaintextDataKey(PLAINTEXT_DATA_KEY)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .build();
+
+ keyring.onDecrypt(decryptionMaterials, Collections.singletonList(ENCRYPTED_GENERATOR_KEY));
+
+ assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey());
+ assertEquals(0, decryptionMaterials.getKeyringTrace().getEntries().size());
+ }
+
+ @Test
+ void testDecryptNoDataKey() {
+ DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE)
+ .encryptionContext(ENCRYPTION_CONTEXT)
+ .keyringTrace(new KeyringTrace())
+ .build();
+
+ keyring.onDecrypt(decryptionMaterials, Collections.emptyList());
+
+ assertFalse(decryptionMaterials.hasPlaintextDataKey());
+ assertEquals(0, decryptionMaterials.getKeyringTrace().getEntries().size());
+ }
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java
new file mode 100644
index 000000000..b85d6f76f
--- /dev/null
+++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java
@@ -0,0 +1,201 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.AmazonWebServiceRequest;
+import com.amazonaws.RequestClientOptions;
+import com.amazonaws.encryptionsdk.CryptoAlgorithm;
+import com.amazonaws.encryptionsdk.EncryptedDataKey;
+import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
+import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException;
+import com.amazonaws.encryptionsdk.internal.VersionInfo;
+import com.amazonaws.encryptionsdk.model.KeyBlob;
+import com.amazonaws.services.kms.AWSKMS;
+import com.amazonaws.services.kms.model.DecryptRequest;
+import com.amazonaws.services.kms.model.DecryptResult;
+import com.amazonaws.services.kms.model.EncryptRequest;
+import com.amazonaws.services.kms.model.GenerateDataKeyRequest;
+import com.amazonaws.services.kms.model.KMSInvalidStateException;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+
+import javax.crypto.SecretKey;
+import javax.crypto.spec.SecretKeySpec;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate;
+import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.isA;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+class KmsDataKeyEncryptionDaoTest {
+
+ private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256;
+ private static final SecretKey DATA_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo());
+ private static final List GRANT_TOKENS = Collections.singletonList("testGrantToken");
+ private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue");
+ private static final EncryptedDataKey ENCRYPTED_DATA_KEY = new KeyBlob(KMS_PROVIDER_ID,
+ "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210".getBytes(EncryptedDataKey.PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength()));
+
+ @Test
+ void testEncryptAndDecrypt() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+ EncryptedDataKey encryptedDataKeyResult = dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT);
+
+ ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class);
+ verify(client, times(1)).encrypt(er.capture());
+
+ EncryptRequest actualRequest = er.getValue();
+
+ assertEquals(keyId, actualRequest.getKeyId());
+ assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens());
+ assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext());
+ assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array());
+ assertUserAgent(actualRequest);
+
+ assertEquals(KMS_PROVIDER_ID, encryptedDataKeyResult.getProviderId());
+ assertArrayEquals(keyId.getBytes(EncryptedDataKey.PROVIDER_ENCODING), encryptedDataKeyResult.getProviderInformation());
+ assertNotNull(encryptedDataKeyResult.getEncryptedDataKey());
+
+ DataKeyEncryptionDao.DecryptDataKeyResult decryptDataKeyResult = dao.decryptDataKey(encryptedDataKeyResult, ALGORITHM_SUITE, ENCRYPTION_CONTEXT);
+
+ ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class);
+ verify(client, times(1)).decrypt(decrypt.capture());
+
+ DecryptRequest actualDecryptRequest = decrypt.getValue();
+ assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens());
+ assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext());
+ assertArrayEquals(encryptedDataKeyResult.getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array());
+ assertUserAgent(actualDecryptRequest);
+
+ assertEquals(DATA_KEY, decryptDataKeyResult.getPlaintextDataKey());
+ assertEquals(keyId, decryptDataKeyResult.getKeyArn());
+ }
+
+ @Test
+ void testGenerateAndDecrypt() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+ DataKeyEncryptionDao.GenerateDataKeyResult generateDataKeyResult = dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT);
+
+ ArgumentCaptor gr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class);
+ verify(client, times(1)).generateDataKey(gr.capture());
+
+ GenerateDataKeyRequest actualRequest = gr.getValue();
+
+ assertEquals(keyId, actualRequest.getKeyId());
+ assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens());
+ assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext());
+ assertEquals(ALGORITHM_SUITE.getDataKeyLength(), actualRequest.getNumberOfBytes());
+ assertUserAgent(actualRequest);
+
+ assertNotNull(generateDataKeyResult.getPlaintextDataKey());
+ assertEquals(ALGORITHM_SUITE.getDataKeyLength(), generateDataKeyResult.getPlaintextDataKey().getEncoded().length);
+ assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), generateDataKeyResult.getPlaintextDataKey().getAlgorithm());
+ assertNotNull(generateDataKeyResult.getEncryptedDataKey());
+
+ DataKeyEncryptionDao.DecryptDataKeyResult decryptDataKeyResult = dao.decryptDataKey(generateDataKeyResult.getEncryptedDataKey(), ALGORITHM_SUITE, ENCRYPTION_CONTEXT);
+
+ ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class);
+ verify(client, times(1)).decrypt(decrypt.capture());
+
+ DecryptRequest actualDecryptRequest = decrypt.getValue();
+ assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens());
+ assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext());
+ assertArrayEquals(generateDataKeyResult.getEncryptedDataKey().getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array());
+ assertUserAgent(actualDecryptRequest);
+
+ assertEquals(generateDataKeyResult.getPlaintextDataKey(), decryptDataKeyResult.getPlaintextDataKey());
+ assertEquals(keyId, decryptDataKeyResult.getKeyArn());
+ }
+
+ @Test
+ void testEncryptWrongKeyFormat() {
+ SecretKey key = mock(SecretKey.class);
+ when(key.getFormat()).thenReturn("BadFormat");
+
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+
+ assertThrows(IllegalArgumentException.class, () -> dao.encryptDataKey(keyId, key, ENCRYPTION_CONTEXT));
+ }
+
+ @Test
+ void testKmsFailure() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ String keyId = client.createKey().getKeyMetadata().getArn();
+ doThrow(new KMSInvalidStateException("fail")).when(client).generateDataKey(isA(GenerateDataKeyRequest.class));
+ doThrow(new KMSInvalidStateException("fail")).when(client).encrypt(isA(EncryptRequest.class));
+ doThrow(new KMSInvalidStateException("fail")).when(client).decrypt(isA(DecryptRequest.class));
+
+ assertThrows(AwsCryptoException.class, () -> dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ assertThrows(AwsCryptoException.class, () -> dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT));
+ assertThrows(AwsCryptoException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ }
+
+ @Test
+ void testDecryptBadKmsKeyId() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ DecryptResult badResult = new DecryptResult();
+ badResult.setKeyId("badKeyId");
+
+ doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class));
+
+ assertThrows(MismatchedDataKeyException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ }
+
+ @Test
+ void testDecryptBadKmsKeyLength() {
+ AWSKMS client = spy(new MockKMSClient());
+ DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS);
+
+ DecryptResult badResult = new DecryptResult();
+ badResult.setKeyId(new String(ENCRYPTED_DATA_KEY.getProviderInformation(), EncryptedDataKey.PROVIDER_ENCODING));
+ badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength() + 1));
+
+ doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class));
+
+ assertThrows(IllegalStateException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
+ }
+
+ private void assertUserAgent(AmazonWebServiceRequest request) {
+ assertTrue(request.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT)
+ .contains(VersionInfo.USER_AGENT));
+ }
+}
diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java
new file mode 100644
index 000000000..4682614fb
--- /dev/null
+++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
+ * in compliance with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.amazonaws.encryptionsdk.kms;
+
+import com.amazonaws.encryptionsdk.exception.MalformedArnException;
+import com.amazonaws.services.kms.AWSKMS;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+@ExtendWith(MockitoExtension.class)
+class KmsUtilsTest {
+
+ private static final String VALID_ARN = "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210";
+ private static final String VALID_ALIAS_ARN = "arn:aws:kms:us-east-1:999999999999:alias/MyCryptoKey";
+ private static final String VALID_ALIAS = "alias/MyCryptoKey";
+
+ @Mock
+ private AWSKMS client;
+
+
+ @Test
+ void testGetClientByArn() {
+ assertEquals(client, KmsUtils.getClientByArn(VALID_ARN, s -> client));
+ assertEquals(client, KmsUtils.getClientByArn(VALID_ALIAS_ARN, s -> client));
+ assertEquals(client, KmsUtils.getClientByArn(VALID_ALIAS, s -> client));
+ assertThrows(MalformedArnException.class, () -> KmsUtils.getClientByArn("invalid", s -> client));
+
+ }
+
+ @Test
+ void testIsArnWellFormed() {
+ assertTrue(KmsUtils.isArnWellFormed(VALID_ARN));
+ assertTrue(KmsUtils.isArnWellFormed(VALID_ALIAS_ARN));
+ assertTrue(KmsUtils.isArnWellFormed(VALID_ALIAS));
+ assertFalse(KmsUtils.isArnWellFormed("invalid"));
+
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java
index 37fe9cbff..00ce5c074 100644
--- a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java
+++ b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java
@@ -29,7 +29,7 @@
import com.amazonaws.ResponseMetadata;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
-import com.amazonaws.services.kms.AWSKMSClient;
+import com.amazonaws.services.kms.AbstractAWSKMS;
import com.amazonaws.services.kms.model.CreateAliasRequest;
import com.amazonaws.services.kms.model.CreateAliasResult;
import com.amazonaws.services.kms.model.CreateGrantRequest;
@@ -85,7 +85,7 @@
import com.amazonaws.services.kms.model.UpdateKeyDescriptionRequest;
import com.amazonaws.services.kms.model.UpdateKeyDescriptionResult;
-public class MockKMSClient extends AWSKMSClient {
+public class MockKMSClient extends AbstractAWSKMS {
private static final SecureRandom rnd = new SecureRandom();
private static final String ACCOUNT_ID = "01234567890";
private final Map results_ = new HashMap<>();