diff --git a/pom.xml b/pom.xml index c426db5e6..a8774878b 100644 --- a/pom.xml +++ b/pom.xml @@ -42,7 +42,7 @@ com.amazonaws aws-java-sdk - 1.11.561 + 1.11.677 true diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java new file mode 100644 index 000000000..58f78833c --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/MalformedArnException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.exception; + +/** + * This exception is thrown when an Amazon Resource Name is provided that does not + * match the CMK Alias or ARN format. + */ +public class MalformedArnException extends AwsCryptoException { + + private static final long serialVersionUID = -1L; + + public MalformedArnException() { + super(); + } + + public MalformedArnException(final String message) { + super(message); + } + + public MalformedArnException(final Throwable cause) { + super(cause); + } + + public MalformedArnException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java new file mode 100644 index 000000000..aa1c87799 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/MismatchedDataKeyException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.exception; + +/** + * This exception is thrown when the key used by KMS to decrypt a data key does not + * match the provider information contained within the encrypted data key. + */ +public class MismatchedDataKeyException extends IllegalStateException { + + private static final long serialVersionUID = -1L; + + public MismatchedDataKeyException() { + super(); + } + + public MismatchedDataKeyException(final String message) { + super(message); + } + + public MismatchedDataKeyException(final Throwable cause) { + super(cause); + } + + public MismatchedDataKeyException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java new file mode 100644 index 000000000..1ad8543d8 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyring.java @@ -0,0 +1,171 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.keyrings; + +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; +import com.amazonaws.encryptionsdk.exception.MalformedArnException; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.DecryptDataKeyResult; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.GenerateDataKeyResult; +import com.amazonaws.encryptionsdk.kms.KmsUtils; + +import java.util.ArrayList; +import java.util.List; + +import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.isArnWellFormed; +import static java.util.Collections.emptyList; +import static java.util.Collections.unmodifiableList; +import static java.util.Objects.requireNonNull; + +/** + * A keyring which interacts with AWS Key Management Service (KMS) to create, + * encrypt, and decrypt data keys using KMS defined Customer Master Keys (CMKs). + */ +public class KmsKeyring implements Keyring { + + private final DataKeyEncryptionDao dataKeyEncryptionDao; + private final List keyIds; + private final String generatorKeyId; + private final boolean isDiscovery; + + KmsKeyring(DataKeyEncryptionDao dataKeyEncryptionDao, List keyIds, String generatorKeyId) { + requireNonNull(dataKeyEncryptionDao, "dataKeyEncryptionDao is required"); + this.dataKeyEncryptionDao = dataKeyEncryptionDao; + this.keyIds = keyIds == null ? emptyList() : unmodifiableList(keyIds); + this.generatorKeyId = generatorKeyId; + this.isDiscovery = this.generatorKeyId == null && this.keyIds.isEmpty(); + + if (!this.keyIds.stream().allMatch(KmsUtils::isArnWellFormed)) { + throw new MalformedArnException("keyIds must contain only CMK aliases and well formed ARNs"); + } + + if (generatorKeyId != null) { + if (!isArnWellFormed(generatorKeyId)) { + throw new MalformedArnException("generatorKeyId must be either a CMK alias or a well formed ARN"); + } + if (this.keyIds.contains(generatorKeyId)) { + throw new IllegalArgumentException("KeyIds should not contain the generatorKeyId"); + } + } + } + + @Override + public void onEncrypt(EncryptionMaterials encryptionMaterials) { + requireNonNull(encryptionMaterials, "encryptionMaterials are required"); + + // If this keyring is a discovery keyring, OnEncrypt MUST return the input encryption materials unmodified. + if (isDiscovery) { + return; + } + + // If the input encryption materials do not contain a plaintext data key and this keyring does not + // have a generator defined, OnEncrypt MUST not modify the encryption materials and MUST fail. + if (!encryptionMaterials.hasPlaintextDataKey() && generatorKeyId == null) { + throw new AwsCryptoException("Encryption materials must contain either a plaintext data key or a generator"); + } + + final List keyIdsToEncrypt = new ArrayList<>(keyIds); + + // If the input encryption materials do not contain a plaintext data key and a generator is defined onEncrypt + // MUST attempt to generate a new plaintext data key and encrypt that data key by calling KMS GenerateDataKey. + if (!encryptionMaterials.hasPlaintextDataKey()) { + generateDataKey(encryptionMaterials); + } else { + // If this keyring's generator is defined and was not used to generate a data key, OnEncrypt + // MUST also attempt to encrypt the plaintext data key using the CMK specified by the generator. + keyIdsToEncrypt.add(generatorKeyId); + } + + // Given a plaintext data key in the encryption materials, OnEncrypt MUST attempt + // to encrypt the plaintext data key using each CMK specified in it's key IDs list. + for (String keyId : keyIdsToEncrypt) { + encryptDataKey(keyId, encryptionMaterials); + } + } + + private void generateDataKey(final EncryptionMaterials encryptionMaterials) { + final GenerateDataKeyResult result = dataKeyEncryptionDao.generateDataKey(generatorKeyId, + encryptionMaterials.getAlgorithmSuite(), encryptionMaterials.getEncryptionContext()); + + encryptionMaterials.setPlaintextDataKey(result.getPlaintextDataKey(), + new KeyringTraceEntry(KMS_PROVIDER_ID, generatorKeyId, KeyringTraceFlag.GENERATED_DATA_KEY)); + encryptionMaterials.addEncryptedDataKey(result.getEncryptedDataKey(), + new KeyringTraceEntry(KMS_PROVIDER_ID, generatorKeyId, KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT)); + } + + private void encryptDataKey(final String keyId, final EncryptionMaterials encryptionMaterials) { + final EncryptedDataKey encryptedDataKey = dataKeyEncryptionDao.encryptDataKey(keyId, + encryptionMaterials.getPlaintextDataKey(), encryptionMaterials.getEncryptionContext()); + + encryptionMaterials.addEncryptedDataKey(encryptedDataKey, + new KeyringTraceEntry(KMS_PROVIDER_ID, keyId, KeyringTraceFlag.ENCRYPTED_DATA_KEY, KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT)); + } + + @Override + public void onDecrypt(DecryptionMaterials decryptionMaterials, List encryptedDataKeys) { + requireNonNull(decryptionMaterials, "decryptionMaterials are required"); + requireNonNull(encryptedDataKeys, "encryptedDataKeys are required"); + + if (decryptionMaterials.hasPlaintextDataKey() || encryptedDataKeys.isEmpty()) { + return; + } + + if (!encryptedDataKeys.stream() + .filter(edk -> edk.getProviderId().equals(KMS_PROVIDER_ID)) + .map(edk -> new String(edk.getProviderInformation(), PROVIDER_ENCODING)) + .allMatch(KmsUtils::isArnWellFormed)) { + throw new MalformedArnException("encryptedDataKeys contains a malformed ARN"); + } + + for (EncryptedDataKey encryptedDataKey : encryptedDataKeys) { + if (okToDecrypt(encryptedDataKey)) { + try { + final DecryptDataKeyResult result = dataKeyEncryptionDao.decryptDataKey(encryptedDataKey, + decryptionMaterials.getAlgorithmSuite(), decryptionMaterials.getEncryptionContext()); + + decryptionMaterials.setPlaintextDataKey(result.getPlaintextDataKey(), + new KeyringTraceEntry(KMS_PROVIDER_ID, result.getKeyArn(), + KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT)); + return; + } catch (CannotUnwrapDataKeyException e) { + continue; + } + } + } + } + + private boolean okToDecrypt(EncryptedDataKey encryptedDataKey) { + // Only attempt to decrypt keys provided by KMS + if (!encryptedDataKey.getProviderId().equals(KMS_PROVIDER_ID)) { + return false; + } + + // If this keyring is a discovery keyring, OnDecrypt MUST attempt to + // decrypt every encrypted data key in the input encrypted data key list + if (isDiscovery) { + return true; + } + + final String providerInfo = new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING); + + // OnDecrypt MUST attempt to decrypt each input encrypted data key in the input + // encrypted data key list where the key provider info has a value equal to one + // of the ARNs in this keyring's key IDs or the generator + return providerInfo.equals(generatorKeyId) || keyIds.stream().anyMatch(providerInfo::equals); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java b/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java new file mode 100644 index 000000000..4267ba6f8 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/DataKeyEncryptionDao.java @@ -0,0 +1,104 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.EncryptedDataKey; + +import javax.crypto.SecretKey; +import java.util.List; +import java.util.Map; + +public interface DataKeyEncryptionDao { + + /** + * Generates a unique data key, returning both the plaintext copy of the key and an encrypted copy encrypted using + * the customer master key specified by the given keyId. + * + * @param keyId The customer master key to encrypt the generated key with. + * @param algorithmSuite The algorithm suite associated with the key. + * @param encryptionContext The encryption context. + * @return GenerateDataKeyResult containing the plaintext data key and the encrypted data key. + */ + GenerateDataKeyResult generateDataKey(String keyId, CryptoAlgorithm algorithmSuite, Map encryptionContext); + + /** + * Encrypts the given plaintext data key using the customer aster key specified by the given keyId. + * + * @param keyId The customer master key to encrypt the plaintext data key with. + * @param plaintextDataKey The plaintext data key to encrypt. + * @param encryptionContext The encryption context. + * @return The encrypted data key. + */ + EncryptedDataKey encryptDataKey(final String keyId, SecretKey plaintextDataKey, Map encryptionContext); + + /** + * Decrypted the given encrypted data key. + * + * @param encryptedDataKey The encrypted data key to decrypt. + * @param algorithmSuite The algorithm suite associated with the key. + * @param encryptionContext The encryption context. + * @return DecryptDataKeyResult containing the plaintext data key and the ARN of the key that decrypted it. + */ + DecryptDataKeyResult decryptDataKey(EncryptedDataKey encryptedDataKey, CryptoAlgorithm algorithmSuite, Map encryptionContext); + + /** + * Constructs an instance of DataKeyEncryptionDao that uses AWS Key Management Service (KMS) for + * generation, encryption, and decryption of data keys. + * + * @param clientSupplier A supplier of AWSKMS clients + * @param grantTokens A list of grant tokens to supply to KMS + * @return The DataKeyEncryptionDao + */ + static DataKeyEncryptionDao kms(KmsClientSupplier clientSupplier, List grantTokens) { + return new KmsDataKeyEncryptionDao(clientSupplier, grantTokens); + } + + class GenerateDataKeyResult { + private final SecretKey plaintextDataKey; + private final EncryptedDataKey encryptedDataKey; + + public GenerateDataKeyResult(SecretKey plaintextDataKey, EncryptedDataKey encryptedDataKey) { + this.plaintextDataKey = plaintextDataKey; + this.encryptedDataKey = encryptedDataKey; + } + + public SecretKey getPlaintextDataKey() { + return plaintextDataKey; + } + + public EncryptedDataKey getEncryptedDataKey() { + return encryptedDataKey; + } + } + + class DecryptDataKeyResult { + private final String keyArn; + private final SecretKey plaintextDataKey; + + public DecryptDataKeyResult(String keyArn, SecretKey plaintextDataKey) { + this.keyArn = keyArn; + this.plaintextDataKey = plaintextDataKey; + } + + public String getKeyArn() { + return keyArn; + } + + public SecretKey getPlaintextDataKey() { + return plaintextDataKey; + } + + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java new file mode 100644 index 000000000..55ef2678b --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsClientSupplier.java @@ -0,0 +1,34 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.services.kms.AWSKMS; + +import javax.annotation.Nullable; + +/** + * Represents a function that accepts an AWS region and returns an {@code AWSKMS} client for that region. The + * function should be able to handle when the region is null. + */ +@FunctionalInterface +public interface KmsClientSupplier { + + /** + * Gets an {@code AWSKMS} client for the given regionId. + * + * @param regionId The AWS region (or null) + * @return The AWSKMS client + */ + AWSKMS getClient(@Nullable String regionId); +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java new file mode 100644 index 000000000..f67846c9f --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDao.java @@ -0,0 +1,171 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.AmazonServiceException; +import com.amazonaws.AmazonWebServiceRequest; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; +import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException; +import com.amazonaws.encryptionsdk.internal.VersionInfo; +import com.amazonaws.encryptionsdk.model.KeyBlob; +import com.amazonaws.services.kms.model.DecryptRequest; +import com.amazonaws.services.kms.model.EncryptRequest; +import com.amazonaws.services.kms.model.GenerateDataKeyRequest; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.getClientByArn; +import static java.util.Objects.requireNonNull; +import static org.apache.commons.lang3.Validate.isTrue; + +/** + * An implementation of DataKeyEncryptionDao that uses AWS Key Management Service (KMS) for + * generation, encryption, and decryption of data keys. The KmsMethods interface is implemented + * to allow usage in KmsMasterKey. + */ +class KmsDataKeyEncryptionDao implements DataKeyEncryptionDao, KmsMethods { + + private final KmsClientSupplier clientSupplier; + private List grantTokens; + + KmsDataKeyEncryptionDao(KmsClientSupplier clientSupplier, List grantTokens) { + requireNonNull(clientSupplier, "clientSupplier is required"); + + this.clientSupplier = clientSupplier; + this.grantTokens = grantTokens == null ? new ArrayList<>() : new ArrayList<>(grantTokens); + } + + @Override + public GenerateDataKeyResult generateDataKey(String keyId, CryptoAlgorithm algorithmSuite, Map encryptionContext) { + requireNonNull(keyId, "keyId is required"); + requireNonNull(algorithmSuite, "algorithmSuite is required"); + requireNonNull(encryptionContext, "encryptionContext is required"); + + final com.amazonaws.services.kms.model.GenerateDataKeyResult kmsResult; + + try { + kmsResult = getClientByArn(keyId, clientSupplier) + .generateDataKey(updateUserAgent( + new GenerateDataKeyRequest() + .withKeyId(keyId) + .withNumberOfBytes(algorithmSuite.getDataKeyLength()) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens))); + } catch (final AmazonServiceException ex) { + throw new AwsCryptoException(ex); + } + + final byte[] rawKey = new byte[algorithmSuite.getDataKeyLength()]; + kmsResult.getPlaintext().get(rawKey); + if (kmsResult.getPlaintext().remaining() > 0) { + throw new IllegalStateException("Received an unexpected number of bytes from KMS"); + } + final byte[] encryptedKey = new byte[kmsResult.getCiphertextBlob().remaining()]; + kmsResult.getCiphertextBlob().get(encryptedKey); + + return new GenerateDataKeyResult(new SecretKeySpec(rawKey, algorithmSuite.getDataKeyAlgo()), + new KeyBlob(KMS_PROVIDER_ID, kmsResult.getKeyId().getBytes(PROVIDER_ENCODING), encryptedKey)); + } + + @Override + public EncryptedDataKey encryptDataKey(final String keyId, SecretKey plaintextDataKey, Map encryptionContext) { + requireNonNull(keyId, "keyId is required"); + requireNonNull(plaintextDataKey, "plaintextDataKey is required"); + requireNonNull(encryptionContext, "encryptionContext is required"); + isTrue(plaintextDataKey.getFormat().equals("RAW"), "Only RAW encoded keys are supported"); + + final com.amazonaws.services.kms.model.EncryptResult kmsResult; + + try { + kmsResult = getClientByArn(keyId, clientSupplier) + .encrypt(updateUserAgent(new EncryptRequest() + .withKeyId(keyId) + .withPlaintext(ByteBuffer.wrap(plaintextDataKey.getEncoded())) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens))); + } catch (final AmazonServiceException ex) { + throw new AwsCryptoException(ex); + } + final byte[] encryptedDataKey = new byte[kmsResult.getCiphertextBlob().remaining()]; + kmsResult.getCiphertextBlob().get(encryptedDataKey); + + return new KeyBlob(KMS_PROVIDER_ID, kmsResult.getKeyId().getBytes(PROVIDER_ENCODING), encryptedDataKey); + + } + + @Override + public DecryptDataKeyResult decryptDataKey(EncryptedDataKey encryptedDataKey, CryptoAlgorithm algorithmSuite, Map encryptionContext) { + requireNonNull(encryptedDataKey, "encryptedDataKey is required"); + requireNonNull(algorithmSuite, "algorithmSuite is required"); + requireNonNull(encryptionContext, "encryptionContext is required"); + + final String providerInformation = new String(encryptedDataKey.getProviderInformation(), PROVIDER_ENCODING); + final com.amazonaws.services.kms.model.DecryptResult kmsResult; + + try { + kmsResult = getClientByArn(providerInformation, clientSupplier) + .decrypt(updateUserAgent(new DecryptRequest() + .withCiphertextBlob(ByteBuffer.wrap(encryptedDataKey.getEncryptedDataKey())) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens))); + } catch (final AmazonServiceException ex) { + throw new CannotUnwrapDataKeyException(ex); + } + + if (!kmsResult.getKeyId().equals(providerInformation)) { + throw new MismatchedDataKeyException("Received an unexpected key Id from KMS"); + } + + final byte[] rawKey = new byte[algorithmSuite.getDataKeyLength()]; + kmsResult.getPlaintext().get(rawKey); + if (kmsResult.getPlaintext().remaining() > 0) { + throw new IllegalStateException("Received an unexpected number of bytes from KMS"); + } + + return new DecryptDataKeyResult(kmsResult.getKeyId(), new SecretKeySpec(rawKey, algorithmSuite.getDataKeyAlgo())); + + } + + private T updateUserAgent(T request) { + request.getRequestClientOptions().appendUserAgent(VersionInfo.USER_AGENT); + + return request; + } + + @Override + public void setGrantTokens(List grantTokens) { + this.grantTokens = new ArrayList<>(grantTokens); + } + + @Override + public List getGrantTokens() { + return Collections.unmodifiableList(grantTokens); + } + + @Override + public void addGrantToken(String grantToken) { + grantTokens.add(grantToken); + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java index b78840221..4b6db613c 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java @@ -14,17 +14,12 @@ package com.amazonaws.encryptionsdk.kms; import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.function.Supplier; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.AmazonWebServiceRequest; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.encryptionsdk.AwsCrypto; @@ -34,31 +29,20 @@ import com.amazonaws.encryptionsdk.MasterKey; import com.amazonaws.encryptionsdk.MasterKeyProvider; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException; import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; -import com.amazonaws.encryptionsdk.internal.VersionInfo; import com.amazonaws.services.kms.AWSKMS; -import com.amazonaws.services.kms.model.DecryptRequest; -import com.amazonaws.services.kms.model.DecryptResult; -import com.amazonaws.services.kms.model.EncryptRequest; -import com.amazonaws.services.kms.model.EncryptResult; -import com.amazonaws.services.kms.model.GenerateDataKeyRequest; -import com.amazonaws.services.kms.model.GenerateDataKeyResult; + +import static java.util.Collections.emptyList; /** * Represents a single Customer Master Key (CMK) and is used to encrypt/decrypt data with * {@link AwsCrypto}. */ public final class KmsMasterKey extends MasterKey implements KmsMethods { - private final Supplier kms_; + private final KmsDataKeyEncryptionDao dataKeyEncryptionDao; private final MasterKeyProvider sourceProvider_; private final String id_; - private final List grantTokens_ = new ArrayList<>(); - - private T updateUserAgent(T request) { - request.getRequestClientOptions().appendUserAgent(VersionInfo.USER_AGENT); - - return request; - } /** * @@ -84,7 +68,7 @@ static KmsMasterKey getInstance(final Supplier kms, final String id, } private KmsMasterKey(final Supplier kms, final String id, final MasterKeyProvider provider) { - kms_ = kms; + dataKeyEncryptionDao = new KmsDataKeyEncryptionDao(s -> kms.get(), emptyList()); id_ = id; sourceProvider_ = provider; } @@ -102,39 +86,27 @@ public String getKeyId() { @Override public DataKey generateDataKey(final CryptoAlgorithm algorithm, final Map encryptionContext) { - final GenerateDataKeyResult gdkResult = kms_.get().generateDataKey(updateUserAgent( - new GenerateDataKeyRequest() - .withKeyId(getKeyId()) - .withNumberOfBytes(algorithm.getDataKeyLength()) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_) - )); - final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; - gdkResult.getPlaintext().get(rawKey); - if (gdkResult.getPlaintext().remaining() > 0) { - throw new IllegalStateException("Recieved an unexpected number of bytes from KMS"); - } - final byte[] encryptedKey = new byte[gdkResult.getCiphertextBlob().remaining()]; - gdkResult.getCiphertextBlob().get(encryptedKey); - - final SecretKeySpec key = new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()); - return new DataKey<>(key, encryptedKey, gdkResult.getKeyId().getBytes(StandardCharsets.UTF_8), this); + final DataKeyEncryptionDao.GenerateDataKeyResult gdkResult = dataKeyEncryptionDao.generateDataKey( + getKeyId(), algorithm, encryptionContext); + return new DataKey<>(gdkResult.getPlaintextDataKey(), + gdkResult.getEncryptedDataKey().getEncryptedDataKey(), + gdkResult.getEncryptedDataKey().getProviderInformation(), + this); } @Override public void setGrantTokens(final List grantTokens) { - grantTokens_.clear(); - grantTokens_.addAll(grantTokens); + dataKeyEncryptionDao.setGrantTokens(grantTokens); } @Override public List getGrantTokens() { - return grantTokens_; + return dataKeyEncryptionDao.getGrantTokens(); } @Override public void addGrantToken(final String grantToken) { - grantTokens_.add(grantToken); + dataKeyEncryptionDao.addGrantToken(grantToken); } @Override @@ -142,22 +114,12 @@ public DataKey encryptDataKey(final CryptoAlgorithm algorithm, final Map encryptionContext, final DataKey dataKey) { final SecretKey key = dataKey.getKey(); - if (!key.getFormat().equals("RAW")) { - throw new IllegalArgumentException("Only RAW encoded keys are supported"); - } - try { - final EncryptResult encryptResult = kms_.get().encrypt(updateUserAgent( - new EncryptRequest() - .withKeyId(id_) - .withPlaintext(ByteBuffer.wrap(key.getEncoded())) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_))); - final byte[] edk = new byte[encryptResult.getCiphertextBlob().remaining()]; - encryptResult.getCiphertextBlob().get(edk); - return new DataKey<>(dataKey.getKey(), edk, encryptResult.getKeyId().getBytes(StandardCharsets.UTF_8), this); - } catch (final AmazonServiceException asex) { - throw new AwsCryptoException(asex); - } + final EncryptedDataKey encryptedDataKey = dataKeyEncryptionDao.encryptDataKey(id_, key, encryptionContext); + + return new DataKey<>(dataKey.getKey(), + encryptedDataKey.getEncryptedDataKey(), + encryptedDataKey.getProviderInformation(), + this); } @Override @@ -168,24 +130,16 @@ public DataKey decryptDataKey(final CryptoAlgorithm algorithm, final List exceptions = new ArrayList<>(); for (final EncryptedDataKey edk : encryptedDataKeys) { try { - final DecryptResult decryptResult = kms_.get().decrypt(updateUserAgent( - new DecryptRequest() - .withCiphertextBlob(ByteBuffer.wrap(edk.getEncryptedDataKey())) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_))); - if (decryptResult.getKeyId().equals(id_)) { - final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; - decryptResult.getPlaintext().get(rawKey); - if (decryptResult.getPlaintext().remaining() > 0) { - throw new IllegalStateException("Received an unexpected number of bytes from KMS"); - } - return new DataKey<>( - new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()), - edk.getEncryptedDataKey(), - edk.getProviderInformation(), this); - } - } catch (final AmazonServiceException awsex) { - exceptions.add(awsex); + final DataKeyEncryptionDao.DecryptDataKeyResult result = dataKeyEncryptionDao.decryptDataKey(edk, algorithm, encryptionContext); + return new DataKey<>( + result.getPlaintextDataKey(), + edk.getEncryptedDataKey(), + edk.getProviderInformation(), this); + } catch (final AwsCryptoException | MismatchedDataKeyException ex) { + // Earlier versions of KmsMasterKey compare the returned keyId to the encryptedDataKey + // provider information and skip that key if it doesn't match. KmsDataKeyEncryptionDao + // throws MismatchedDataKeyException in this case, so this maintains the existing behavior. + exceptions.add(ex); } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java new file mode 100644 index 000000000..35ddc37e2 --- /dev/null +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsUtils.java @@ -0,0 +1,71 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.arn.Arn; +import com.amazonaws.encryptionsdk.exception.MalformedArnException; +import com.amazonaws.services.kms.AWSKMS; + +public class KmsUtils { + + private static final String ALIAS_PREFIX = "alias/"; + /** + * The provider ID used for the KmsKeyring + */ + public static final String KMS_PROVIDER_ID = "aws-kms"; + + /** + * Parses region from the given arn (if possible) and passes that region to the + * given clientSupplier to produce an {@code AWSKMS} client. + * + * @param arn The Amazon Resource Name or Key Alias + * @param clientSupplier The client supplier + * @return AWSKMS The client + * @throws MalformedArnException if the arn is malformed + */ + public static AWSKMS getClientByArn(String arn, KmsClientSupplier clientSupplier) throws MalformedArnException { + if (isKeyAlias(arn)) { + return clientSupplier.getClient(null); + } + + try { + return clientSupplier.getClient(Arn.fromString(arn).getRegion()); + } catch (IllegalArgumentException e) { + throw new MalformedArnException(e); + } + } + + /** + * Returns true if the given arn is a well formed Amazon Resource Name or Key Alias + * + * @param arn The Amazon Resource Name or Key Alias + * @return True if well formed, false otherwise + */ + public static boolean isArnWellFormed(String arn) { + if (isKeyAlias(arn)) { + return true; + } + + try { + Arn.fromString(arn); + return true; + } catch (IllegalArgumentException e) { + return false; + } + } + + private static boolean isKeyAlias(String arn) { + return arn.startsWith(ALIAS_PREFIX) && !arn.contains(":"); + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java b/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java new file mode 100644 index 000000000..0b8034b9a --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/keyrings/KmsKeyringTest.java @@ -0,0 +1,298 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.keyrings; + +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; +import com.amazonaws.encryptionsdk.exception.MalformedArnException; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao; +import com.amazonaws.encryptionsdk.kms.DataKeyEncryptionDao.DecryptDataKeyResult; +import com.amazonaws.encryptionsdk.model.KeyBlob; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.amazonaws.encryptionsdk.EncryptedDataKey.PROVIDER_ENCODING; +import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; +import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.ENCRYPTED_DATA_KEY; +import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.GENERATED_DATA_KEY; +import static com.amazonaws.encryptionsdk.keyrings.KeyringTraceFlag.SIGNED_ENCRYPTION_CONTEXT; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class KmsKeyringTest { + + private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + private static final SecretKey PLAINTEXT_DATA_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); + private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + private static final String GENERATOR_KEY_ID = "arn:aws:kms:us-east-1:999999999999:key/generator-89ab-cdef-fedc-ba9876543210"; + private static final String KEY_ID_1 = "arn:aws:kms:us-east-1:999999999999:key/key1-23bv-sdfs-werw-234323nfdsf"; + private static final String KEY_ID_2 = "arn:aws:kms:us-east-1:999999999999:key/key2-02ds-wvjs-aswe-a4923489273"; + private static final EncryptedDataKey ENCRYPTED_GENERATOR_KEY = new KeyBlob(KMS_PROVIDER_ID, + GENERATOR_KEY_ID.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey ENCRYPTED_KEY_1 = new KeyBlob(KMS_PROVIDER_ID, + KEY_ID_1.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey ENCRYPTED_KEY_2 = new KeyBlob(KMS_PROVIDER_ID, + KEY_ID_2.getBytes(PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final KeyringTraceEntry ENCRYPTED_GENERATOR_KEY_TRACE = + new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT); + private static final KeyringTraceEntry ENCRYPTED_KEY_1_TRACE = + new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_1, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT); + private static final KeyringTraceEntry ENCRYPTED_KEY_2_TRACE = + new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, ENCRYPTED_DATA_KEY, SIGNED_ENCRYPTION_CONTEXT); + private static final KeyringTraceEntry GENERATED_DATA_KEY_TRACE = + new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, GENERATED_DATA_KEY); + @Mock(lenient = true) private DataKeyEncryptionDao dataKeyEncryptionDao; + private Keyring keyring; + + @BeforeEach + void setup() { + when(dataKeyEncryptionDao.encryptDataKey(GENERATOR_KEY_ID, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_GENERATOR_KEY); + when(dataKeyEncryptionDao.encryptDataKey(KEY_ID_1, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_KEY_1); + when(dataKeyEncryptionDao.encryptDataKey(KEY_ID_2, PLAINTEXT_DATA_KEY, ENCRYPTION_CONTEXT)).thenReturn(ENCRYPTED_KEY_2); + + when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_GENERATOR_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)) + .thenReturn(new DecryptDataKeyResult(GENERATOR_KEY_ID, PLAINTEXT_DATA_KEY)); + when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)) + .thenReturn(new DecryptDataKeyResult(KEY_ID_1, PLAINTEXT_DATA_KEY)); + when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_2, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)) + .thenReturn(new DecryptDataKeyResult(KEY_ID_2, PLAINTEXT_DATA_KEY)); + + List keyIds = new ArrayList<>(); + keyIds.add(KEY_ID_1); + keyIds.add(KEY_ID_2); + keyring = new KmsKeyring(dataKeyEncryptionDao, keyIds, GENERATOR_KEY_ID); + } + + @Test + void testMalformedArns() { + assertThrows(MalformedArnException.class, () -> new KmsKeyring(dataKeyEncryptionDao, null, "badArn")); + assertThrows(MalformedArnException.class, () -> new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList("badArn"), GENERATOR_KEY_ID)); + + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(new KeyBlob(KMS_PROVIDER_ID, "badArn".getBytes(PROVIDER_ENCODING), new byte[]{})); + assertThrows(MalformedArnException.class, () -> keyring.onDecrypt(decryptionMaterials, encryptedDataKeys)); + + // Malformed Arn for a non KMS provider shouldn't fail + encryptedDataKeys.clear(); + encryptedDataKeys.add(new KeyBlob("OtherProviderId", "badArn".getBytes(PROVIDER_ENCODING), new byte[]{})); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + } + + @Test + void testGeneratorKeyInKeyIds() { + assertThrows(IllegalArgumentException.class, () -> new KmsKeyring(dataKeyEncryptionDao, Collections.singletonList(GENERATOR_KEY_ID), GENERATOR_KEY_ID)); + } + + @Test + void testEncryptDecryptExistingDataKey() { + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE) + .plaintextDataKey(PLAINTEXT_DATA_KEY) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + + keyring.onEncrypt(encryptionMaterials); + + assertEquals(3, encryptionMaterials.getEncryptedDataKeys().size()); + assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_GENERATOR_KEY)); + assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_1)); + assertTrue(encryptionMaterials.getEncryptedDataKeys().contains(ENCRYPTED_KEY_2)); + + assertEquals(3, encryptionMaterials.getKeyringTrace().getEntries().size()); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_GENERATOR_KEY_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_2_TRACE)); + + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(ENCRYPTED_GENERATOR_KEY); + encryptedDataKeys.add(ENCRYPTED_KEY_1); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testEncryptNullDataKey() { + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE) + .keyringTrace(new KeyringTrace()) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + + when(dataKeyEncryptionDao.generateDataKey(GENERATOR_KEY_ID, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenReturn(new DataKeyEncryptionDao.GenerateDataKeyResult(PLAINTEXT_DATA_KEY, ENCRYPTED_GENERATOR_KEY)); + keyring.onEncrypt(encryptionMaterials); + + assertEquals(PLAINTEXT_DATA_KEY, encryptionMaterials.getPlaintextDataKey()); + + assertEquals(4, encryptionMaterials.getKeyringTrace().getEntries().size()); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(GENERATED_DATA_KEY_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_GENERATOR_KEY_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_1_TRACE)); + assertTrue(encryptionMaterials.getKeyringTrace().getEntries().contains(ENCRYPTED_KEY_2_TRACE)); + + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(ENCRYPTED_GENERATOR_KEY); + encryptedDataKeys.add(ENCRYPTED_KEY_1); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, GENERATOR_KEY_ID, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testDiscoveryEncrypt() { + keyring = new KmsKeyring(dataKeyEncryptionDao, null, null); + + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + keyring.onEncrypt(encryptionMaterials); + + assertFalse(encryptionMaterials.hasPlaintextDataKey()); + assertEquals(0, encryptionMaterials.getKeyringTrace().getEntries().size()); + } + + @Test + void testEncryptNoGeneratorOrPlaintextDataKey() { + List keyIds = new ArrayList<>(); + keyIds.add(KEY_ID_1); + keyring = new KmsKeyring(dataKeyEncryptionDao, keyIds, null); + + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder(ALGORITHM_SUITE).build(); + assertThrows(AwsCryptoException.class, () -> keyring.onEncrypt(encryptionMaterials)); + } + + @Test + void testDecryptFirstKeyFails() { + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + when(dataKeyEncryptionDao.decryptDataKey(ENCRYPTED_KEY_1, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)).thenThrow(new CannotUnwrapDataKeyException()); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(ENCRYPTED_KEY_1); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testDecryptFirstKeyWrongProvider() { + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + EncryptedDataKey wrongProviderKey = new KeyBlob("OtherProvider", KEY_ID_1.getBytes(PROVIDER_ENCODING), new byte[]{}); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(wrongProviderKey); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_2, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testDiscoveryDecrypt() { + keyring = new KmsKeyring(dataKeyEncryptionDao, null, null); + + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + List encryptedDataKeys = new ArrayList<>(); + encryptedDataKeys.add(ENCRYPTED_KEY_1); + encryptedDataKeys.add(ENCRYPTED_KEY_2); + keyring.onDecrypt(decryptionMaterials, encryptedDataKeys); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + + KeyringTraceEntry expectedKeyringTraceEntry = new KeyringTraceEntry(KMS_PROVIDER_ID, KEY_ID_1, KeyringTraceFlag.DECRYPTED_DATA_KEY, KeyringTraceFlag.VERIFIED_ENCRYPTION_CONTEXT); + assertEquals(expectedKeyringTraceEntry, decryptionMaterials.getKeyringTrace().getEntries().get(0)); + } + + @Test + void testDecryptAlreadyDecryptedDataKey() { + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .plaintextDataKey(PLAINTEXT_DATA_KEY) + .encryptionContext(ENCRYPTION_CONTEXT) + .build(); + + keyring.onDecrypt(decryptionMaterials, Collections.singletonList(ENCRYPTED_GENERATOR_KEY)); + + assertEquals(PLAINTEXT_DATA_KEY, decryptionMaterials.getPlaintextDataKey()); + assertEquals(0, decryptionMaterials.getKeyringTrace().getEntries().size()); + } + + @Test + void testDecryptNoDataKey() { + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder(ALGORITHM_SUITE) + .encryptionContext(ENCRYPTION_CONTEXT) + .keyringTrace(new KeyringTrace()) + .build(); + + keyring.onDecrypt(decryptionMaterials, Collections.emptyList()); + + assertFalse(decryptionMaterials.hasPlaintextDataKey()); + assertEquals(0, decryptionMaterials.getKeyringTrace().getEntries().size()); + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java new file mode 100644 index 000000000..b85d6f76f --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsDataKeyEncryptionDaoTest.java @@ -0,0 +1,201 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.AmazonWebServiceRequest; +import com.amazonaws.RequestClientOptions; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.MismatchedDataKeyException; +import com.amazonaws.encryptionsdk.internal.VersionInfo; +import com.amazonaws.encryptionsdk.model.KeyBlob; +import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.model.DecryptRequest; +import com.amazonaws.services.kms.model.DecryptResult; +import com.amazonaws.services.kms.model.EncryptRequest; +import com.amazonaws.services.kms.model.GenerateDataKeyRequest; +import com.amazonaws.services.kms.model.KMSInvalidStateException; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; +import static com.amazonaws.encryptionsdk.kms.KmsUtils.KMS_PROVIDER_ID; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class KmsDataKeyEncryptionDaoTest { + + private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + private static final SecretKey DATA_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); + private static final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + private static final EncryptedDataKey ENCRYPTED_DATA_KEY = new KeyBlob(KMS_PROVIDER_ID, + "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210".getBytes(EncryptedDataKey.PROVIDER_ENCODING), generate(ALGORITHM_SUITE.getDataKeyLength())); + + @Test + void testEncryptAndDecrypt() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + EncryptedDataKey encryptedDataKeyResult = dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT); + + ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(er.capture()); + + EncryptRequest actualRequest = er.getValue(); + + assertEquals(keyId, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array()); + assertUserAgent(actualRequest); + + assertEquals(KMS_PROVIDER_ID, encryptedDataKeyResult.getProviderId()); + assertArrayEquals(keyId.getBytes(EncryptedDataKey.PROVIDER_ENCODING), encryptedDataKeyResult.getProviderInformation()); + assertNotNull(encryptedDataKeyResult.getEncryptedDataKey()); + + DataKeyEncryptionDao.DecryptDataKeyResult decryptDataKeyResult = dao.decryptDataKey(encryptedDataKeyResult, ALGORITHM_SUITE, ENCRYPTION_CONTEXT); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + + DecryptRequest actualDecryptRequest = decrypt.getValue(); + assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext()); + assertArrayEquals(encryptedDataKeyResult.getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array()); + assertUserAgent(actualDecryptRequest); + + assertEquals(DATA_KEY, decryptDataKeyResult.getPlaintextDataKey()); + assertEquals(keyId, decryptDataKeyResult.getKeyArn()); + } + + @Test + void testGenerateAndDecrypt() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + DataKeyEncryptionDao.GenerateDataKeyResult generateDataKeyResult = dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT); + + ArgumentCaptor gr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + verify(client, times(1)).generateDataKey(gr.capture()); + + GenerateDataKeyRequest actualRequest = gr.getValue(); + + assertEquals(keyId, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertEquals(ALGORITHM_SUITE.getDataKeyLength(), actualRequest.getNumberOfBytes()); + assertUserAgent(actualRequest); + + assertNotNull(generateDataKeyResult.getPlaintextDataKey()); + assertEquals(ALGORITHM_SUITE.getDataKeyLength(), generateDataKeyResult.getPlaintextDataKey().getEncoded().length); + assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), generateDataKeyResult.getPlaintextDataKey().getAlgorithm()); + assertNotNull(generateDataKeyResult.getEncryptedDataKey()); + + DataKeyEncryptionDao.DecryptDataKeyResult decryptDataKeyResult = dao.decryptDataKey(generateDataKeyResult.getEncryptedDataKey(), ALGORITHM_SUITE, ENCRYPTION_CONTEXT); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + + DecryptRequest actualDecryptRequest = decrypt.getValue(); + assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext()); + assertArrayEquals(generateDataKeyResult.getEncryptedDataKey().getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array()); + assertUserAgent(actualDecryptRequest); + + assertEquals(generateDataKeyResult.getPlaintextDataKey(), decryptDataKeyResult.getPlaintextDataKey()); + assertEquals(keyId, decryptDataKeyResult.getKeyArn()); + } + + @Test + void testEncryptWrongKeyFormat() { + SecretKey key = mock(SecretKey.class); + when(key.getFormat()).thenReturn("BadFormat"); + + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + + assertThrows(IllegalArgumentException.class, () -> dao.encryptDataKey(keyId, key, ENCRYPTION_CONTEXT)); + } + + @Test + void testKmsFailure() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + String keyId = client.createKey().getKeyMetadata().getArn(); + doThrow(new KMSInvalidStateException("fail")).when(client).generateDataKey(isA(GenerateDataKeyRequest.class)); + doThrow(new KMSInvalidStateException("fail")).when(client).encrypt(isA(EncryptRequest.class)); + doThrow(new KMSInvalidStateException("fail")).when(client).decrypt(isA(DecryptRequest.class)); + + assertThrows(AwsCryptoException.class, () -> dao.generateDataKey(keyId, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + assertThrows(AwsCryptoException.class, () -> dao.encryptDataKey(keyId, DATA_KEY, ENCRYPTION_CONTEXT)); + assertThrows(AwsCryptoException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + + @Test + void testDecryptBadKmsKeyId() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + DecryptResult badResult = new DecryptResult(); + badResult.setKeyId("badKeyId"); + + doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); + + assertThrows(MismatchedDataKeyException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + + @Test + void testDecryptBadKmsKeyLength() { + AWSKMS client = spy(new MockKMSClient()); + DataKeyEncryptionDao dao = new KmsDataKeyEncryptionDao(s -> client, GRANT_TOKENS); + + DecryptResult badResult = new DecryptResult(); + badResult.setKeyId(new String(ENCRYPTED_DATA_KEY.getProviderInformation(), EncryptedDataKey.PROVIDER_ENCODING)); + badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength() + 1)); + + doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); + + assertThrows(IllegalStateException.class, () -> dao.decryptDataKey(ENCRYPTED_DATA_KEY, ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + + private void assertUserAgent(AmazonWebServiceRequest request) { + assertTrue(request.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT) + .contains(VersionInfo.USER_AGENT)); + } +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java new file mode 100644 index 000000000..4682614fb --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsUtilsTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except + * in compliance with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.amazonaws.encryptionsdk.kms; + +import com.amazonaws.encryptionsdk.exception.MalformedArnException; +import com.amazonaws.services.kms.AWSKMS; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import static org.junit.jupiter.api.Assertions.*; + +@ExtendWith(MockitoExtension.class) +class KmsUtilsTest { + + private static final String VALID_ARN = "arn:aws:kms:us-east-1:999999999999:key/01234567-89ab-cdef-fedc-ba9876543210"; + private static final String VALID_ALIAS_ARN = "arn:aws:kms:us-east-1:999999999999:alias/MyCryptoKey"; + private static final String VALID_ALIAS = "alias/MyCryptoKey"; + + @Mock + private AWSKMS client; + + + @Test + void testGetClientByArn() { + assertEquals(client, KmsUtils.getClientByArn(VALID_ARN, s -> client)); + assertEquals(client, KmsUtils.getClientByArn(VALID_ALIAS_ARN, s -> client)); + assertEquals(client, KmsUtils.getClientByArn(VALID_ALIAS, s -> client)); + assertThrows(MalformedArnException.class, () -> KmsUtils.getClientByArn("invalid", s -> client)); + + } + + @Test + void testIsArnWellFormed() { + assertTrue(KmsUtils.isArnWellFormed(VALID_ARN)); + assertTrue(KmsUtils.isArnWellFormed(VALID_ALIAS_ARN)); + assertTrue(KmsUtils.isArnWellFormed(VALID_ALIAS)); + assertFalse(KmsUtils.isArnWellFormed("invalid")); + + } +} \ No newline at end of file diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java index 37fe9cbff..00ce5c074 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java @@ -29,7 +29,7 @@ import com.amazonaws.ResponseMetadata; import com.amazonaws.regions.Region; import com.amazonaws.regions.Regions; -import com.amazonaws.services.kms.AWSKMSClient; +import com.amazonaws.services.kms.AbstractAWSKMS; import com.amazonaws.services.kms.model.CreateAliasRequest; import com.amazonaws.services.kms.model.CreateAliasResult; import com.amazonaws.services.kms.model.CreateGrantRequest; @@ -85,7 +85,7 @@ import com.amazonaws.services.kms.model.UpdateKeyDescriptionRequest; import com.amazonaws.services.kms.model.UpdateKeyDescriptionResult; -public class MockKMSClient extends AWSKMSClient { +public class MockKMSClient extends AbstractAWSKMS { private static final SecureRandom rnd = new SecureRandom(); private static final String ACCOUNT_ID = "01234567890"; private final Map results_ = new HashMap<>();