diff --git a/.gitmodules b/.gitmodules index eb6b6a564..b90e3bf80 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "src/test/resources/aws-encryption-sdk-test-vectors"] path = src/test/resources/aws-encryption-sdk-test-vectors - url = https://github.com/awslabs/private-aws-encryption-sdk-test-vectors-staging.git + url = https://github.com/awslabs/aws-encryption-sdk-test-vectors.git [submodule "aws-encryption-sdk-specification"] path = aws-encryption-sdk-specification - url = https://github.com/awslabs/private-aws-encryption-sdk-specification-staging.git + url = https://github.com/awslabs/aws-encryption-sdk-specification.git diff --git a/aws-encryption-sdk-specification b/aws-encryption-sdk-specification index ef3420d0f..c35fbd91b 160000 --- a/aws-encryption-sdk-specification +++ b/aws-encryption-sdk-specification @@ -1 +1 @@ -Subproject commit ef3420d0fa8740c4a98f2e9e976d75be185473e4 +Subproject commit c35fbd91b28303d69813119088c44b5006395eb4 diff --git a/buildspec.yml b/buildspec.yml index 365eb003e..1fc7d652a 100644 --- a/buildspec.yml +++ b/buildspec.yml @@ -23,3 +23,8 @@ batch: env: env: image: aws/codebuild/amazonlinux2-x86_64-standard:3.0 + - identifier: static_analysis + buildspec: codebuild/static-analysis.yml + env: + env: + image: aws/codebuild/amazonlinux2-x86_64-standard:3.0 diff --git a/codebuild/compliance.yml b/codebuild/static-analysis.yml similarity index 65% rename from codebuild/compliance.yml rename to codebuild/static-analysis.yml index fe25a9c37..798c1af3c 100644 --- a/codebuild/compliance.yml +++ b/codebuild/static-analysis.yml @@ -4,6 +4,8 @@ phases: install: runtime-versions: nodejs: 12 + java: corretto11 build: commands: + - mvn com.coveo:fmt-maven-plugin:check - ./util/test-conditions.sh diff --git a/pom.xml b/pom.xml index e8e4147f2..5a96f2025 100644 --- a/pom.xml +++ b/pom.xml @@ -199,6 +199,16 @@ + + + com.coveo + fmt-maven-plugin + 2.10 + + + + + diff --git a/src/main/java/com/amazonaws/encryptionsdk/AwsCrypto.java b/src/main/java/com/amazonaws/encryptionsdk/AwsCrypto.java index c16b6ed53..5268935b3 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/AwsCrypto.java +++ b/src/main/java/com/amazonaws/encryptionsdk/AwsCrypto.java @@ -3,898 +3,904 @@ package com.amazonaws.encryptionsdk; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.Map; - import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.internal.*; import com.amazonaws.encryptionsdk.model.CiphertextHeaders; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; /** * Provides the primary entry-point to the AWS Encryption SDK. All encryption and decryption - * operations should start here. Most people will want to use either - * {@link #encryptData(MasterKeyProvider, byte[], Map)} and - * {@link #decryptData(MasterKeyProvider, byte[])} to encrypt/decrypt things. - * - *

- * The core concepts (and classes) in this SDK are: + * operations should start here. Most people will want to use either {@link + * #encryptData(MasterKeyProvider, byte[], Map)} and {@link #decryptData(MasterKeyProvider, byte[])} + * to encrypt/decrypt things. + * + *

The core concepts (and classes) in this SDK are: + * *

* - *

- * {@link AwsCrypto} provides the primary way to encrypt/decrypt data. It can operate on + *

{@link AwsCrypto} provides the primary way to encrypt/decrypt data. It can operate on * byte-arrays, streams, or {@link java.lang.String Strings}. This data is encrypted using the * specifed {@link CryptoAlgorithm} and a {@link DataKey} which is unique to each encrypted message. * This {@code DataKey} is then encrypted using one (or more) {@link MasterKey MasterKeys}. The * process is reversed on decryption with the code selecting a copy of the {@code DataKey} protected * by a usable {@code MasterKey}, decrypting the {@code DataKey}, and then decrypted the message. * - *

- * The main way to get a {@code MasterKey} is through the use of a {@link MasterKeyProvider}. This - * provides a common interface for the AwsEncryptionSdk to find and retrieve {@code MasterKeys}. - * (Some {@code MasterKeys} can also be constructed directly.) + *

The main way to get a {@code MasterKey} is through the use of a {@link MasterKeyProvider}. + * This provides a common interface for the AwsEncryptionSdk to find and retrieve {@code + * MasterKeys}. (Some {@code MasterKeys} can also be constructed directly.) * - *

- * {@code AwsCrypto} uses the {@code MasterKeyProvider} to determine which {@code MasterKeys} should - * be used to encrypt the {@code DataKeys} by calling - * {@link MasterKeyProvider#getMasterKeysForEncryption(MasterKeyRequest)} . When more than one - * {@code MasterKey} is returned, the first {@code MasterKeys} is used to create the - * {@code DataKeys} by calling {@link MasterKey#generateDataKey(CryptoAlgorithm,java.util.Map)} . - * All of the other {@code MasterKeys} are then used to re-encrypt that {@code DataKey} with - * {@link MasterKey#encryptDataKey(CryptoAlgorithm,java.util.Map,DataKey)} . This list of - * {@link EncryptedDataKey EncryptedDataKeys} (the same {@code DataKey} possibly encrypted multiple - * times) is stored in the {@link com.amazonaws.encryptionsdk.model.CiphertextHeaders}. + *

{@code AwsCrypto} uses the {@code MasterKeyProvider} to determine which {@code MasterKeys} + * should be used to encrypt the {@code DataKeys} by calling {@link + * MasterKeyProvider#getMasterKeysForEncryption(MasterKeyRequest)} . When more than one {@code + * MasterKey} is returned, the first {@code MasterKeys} is used to create the {@code DataKeys} by + * calling {@link MasterKey#generateDataKey(CryptoAlgorithm,java.util.Map)} . All of the other + * {@code MasterKeys} are then used to re-encrypt that {@code DataKey} with {@link + * MasterKey#encryptDataKey(CryptoAlgorithm,java.util.Map,DataKey)} . This list of {@link + * EncryptedDataKey EncryptedDataKeys} (the same {@code DataKey} possibly encrypted multiple times) + * is stored in the {@link com.amazonaws.encryptionsdk.model.CiphertextHeaders}. * - *

- * {@code AwsCrypto} also uses the {@code MasterKeyProvider} to decrypt one of the - * {@link EncryptedDataKey EncryptedDataKeys} from the header to retrieve the actual {@code DataKey} + *

{@code AwsCrypto} also uses the {@code MasterKeyProvider} to decrypt one of the {@link + * EncryptedDataKey EncryptedDataKeys} from the header to retrieve the actual {@code DataKey} * necessary to decrypt the message. * - *

- * Any place a {@code MasterKeyProvider} is used, a {@link MasterKey} can be used instead. The + *

Any place a {@code MasterKeyProvider} is used, a {@link MasterKey} can be used instead. The * {@code MasterKey} will behave as a {@code MasterKeyProvider} which is only capable of providing * itself. This is often useful when only one {@code MasterKey} is being used. * - *

- * Note regarding the use of generics: This library makes heavy use of generics to provide type + *

Note regarding the use of generics: This library makes heavy use of generics to provide type * safety to advanced developers. The great majority of users should be able to just use the * provided type parameters or the {@code ?} wildcard. */ @SuppressWarnings("WeakerAccess") // this is a public API public class AwsCrypto { - private static final Map EMPTY_MAP = Collections.emptyMap(); - - // These are volatile because we allow unsynchronized writes via our setters, - // and without setting volatile we could see strange results. - // E.g. copying these to a local might give different values on subsequent reads from the local. - // By setting them volatile we ensure that proper memory barriers are applied - // to ensure things behave in a sensible manner. - private volatile CryptoAlgorithm encryptionAlgorithm_ = null; - private volatile int encryptionFrameSize_ = getDefaultFrameSize(); - - private static final CommitmentPolicy DEFAULT_COMMITMENT_POLICY = CommitmentPolicy.RequireEncryptRequireDecrypt; - private final CommitmentPolicy commitmentPolicy_; - - /** - * The maximum number of encrypted data keys to unwrap (resp. wrap) on decrypt (resp. encrypt), if positive. - * If zero, do not limit EDKs. - */ - private final int maxEncryptedDataKeys_; - - private AwsCrypto(Builder builder) { - commitmentPolicy_ = builder.commitmentPolicy_ == null ? DEFAULT_COMMITMENT_POLICY : builder.commitmentPolicy_; - if (builder.encryptionAlgorithm_ != null && !commitmentPolicy_.algorithmAllowedForEncrypt(builder.encryptionAlgorithm_)) { - if (commitmentPolicy_ == CommitmentPolicy.ForbidEncryptAllowDecrypt) { - throw new AwsCryptoException("Configuration conflict. Cannot encrypt due to CommitmentPolicy " + - commitmentPolicy_ + " requiring only non-committed messages. Algorithm ID was " + - builder.encryptionAlgorithm_ + - ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); - } else { - throw new AwsCryptoException("Configuration conflict. Cannot encrypt due to CommitmentPolicy " + - commitmentPolicy_ + " requiring only committed messages. Algorithm ID was " + - builder.encryptionAlgorithm_ + - ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); - } - } - encryptionAlgorithm_ = builder.encryptionAlgorithm_; - encryptionFrameSize_ = builder.encryptionFrameSize_; - maxEncryptedDataKeys_ = builder.maxEncryptedDataKeys_; - } - - public static class Builder { - private CryptoAlgorithm encryptionAlgorithm_; - private int encryptionFrameSize_ = getDefaultFrameSize(); - private CommitmentPolicy commitmentPolicy_; - private int maxEncryptedDataKeys_ = CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS; - - private Builder() {} - - private Builder(final AwsCrypto client) { - encryptionAlgorithm_ = client.encryptionAlgorithm_; - encryptionFrameSize_ = client.encryptionFrameSize_; - commitmentPolicy_ = client.commitmentPolicy_; - maxEncryptedDataKeys_ = client.maxEncryptedDataKeys_; - } - - /** - * Sets the {@link CryptoAlgorithm} to encrypt with. - * The Aws Crypto client will use the last crypto algorithm set with - * either {@link AwsCrypto.Builder#withEncryptionAlgorithm(CryptoAlgorithm)} or - * {@link #setEncryptionAlgorithm(CryptoAlgorithm)} to encrypt with. - * - * @param encryptionAlgorithm The {@link CryptoAlgorithm} - * @return The Builder, for method chaining - */ - public Builder withEncryptionAlgorithm(CryptoAlgorithm encryptionAlgorithm) { - this.encryptionAlgorithm_ = encryptionAlgorithm; - return this; - } - - /** - * Sets the frame size of the encrypted messages that the Aws Crypto client produces. - * The Aws Crypto client will use the last frame size set with - * either {@link AwsCrypto.Builder#withEncryptionFrameSize(int)} or - * {@link #setEncryptionFrameSize(int)}. - * - * @param frameSize The frame size to produce encrypted messages with. - * @return The Builder, for method chaining - */ - public Builder withEncryptionFrameSize(int frameSize) { - this.encryptionFrameSize_ = frameSize; - return this; - } - - /** - * Sets the {@link CommitmentPolicy} of this Aws Crypto client. - * - * @param commitmentPolicy The commitment policy to enforce during encryption and decryption - * @return The Builder, for method chaining - */ - public Builder withCommitmentPolicy(CommitmentPolicy commitmentPolicy) { - Utils.assertNonNull(commitmentPolicy, "commitmentPolicy"); - this.commitmentPolicy_ = commitmentPolicy; - return this; - } - - /** - * Sets the maximum number of encrypted data keys that this Aws Crypto client will wrap when - * encrypting, or unwrap when decrypting, a single message. - * - * @param maxEncryptedDataKeys The maximum number of encrypted data keys; must be positive - * @return The Builder, for method chaining - */ - public Builder withMaxEncryptedDataKeys(int maxEncryptedDataKeys) { - if (maxEncryptedDataKeys < 1) { - throw new IllegalArgumentException("maxEncryptedDataKeys must be positive"); - } - this.maxEncryptedDataKeys_ = maxEncryptedDataKeys; - return this; - } - - public AwsCrypto build() { - return new AwsCrypto(this); - } - } - - public static Builder builder() { - return new Builder(); - } - - public Builder toBuilder() { - return new Builder(this); - } - - public static AwsCrypto standard() { - return AwsCrypto.builder().build(); - } - - /** - * Returns the frame size to use for encryption when none is explicitly selected. Currently it - * is 4096. - */ - public static int getDefaultFrameSize() { - return 4096; - } - - /** - * Sets the {@link CryptoAlgorithm} to use when encrypting data. This has no impact on - * decryption. - */ - public void setEncryptionAlgorithm(final CryptoAlgorithm alg) { - if (!commitmentPolicy_.algorithmAllowedForEncrypt(alg)) { - if (commitmentPolicy_ == CommitmentPolicy.ForbidEncryptAllowDecrypt) { - throw new AwsCryptoException("Configuration conflict. Cannot encrypt due to CommitmentPolicy " + - commitmentPolicy_ + " requiring only non-committed messages. Algorithm ID was " + - alg + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); - } else { - throw new AwsCryptoException("Configuration conflict. Cannot encrypt due to CommitmentPolicy " + - commitmentPolicy_ + " requiring only committed messages. Algorithm ID was " + - alg + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); - } - } - encryptionAlgorithm_ = alg; - } - - public CryptoAlgorithm getEncryptionAlgorithm() { - return encryptionAlgorithm_; - } - - /** - * Sets the framing size to use when encrypting data. This has no impact on decryption. - * If {@code frameSize} is 0, then framing is disabled and the entire plaintext will be encrypted - * in a single block. - * - * Note that during encryption arrays of this size will be allocated. Using extremely large frame sizes may pose - * compatibility issues when the decryptor is running on 32-bit systems. Additionally, Java VM limits may set a - * platform-specific upper bound to frame sizes. - */ - public void setEncryptionFrameSize(final int frameSize) { - if (frameSize < 0) { - throw new IllegalArgumentException("frameSize must be non-negative"); - } - - encryptionFrameSize_ = frameSize; - } - - public int getEncryptionFrameSize() { - return encryptionFrameSize_; - } - - /** - * Returns the best estimate for the output length of encrypting a plaintext with the provided - * {@code plaintextSize} and {@code encryptionContext}. The actual ciphertext may be shorter. - * - * This method is equivalent to calling {@link #estimateCiphertextSize(CryptoMaterialsManager, int, Map)} with a - * {@link DefaultCryptoMaterialsManager} based on the given provider. - */ - public > long estimateCiphertextSize( - final MasterKeyProvider provider, - final int plaintextSize, - final Map encryptionContext - ) { - return estimateCiphertextSize(new DefaultCryptoMaterialsManager(provider), plaintextSize, encryptionContext); - } - - /** - * Returns the best estimate for the output length of encrypting a plaintext with the provided - * {@code plaintextSize} and {@code encryptionContext}. The actual ciphertext may be shorter. - */ - public long estimateCiphertextSize( - CryptoMaterialsManager materialsManager, - final int plaintextSize, - final Map encryptionContext - ) { - EncryptionMaterialsRequest request = EncryptionMaterialsRequest.newBuilder() - .setContext(encryptionContext) - .setRequestedAlgorithm(getEncryptionAlgorithm()) - // We're not actually encrypting any data, so don't consume any bytes from the cache's limits. We do need to - // pass /something/ though, or the cache will be bypassed (as it'll assume this is a streaming encrypt of - // unknown size). - .setPlaintextSize(0) - .setCommitmentPolicy(commitmentPolicy_) - .build(); - - final MessageCryptoHandler cryptoHandler = new EncryptionHandler( - getEncryptionFrameSize(), - checkAlgorithm(materialsManager.getMaterialsForEncrypt(request)), - commitmentPolicy_ - ); - - return cryptoHandler.estimateOutputSize(plaintextSize); - } - - /** - * Returns the equivalent to calling - * {@link #estimateCiphertextSize(MasterKeyProvider, int, Map)} with an empty - * {@code encryptionContext}. - */ - public > long estimateCiphertextSize( - final MasterKeyProvider provider, - final int plaintextSize - ) { - return estimateCiphertextSize(provider, plaintextSize, EMPTY_MAP); - } - - /** - * Returns the equivalent to calling - * {@link #estimateCiphertextSize(CryptoMaterialsManager, int, Map)} with an empty - * {@code encryptionContext}. - */ - public long estimateCiphertextSize( - final CryptoMaterialsManager materialsManager, - final int plaintextSize - ) { - return estimateCiphertextSize(materialsManager, plaintextSize, EMPTY_MAP); - } - - /** - * Returns an encrypted form of {@code plaintext} that has been protected with {@link DataKey - * DataKeys} that are in turn protected by {@link MasterKey MasterKeys} provided by - * {@code provider}. - * - * This method is equivalent to calling {@link #encryptData(CryptoMaterialsManager, byte[], Map)} using a - * {@link DefaultCryptoMaterialsManager} based on the given provider. - */ - public > CryptoResult encryptData( - final MasterKeyProvider provider, - final byte[] plaintext, - final Map encryptionContext - ) { - //noinspection unchecked - return (CryptoResult) - encryptData(new DefaultCryptoMaterialsManager(provider), plaintext, encryptionContext); - } - - /** - * Returns an encrypted form of {@code plaintext} that has been protected with {@link DataKey - * DataKeys} that are in turn protected by the given CryptoMaterialsProvider. - */ - public CryptoResult encryptData( - CryptoMaterialsManager materialsManager, - final byte[] plaintext, - final Map encryptionContext - ) { - EncryptionMaterialsRequest request = EncryptionMaterialsRequest.newBuilder() - .setContext(encryptionContext) - .setRequestedAlgorithm(getEncryptionAlgorithm()) - .setPlaintext(plaintext) - .setCommitmentPolicy(commitmentPolicy_) - .build(); - - EncryptionMaterials encryptionMaterials = checkMaxEncryptedDataKeys(checkAlgorithm(materialsManager.getMaterialsForEncrypt(request))); - final MessageCryptoHandler cryptoHandler = new EncryptionHandler( - getEncryptionFrameSize(), - encryptionMaterials, - commitmentPolicy_ - ); - - final int outSizeEstimate = cryptoHandler.estimateOutputSize(plaintext.length); - final byte[] out = new byte[outSizeEstimate]; - int outLen = cryptoHandler.processBytes(plaintext, 0, plaintext.length, out, 0).getBytesWritten(); - outLen += cryptoHandler.doFinal(out, outLen); - - final byte[] outBytes = Utils.truncate(out, outLen); - - //noinspection unchecked - return new CryptoResult(outBytes, cryptoHandler.getMasterKeys(), cryptoHandler.getHeaders()); - } - - /** - * Returns the equivalent to calling {@link #encryptData(MasterKeyProvider, byte[], Map)} with - * an empty {@code encryptionContext}. - */ - public > CryptoResult encryptData(final MasterKeyProvider provider, - final byte[] plaintext) { - return encryptData(provider, plaintext, EMPTY_MAP); - } - - /** - * Returns the equivalent to calling {@link #encryptData(CryptoMaterialsManager, byte[], Map)} with - * an empty {@code encryptionContext}. - */ - public CryptoResult encryptData( - final CryptoMaterialsManager materialsManager, - final byte[] plaintext - ) { - return encryptData(materialsManager, plaintext, EMPTY_MAP); - } - - /** - * Calls {@link #encryptData(MasterKeyProvider, byte[], Map)} on the UTF-8 encoded bytes of - * {@code plaintext} and base64 encodes the result. - * @deprecated Use the {@link #encryptData(MasterKeyProvider, byte[], Map)} and - * {@link #decryptData(MasterKeyProvider, byte[])} APIs instead. {@code encryptString} and {@code decryptString} - * work as expected if you use them together. However, to work with other language implementations of the AWS - * Encryption SDK, you need to base64-decode the output of {@code encryptString} and base64-encode the input to - * {@code decryptString}. These deprecated APIs will be removed in the future. - */ - @Deprecated - public > CryptoResult encryptString( - final MasterKeyProvider provider, - final String plaintext, - final Map encryptionContext - ) { - //noinspection unchecked - return (CryptoResult) - encryptString(new DefaultCryptoMaterialsManager(provider), plaintext, encryptionContext); - } - - /** - * Calls {@link #encryptData(CryptoMaterialsManager, byte[], Map)} on the UTF-8 encoded bytes of - * {@code plaintext} and base64 encodes the result. - * @deprecated Use the {@link #encryptData(CryptoMaterialsManager, byte[], Map)} and - * {@link #decryptData(CryptoMaterialsManager, byte[])} APIs instead. {@code encryptString} and {@code decryptString} - * work as expected if you use them together. However, to work with other language implementations of the AWS - * Encryption SDK, you need to base64-decode the output of {@code encryptString} and base64-encode the input to - * {@code decryptString}. These deprecated APIs will be removed in the future. - */ - @Deprecated - public CryptoResult encryptString( - CryptoMaterialsManager materialsManager, - final String plaintext, - final Map encryptionContext - ) { - final CryptoResult ctBytes = encryptData( - materialsManager, - plaintext.getBytes(StandardCharsets.UTF_8), - encryptionContext - ); - return new CryptoResult<>(Utils.encodeBase64String(ctBytes.getResult()), - ctBytes.getMasterKeys(), ctBytes.getHeaders()); - } - - /** - * Returns the equivalent to calling {@link #encryptString(MasterKeyProvider, String, Map)} with - * an empty {@code encryptionContext}. - * @deprecated Use the {@link #encryptData(MasterKeyProvider, byte[])} and - * {@link #decryptData(MasterKeyProvider, byte[])} APIs instead. {@code encryptString} and {@code decryptString} - * work as expected if you use them together. However, to work with other language implementations of the AWS - * Encryption SDK, you need to base64-decode the output of {@code encryptString} and base64-encode the input to - * {@code decryptString}. These deprecated APIs will be removed in the future. - */ - @Deprecated - public > CryptoResult encryptString(final MasterKeyProvider provider, - final String plaintext) { - return encryptString(provider, plaintext, EMPTY_MAP); - } - - /** - * Returns the equivalent to calling {@link #encryptString(CryptoMaterialsManager, String, Map)} with - * an empty {@code encryptionContext}. - * @deprecated Use the {@link #encryptData(CryptoMaterialsManager, byte[])} and - * {@link #decryptData(CryptoMaterialsManager, byte[])} APIs instead. {@code encryptString} and {@code decryptString} - * work as expected if you use them together. However, to work with other language implementations of the AWS - * Encryption SDK, you need to base64-decode the output of {@code encryptString} and base64-encode the input to - * {@code decryptString}. These deprecated APIs will be removed in the future. - */ - @Deprecated - public CryptoResult encryptString( - final CryptoMaterialsManager materialsManager, - final String plaintext - ) { - return encryptString(materialsManager, plaintext, EMPTY_MAP); - } - - /** - * Decrypts the provided {@code ciphertext} by requesting that the {@code provider} unwrap any - * usable {@link DataKey} in the ciphertext and then decrypts the ciphertext using that - * {@code DataKey}. - */ - public > CryptoResult decryptData(final MasterKeyProvider provider, - final byte[] ciphertext) { - return decryptData(Utils.assertNonNull(provider, "provider"), new - ParsedCiphertext(ciphertext, maxEncryptedDataKeys_)); - } - - /** - * Decrypts the provided ciphertext by delegating to the provided materialsManager to obtain the decrypted - * {@link DataKey}. - * - * @param materialsManager the {@link CryptoMaterialsManager} to use for decryption operations. - * @param ciphertext the ciphertext to attempt to decrypt. - * @return the {@link CryptoResult} with the decrypted data. - */ - public CryptoResult decryptData( - final CryptoMaterialsManager materialsManager, - final byte[] ciphertext - ) { - return decryptData(Utils.assertNonNull(materialsManager, "materialsManager"), - new ParsedCiphertext(ciphertext, maxEncryptedDataKeys_)); - } - - /** - * @see #decryptData(MasterKeyProvider, byte[]) - */ - @SuppressWarnings("unchecked") - public > CryptoResult decryptData( - final MasterKeyProvider provider, final ParsedCiphertext ciphertext) { - Utils.assertNonNull(provider, "provider"); - return (CryptoResult) decryptData(new DefaultCryptoMaterialsManager(provider), ciphertext); - } - - /** - * @see #decryptData(CryptoMaterialsManager, byte[]) - */ - public CryptoResult decryptData( - final CryptoMaterialsManager materialsManager, - final ParsedCiphertext ciphertext - ) { - Utils.assertNonNull(materialsManager, "materialsManager"); - - final MessageCryptoHandler cryptoHandler = - DecryptionHandler.create(materialsManager, ciphertext, commitmentPolicy_, - SignaturePolicy.AllowEncryptAllowDecrypt, maxEncryptedDataKeys_); - - final byte[] ciphertextBytes = ciphertext.getCiphertext(); - final int contentLen = ciphertextBytes.length - ciphertext.getOffset(); - final int outSizeEstimate = cryptoHandler.estimateOutputSize(contentLen); - final byte[] out = new byte[outSizeEstimate]; - final ProcessingSummary processed = cryptoHandler.processBytes(ciphertextBytes, ciphertext.getOffset(), - contentLen, out, - 0); - if (processed.getBytesProcessed() != contentLen) { - throw new BadCiphertextException("Unable to process entire ciphertext. May have trailing data."); - } - int outLen = processed.getBytesWritten(); - outLen += cryptoHandler.doFinal(out, outLen); - - final byte[] outBytes = Utils.truncate(out, outLen); - - //noinspection unchecked - return new CryptoResult(outBytes, cryptoHandler.getMasterKeys(), cryptoHandler.getHeaders()); - } - - /** - * Base64 decodes the {@code ciphertext} prior to decryption and then treats the results as a - * UTF-8 encoded string. - * - * @see #decryptData(MasterKeyProvider, byte[]) - * @deprecated Use the {@link #decryptData(MasterKeyProvider, byte[])} and - * {@link #encryptData(MasterKeyProvider, byte[], Map)} APIs instead. {@code encryptString} and {@code decryptString} - * work as expected if you use them together. However, to work with other language implementations of the AWS - * Encryption SDK, you need to base64-decode the output of {@code encryptString} and base64-encode the input to - * {@code decryptString}. These deprecated APIs will be removed in the future. - */ - @Deprecated - @SuppressWarnings("unchecked") - public > CryptoResult decryptString( - final MasterKeyProvider provider, - final String ciphertext - ) { - return (CryptoResult) decryptString(new DefaultCryptoMaterialsManager(provider), ciphertext); - } - - /** - * Base64 decodes the {@code ciphertext} prior to decryption and then treats the results as a - * UTF-8 encoded string. - * - * @see #decryptData(CryptoMaterialsManager, byte[]) - * @deprecated Use the {@link #decryptData(CryptoMaterialsManager, byte[])} and - * {@link #encryptData(CryptoMaterialsManager, byte[], Map)} APIs instead. {@code encryptString} and {@code decryptString} - * work as expected if you use them together. However, to work with other language implementations of the AWS - * Encryption SDK, you need to base64-decode the output of {@code encryptString} and base64-encode the input to - * {@code decryptString}. These deprecated APIs will be removed in the future. - */ - @Deprecated - public CryptoResult decryptString(final CryptoMaterialsManager provider, - final String ciphertext) { - Utils.assertNonNull(provider, "provider"); - final byte[] ciphertextBytes; - try { - ciphertextBytes = Utils.decodeBase64String(Utils.assertNonNull(ciphertext, "ciphertext")); - } catch (final IllegalArgumentException ex) { - throw new BadCiphertextException("Invalid base 64", ex); - } - final CryptoResult ptBytes = decryptData(provider, ciphertextBytes); - //noinspection unchecked - return new CryptoResult( - new String(ptBytes.getResult(), StandardCharsets.UTF_8), - ptBytes.getMasterKeys(), ptBytes.getHeaders()); - } - - /** - * Returns a {@link CryptoOutputStream} which encrypts the data prior to passing it onto the - * underlying {@link OutputStream}. - * - * @see #encryptData(MasterKeyProvider, byte[], Map) - * @see javax.crypto.CipherOutputStream - */ - public > CryptoOutputStream createEncryptingStream( - final MasterKeyProvider provider, - final OutputStream os, - final Map encryptionContext - ) { - //noinspection unchecked - return (CryptoOutputStream) - createEncryptingStream(new DefaultCryptoMaterialsManager(provider), os, encryptionContext); - } - - /** - * Returns a {@link CryptoOutputStream} which encrypts the data prior to passing it onto the - * underlying {@link OutputStream}. - * - * @see #encryptData(MasterKeyProvider, byte[], Map) - * @see javax.crypto.CipherOutputStream - */ - public CryptoOutputStream createEncryptingStream( - final CryptoMaterialsManager materialsManager, - final OutputStream os, - final Map encryptionContext - ) { - return new CryptoOutputStream<>(os, getEncryptingStreamHandler(materialsManager, encryptionContext)); - } - - /** - * Returns the equivalent to calling - * {@link #createEncryptingStream(MasterKeyProvider, OutputStream, Map)} with an empty - * {@code encryptionContext}. - */ - public > CryptoOutputStream createEncryptingStream( - final MasterKeyProvider provider, - final OutputStream os) { - return createEncryptingStream(provider, os, EMPTY_MAP); - } - - /** - * Returns the equivalent to calling - * {@link #createEncryptingStream(CryptoMaterialsManager, OutputStream, Map)} with an empty - * {@code encryptionContext}. - */ - public CryptoOutputStream createEncryptingStream( - final CryptoMaterialsManager materialsManager, - final OutputStream os - ) { - return createEncryptingStream(materialsManager, os, EMPTY_MAP); - } - - /** - * Returns a {@link CryptoInputStream} which encrypts the data after reading it from the - * underlying {@link InputStream}. + private static final Map EMPTY_MAP = Collections.emptyMap(); + + // These are volatile because we allow unsynchronized writes via our setters, + // and without setting volatile we could see strange results. + // E.g. copying these to a local might give different values on subsequent reads from the local. + // By setting them volatile we ensure that proper memory barriers are applied + // to ensure things behave in a sensible manner. + private volatile CryptoAlgorithm encryptionAlgorithm_ = null; + private volatile int encryptionFrameSize_ = getDefaultFrameSize(); + + private static final CommitmentPolicy DEFAULT_COMMITMENT_POLICY = + CommitmentPolicy.RequireEncryptRequireDecrypt; + private final CommitmentPolicy commitmentPolicy_; + + /** + * The maximum number of encrypted data keys to unwrap (resp. wrap) on decrypt (resp. encrypt), if + * positive. If zero, do not limit EDKs. + */ + private final int maxEncryptedDataKeys_; + + private AwsCrypto(Builder builder) { + commitmentPolicy_ = + builder.commitmentPolicy_ == null ? DEFAULT_COMMITMENT_POLICY : builder.commitmentPolicy_; + if (builder.encryptionAlgorithm_ != null + && !commitmentPolicy_.algorithmAllowedForEncrypt(builder.encryptionAlgorithm_)) { + if (commitmentPolicy_ == CommitmentPolicy.ForbidEncryptAllowDecrypt) { + throw new AwsCryptoException( + "Configuration conflict. Cannot encrypt due to CommitmentPolicy " + + commitmentPolicy_ + + " requiring only non-committed messages. Algorithm ID was " + + builder.encryptionAlgorithm_ + + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); + } else { + throw new AwsCryptoException( + "Configuration conflict. Cannot encrypt due to CommitmentPolicy " + + commitmentPolicy_ + + " requiring only committed messages. Algorithm ID was " + + builder.encryptionAlgorithm_ + + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); + } + } + encryptionAlgorithm_ = builder.encryptionAlgorithm_; + encryptionFrameSize_ = builder.encryptionFrameSize_; + maxEncryptedDataKeys_ = builder.maxEncryptedDataKeys_; + } + + public static class Builder { + private CryptoAlgorithm encryptionAlgorithm_; + private int encryptionFrameSize_ = getDefaultFrameSize(); + private CommitmentPolicy commitmentPolicy_; + private int maxEncryptedDataKeys_ = CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS; + + private Builder() {} + + private Builder(final AwsCrypto client) { + encryptionAlgorithm_ = client.encryptionAlgorithm_; + encryptionFrameSize_ = client.encryptionFrameSize_; + commitmentPolicy_ = client.commitmentPolicy_; + maxEncryptedDataKeys_ = client.maxEncryptedDataKeys_; + } + + /** + * Sets the {@link CryptoAlgorithm} to encrypt with. The Aws Crypto client will use the last + * crypto algorithm set with either {@link + * AwsCrypto.Builder#withEncryptionAlgorithm(CryptoAlgorithm)} or {@link + * #setEncryptionAlgorithm(CryptoAlgorithm)} to encrypt with. * - * @see #encryptData(MasterKeyProvider, byte[], Map) - * @see javax.crypto.CipherInputStream + * @param encryptionAlgorithm The {@link CryptoAlgorithm} + * @return The Builder, for method chaining */ - public > CryptoInputStream createEncryptingStream( - final MasterKeyProvider provider, - final InputStream is, - final Map encryptionContext - ) { - //noinspection unchecked - return (CryptoInputStream) - createEncryptingStream(new DefaultCryptoMaterialsManager(provider), is, encryptionContext); + public Builder withEncryptionAlgorithm(CryptoAlgorithm encryptionAlgorithm) { + this.encryptionAlgorithm_ = encryptionAlgorithm; + return this; } /** - * Returns a {@link CryptoInputStream} which encrypts the data after reading it from the - * underlying {@link InputStream}. + * Sets the frame size of the encrypted messages that the Aws Crypto client produces. The Aws + * Crypto client will use the last frame size set with either {@link + * AwsCrypto.Builder#withEncryptionFrameSize(int)} or {@link #setEncryptionFrameSize(int)}. * - * @see #encryptData(MasterKeyProvider, byte[], Map) - * @see javax.crypto.CipherInputStream + * @param frameSize The frame size to produce encrypted messages with. + * @return The Builder, for method chaining */ - public CryptoInputStream createEncryptingStream( - CryptoMaterialsManager materialsManager, - final InputStream is, - final Map encryptionContext - ) { - final MessageCryptoHandler cryptoHandler = getEncryptingStreamHandler(materialsManager, encryptionContext); - - return new CryptoInputStream<>(is, cryptoHandler); + public Builder withEncryptionFrameSize(int frameSize) { + this.encryptionFrameSize_ = frameSize; + return this; } /** - * Returns the equivalent to calling - * {@link #createEncryptingStream(MasterKeyProvider, InputStream, Map)} with an empty - * {@code encryptionContext}. - */ - public > CryptoInputStream createEncryptingStream( - final MasterKeyProvider provider, - final InputStream is - ) { - return createEncryptingStream(provider, is, EMPTY_MAP); - } - - /** - * Returns the equivalent to calling - * {@link #createEncryptingStream(CryptoMaterialsManager, InputStream, Map)} with an empty - * {@code encryptionContext}. - */ - public CryptoInputStream createEncryptingStream( - final CryptoMaterialsManager materialsManager, - final InputStream is - ) { - return createEncryptingStream(materialsManager, is, EMPTY_MAP); - } - - /** - * Returns a {@link CryptoOutputStream} which decrypts the data prior to passing it onto the - * underlying {@link OutputStream}. This version only accepts unsigned messages. - * - * @see #decryptData(MasterKeyProvider, byte[]) - * @see javax.crypto.CipherOutputStream - */ - public > CryptoOutputStream createUnsignedMessageDecryptingStream( - final MasterKeyProvider provider, final OutputStream os) { - final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(provider, - commitmentPolicy_, - SignaturePolicy.AllowEncryptForbidDecrypt, - maxEncryptedDataKeys_); - return new CryptoOutputStream(os, cryptoHandler); - } - - /** - * Returns a {@link CryptoInputStream} which decrypts the data after reading it from the - * underlying {@link InputStream}. This version only accepts unsigned messages. - * - * @see #decryptData(MasterKeyProvider, byte[]) - * @see javax.crypto.CipherInputStream - */ - public > CryptoInputStream createUnsignedMessageDecryptingStream( - final MasterKeyProvider provider, final InputStream is) { - final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(provider, - commitmentPolicy_, - SignaturePolicy.AllowEncryptForbidDecrypt, - maxEncryptedDataKeys_); - return new CryptoInputStream(is, cryptoHandler); - } - - /** - * Returns a {@link CryptoOutputStream} which decrypts the data prior to passing it onto the - * underlying {@link OutputStream}. This version only accepts unsigned messages. + * Sets the {@link CommitmentPolicy} of this Aws Crypto client. * - * @see #decryptData(CryptoMaterialsManager, byte[]) - * @see javax.crypto.CipherOutputStream + * @param commitmentPolicy The commitment policy to enforce during encryption and decryption + * @return The Builder, for method chaining */ - public CryptoOutputStream createUnsignedMessageDecryptingStream( - final CryptoMaterialsManager materialsManager, final OutputStream os - ) { - final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager, - commitmentPolicy_, - SignaturePolicy.AllowEncryptForbidDecrypt, - maxEncryptedDataKeys_); - return new CryptoOutputStream(os, cryptoHandler); + public Builder withCommitmentPolicy(CommitmentPolicy commitmentPolicy) { + Utils.assertNonNull(commitmentPolicy, "commitmentPolicy"); + this.commitmentPolicy_ = commitmentPolicy; + return this; } /** - * Returns a {@link CryptoInputStream} which decrypts the data after reading it from the - * underlying {@link InputStream}. This version only accepts unsigned messages. + * Sets the maximum number of encrypted data keys that this Aws Crypto client will wrap when + * encrypting, or unwrap when decrypting, a single message. * - * @see #encryptData(CryptoMaterialsManager, byte[], Map) - * @see javax.crypto.CipherInputStream - */ - public CryptoInputStream createUnsignedMessageDecryptingStream( - final CryptoMaterialsManager materialsManager, final InputStream is - ) { - final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager, - commitmentPolicy_, - SignaturePolicy.AllowEncryptForbidDecrypt, - maxEncryptedDataKeys_); - return new CryptoInputStream(is, cryptoHandler); - } - - /** - * Returns a {@link CryptoOutputStream} which decrypts the data prior to passing it onto the - * underlying {@link OutputStream}. - * - * Note that if the encrypted message includes a trailing signature, by necessity it cannot be verified until - * after the decrypted plaintext has been released to the underlying {@link OutputStream}! This behavior can - * be avoided by using the non-streaming #decryptData(MasterKeyProvider, byte[]) method instead, or - * #createUnsignedMessageDecryptingStream(MasterKeyProvider, OutputStream) if you do not need to decrypt - * signed messages. - * - * @see #decryptData(MasterKeyProvider, byte[]) - * @see #createUnsignedMessageDecryptingStream(MasterKeyProvider, OutputStream) - * @see javax.crypto.CipherOutputStream - */ - public > CryptoOutputStream createDecryptingStream( - final MasterKeyProvider provider, final OutputStream os) { - final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(provider, - commitmentPolicy_, - SignaturePolicy.AllowEncryptAllowDecrypt, - maxEncryptedDataKeys_); - return new CryptoOutputStream(os, cryptoHandler); - } - - /** - * Returns a {@link CryptoInputStream} which decrypts the data after reading it from the - * underlying {@link InputStream}. - * - * Note that if the encrypted message includes a trailing signature, by necessity it cannot be verified until - * after the decrypted plaintext has been produced from the {@link InputStream}! This behavior can - * be avoided by using the non-streaming #decryptData(MasterKeyProvider, byte[]) method instead, or - * #createUnsignedMessageDecryptingStream(MasterKeyProvider, InputStream) if you do not need to decrypt - * signed messages. - * - * @see #decryptData(MasterKeyProvider, byte[]) - * @see #createUnsignedMessageDecryptingStream(MasterKeyProvider, InputStream) - * @see javax.crypto.CipherInputStream - */ - public > CryptoInputStream createDecryptingStream( - final MasterKeyProvider provider, final InputStream is) { - final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(provider, - commitmentPolicy_, - SignaturePolicy.AllowEncryptAllowDecrypt, - maxEncryptedDataKeys_); - return new CryptoInputStream(is, cryptoHandler); - } - - /** - * Returns a {@link CryptoOutputStream} which decrypts the data prior to passing it onto the - * underlying {@link OutputStream}. - * - * Note that if the encrypted message includes a trailing signature, by necessity it cannot be verified until - * after the decrypted plaintext has been released to the underlying {@link OutputStream}! This behavior can - * be avoided by using the non-streaming #decryptData(CryptoMaterialsManager, byte[]) method instead, or - * #createUnsignedMessageDecryptingStream(CryptoMaterialsManager, OutputStream) if you do not need to decrypt - * signed messages. - * - * @see #decryptData(CryptoMaterialsManager, byte[]) - * @see #createUnsignedMessageDecryptingStream(CryptoMaterialsManager, OutputStream) - * @see javax.crypto.CipherOutputStream - */ - public CryptoOutputStream createDecryptingStream( - final CryptoMaterialsManager materialsManager, final OutputStream os - ) { - final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager, - commitmentPolicy_, - SignaturePolicy.AllowEncryptAllowDecrypt, - maxEncryptedDataKeys_); - return new CryptoOutputStream(os, cryptoHandler); - } - - /** - * Returns a {@link CryptoInputStream} which decrypts the data after reading it from the - * underlying {@link InputStream}. - * - * Note that if the encrypted message includes a trailing signature, by necessity it cannot be verified until - * after the decrypted plaintext has been produced from the {@link InputStream}! This behavior can - * be avoided by using the non-streaming #decryptData(CryptoMaterialsManager, byte[]) method instead, or - * #createUnsignedMessageDecryptingStream(CryptoMaterialsManager, InputStream) if you do not need to decrypt - * signed messages. - * - * @see #decryptData(CryptoMaterialsManager, byte[]) - * @see #createUnsignedMessageDecryptingStream(CryptoMaterialsManager, InputStream) - * @see javax.crypto.CipherInputStream - */ - public CryptoInputStream createDecryptingStream( - final CryptoMaterialsManager materialsManager, final InputStream is - ) { - final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager, - commitmentPolicy_, - SignaturePolicy.AllowEncryptAllowDecrypt, - maxEncryptedDataKeys_); - return new CryptoInputStream(is, cryptoHandler); - } - - private MessageCryptoHandler getEncryptingStreamHandler( - CryptoMaterialsManager materialsManager, Map encryptionContext - ) { - Utils.assertNonNull(materialsManager, "materialsManager"); - Utils.assertNonNull(encryptionContext, "encryptionContext"); - - EncryptionMaterialsRequest.Builder requestBuilder = EncryptionMaterialsRequest.newBuilder() - .setContext(encryptionContext) - .setRequestedAlgorithm(getEncryptionAlgorithm()) - .setCommitmentPolicy(commitmentPolicy_); - - return new LazyMessageCryptoHandler(info -> { - // Hopefully we know the input size now, so we can pass it along to the CMM. - if (info.getMaxInputSize() != -1) { - requestBuilder.setPlaintextSize(info.getMaxInputSize()); - } - - return new EncryptionHandler( - getEncryptionFrameSize(), - checkMaxEncryptedDataKeys(checkAlgorithm(materialsManager.getMaterialsForEncrypt(requestBuilder.build()))), - commitmentPolicy_ - ); + * @param maxEncryptedDataKeys The maximum number of encrypted data keys; must be positive + * @return The Builder, for method chaining + */ + public Builder withMaxEncryptedDataKeys(int maxEncryptedDataKeys) { + if (maxEncryptedDataKeys < 1) { + throw new IllegalArgumentException("maxEncryptedDataKeys must be positive"); + } + this.maxEncryptedDataKeys_ = maxEncryptedDataKeys; + return this; + } + + public AwsCrypto build() { + return new AwsCrypto(this); + } + } + + public static Builder builder() { + return new Builder(); + } + + public Builder toBuilder() { + return new Builder(this); + } + + public static AwsCrypto standard() { + return AwsCrypto.builder().build(); + } + + /** + * Returns the frame size to use for encryption when none is explicitly selected. Currently it is + * 4096. + */ + public static int getDefaultFrameSize() { + return 4096; + } + + /** + * Sets the {@link CryptoAlgorithm} to use when encrypting data. This has no impact on + * decryption. + */ + public void setEncryptionAlgorithm(final CryptoAlgorithm alg) { + if (!commitmentPolicy_.algorithmAllowedForEncrypt(alg)) { + if (commitmentPolicy_ == CommitmentPolicy.ForbidEncryptAllowDecrypt) { + throw new AwsCryptoException( + "Configuration conflict. Cannot encrypt due to CommitmentPolicy " + + commitmentPolicy_ + + " requiring only non-committed messages. Algorithm ID was " + + alg + + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); + } else { + throw new AwsCryptoException( + "Configuration conflict. Cannot encrypt due to CommitmentPolicy " + + commitmentPolicy_ + + " requiring only committed messages. Algorithm ID was " + + alg + + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); + } + } + encryptionAlgorithm_ = alg; + } + + public CryptoAlgorithm getEncryptionAlgorithm() { + return encryptionAlgorithm_; + } + + /** + * Sets the framing size to use when encrypting data. This has no impact on decryption. + * If {@code frameSize} is 0, then framing is disabled and the entire plaintext will be encrypted + * in a single block. + * + *

Note that during encryption arrays of this size will be allocated. Using extremely large + * frame sizes may pose compatibility issues when the decryptor is running on 32-bit systems. + * Additionally, Java VM limits may set a platform-specific upper bound to frame sizes. + */ + public void setEncryptionFrameSize(final int frameSize) { + if (frameSize < 0) { + throw new IllegalArgumentException("frameSize must be non-negative"); + } + + encryptionFrameSize_ = frameSize; + } + + public int getEncryptionFrameSize() { + return encryptionFrameSize_; + } + + /** + * Returns the best estimate for the output length of encrypting a plaintext with the provided + * {@code plaintextSize} and {@code encryptionContext}. The actual ciphertext may be shorter. + * + *

This method is equivalent to calling {@link #estimateCiphertextSize(CryptoMaterialsManager, + * int, Map)} with a {@link DefaultCryptoMaterialsManager} based on the given provider. + */ + public > long estimateCiphertextSize( + final MasterKeyProvider provider, + final int plaintextSize, + final Map encryptionContext) { + return estimateCiphertextSize( + new DefaultCryptoMaterialsManager(provider), plaintextSize, encryptionContext); + } + + /** + * Returns the best estimate for the output length of encrypting a plaintext with the provided + * {@code plaintextSize} and {@code encryptionContext}. The actual ciphertext may be shorter. + */ + public long estimateCiphertextSize( + CryptoMaterialsManager materialsManager, + final int plaintextSize, + final Map encryptionContext) { + EncryptionMaterialsRequest request = + EncryptionMaterialsRequest.newBuilder() + .setContext(encryptionContext) + .setRequestedAlgorithm(getEncryptionAlgorithm()) + // We're not actually encrypting any data, so don't consume any bytes from the cache's + // limits. We do need to + // pass /something/ though, or the cache will be bypassed (as it'll assume this is a + // streaming encrypt of + // unknown size). + .setPlaintextSize(0) + .setCommitmentPolicy(commitmentPolicy_) + .build(); + + final MessageCryptoHandler cryptoHandler = + new EncryptionHandler( + getEncryptionFrameSize(), + checkAlgorithm(materialsManager.getMaterialsForEncrypt(request)), + commitmentPolicy_); + + return cryptoHandler.estimateOutputSize(plaintextSize); + } + + /** + * Returns the equivalent to calling {@link #estimateCiphertextSize(MasterKeyProvider, int, Map)} + * with an empty {@code encryptionContext}. + */ + public > long estimateCiphertextSize( + final MasterKeyProvider provider, final int plaintextSize) { + return estimateCiphertextSize(provider, plaintextSize, EMPTY_MAP); + } + + /** + * Returns the equivalent to calling {@link #estimateCiphertextSize(CryptoMaterialsManager, int, + * Map)} with an empty {@code encryptionContext}. + */ + public long estimateCiphertextSize( + final CryptoMaterialsManager materialsManager, final int plaintextSize) { + return estimateCiphertextSize(materialsManager, plaintextSize, EMPTY_MAP); + } + + /** + * Returns an encrypted form of {@code plaintext} that has been protected with {@link DataKey + * DataKeys} that are in turn protected by {@link MasterKey MasterKeys} provided by {@code + * provider}. + * + *

This method is equivalent to calling {@link #encryptData(CryptoMaterialsManager, byte[], + * Map)} using a {@link DefaultCryptoMaterialsManager} based on the given provider. + */ + public > CryptoResult encryptData( + final MasterKeyProvider provider, + final byte[] plaintext, + final Map encryptionContext) { + //noinspection unchecked + return (CryptoResult) + encryptData(new DefaultCryptoMaterialsManager(provider), plaintext, encryptionContext); + } + + /** + * Returns an encrypted form of {@code plaintext} that has been protected with {@link DataKey + * DataKeys} that are in turn protected by the given CryptoMaterialsProvider. + */ + public CryptoResult encryptData( + CryptoMaterialsManager materialsManager, + final byte[] plaintext, + final Map encryptionContext) { + EncryptionMaterialsRequest request = + EncryptionMaterialsRequest.newBuilder() + .setContext(encryptionContext) + .setRequestedAlgorithm(getEncryptionAlgorithm()) + .setPlaintext(plaintext) + .setCommitmentPolicy(commitmentPolicy_) + .build(); + + EncryptionMaterials encryptionMaterials = + checkMaxEncryptedDataKeys(checkAlgorithm(materialsManager.getMaterialsForEncrypt(request))); + final MessageCryptoHandler cryptoHandler = + new EncryptionHandler(getEncryptionFrameSize(), encryptionMaterials, commitmentPolicy_); + + final int outSizeEstimate = cryptoHandler.estimateOutputSize(plaintext.length); + final byte[] out = new byte[outSizeEstimate]; + int outLen = + cryptoHandler.processBytes(plaintext, 0, plaintext.length, out, 0).getBytesWritten(); + outLen += cryptoHandler.doFinal(out, outLen); + + final byte[] outBytes = Utils.truncate(out, outLen); + + //noinspection unchecked + return new CryptoResult(outBytes, cryptoHandler.getMasterKeys(), cryptoHandler.getHeaders()); + } + + /** + * Returns the equivalent to calling {@link #encryptData(MasterKeyProvider, byte[], Map)} with an + * empty {@code encryptionContext}. + */ + public > CryptoResult encryptData( + final MasterKeyProvider provider, final byte[] plaintext) { + return encryptData(provider, plaintext, EMPTY_MAP); + } + + /** + * Returns the equivalent to calling {@link #encryptData(CryptoMaterialsManager, byte[], Map)} + * with an empty {@code encryptionContext}. + */ + public CryptoResult encryptData( + final CryptoMaterialsManager materialsManager, final byte[] plaintext) { + return encryptData(materialsManager, plaintext, EMPTY_MAP); + } + + /** + * Calls {@link #encryptData(MasterKeyProvider, byte[], Map)} on the UTF-8 encoded bytes of {@code + * plaintext} and base64 encodes the result. + * + * @deprecated Use the {@link #encryptData(MasterKeyProvider, byte[], Map)} and {@link + * #decryptData(MasterKeyProvider, byte[])} APIs instead. {@code encryptString} and {@code + * decryptString} work as expected if you use them together. However, to work with other + * language implementations of the AWS Encryption SDK, you need to base64-decode the output of + * {@code encryptString} and base64-encode the input to {@code decryptString}. These + * deprecated APIs will be removed in the future. + */ + @Deprecated + public > CryptoResult encryptString( + final MasterKeyProvider provider, + final String plaintext, + final Map encryptionContext) { + //noinspection unchecked + return (CryptoResult) + encryptString(new DefaultCryptoMaterialsManager(provider), plaintext, encryptionContext); + } + + /** + * Calls {@link #encryptData(CryptoMaterialsManager, byte[], Map)} on the UTF-8 encoded bytes of + * {@code plaintext} and base64 encodes the result. + * + * @deprecated Use the {@link #encryptData(CryptoMaterialsManager, byte[], Map)} and {@link + * #decryptData(CryptoMaterialsManager, byte[])} APIs instead. {@code encryptString} and + * {@code decryptString} work as expected if you use them together. However, to work with + * other language implementations of the AWS Encryption SDK, you need to base64-decode the + * output of {@code encryptString} and base64-encode the input to {@code decryptString}. These + * deprecated APIs will be removed in the future. + */ + @Deprecated + public CryptoResult encryptString( + CryptoMaterialsManager materialsManager, + final String plaintext, + final Map encryptionContext) { + final CryptoResult ctBytes = + encryptData( + materialsManager, plaintext.getBytes(StandardCharsets.UTF_8), encryptionContext); + return new CryptoResult<>( + Utils.encodeBase64String(ctBytes.getResult()), + ctBytes.getMasterKeys(), + ctBytes.getHeaders()); + } + + /** + * Returns the equivalent to calling {@link #encryptString(MasterKeyProvider, String, Map)} with + * an empty {@code encryptionContext}. + * + * @deprecated Use the {@link #encryptData(MasterKeyProvider, byte[])} and {@link + * #decryptData(MasterKeyProvider, byte[])} APIs instead. {@code encryptString} and {@code + * decryptString} work as expected if you use them together. However, to work with other + * language implementations of the AWS Encryption SDK, you need to base64-decode the output of + * {@code encryptString} and base64-encode the input to {@code decryptString}. These + * deprecated APIs will be removed in the future. + */ + @Deprecated + public > CryptoResult encryptString( + final MasterKeyProvider provider, final String plaintext) { + return encryptString(provider, plaintext, EMPTY_MAP); + } + + /** + * Returns the equivalent to calling {@link #encryptString(CryptoMaterialsManager, String, Map)} + * with an empty {@code encryptionContext}. + * + * @deprecated Use the {@link #encryptData(CryptoMaterialsManager, byte[])} and {@link + * #decryptData(CryptoMaterialsManager, byte[])} APIs instead. {@code encryptString} and + * {@code decryptString} work as expected if you use them together. However, to work with + * other language implementations of the AWS Encryption SDK, you need to base64-decode the + * output of {@code encryptString} and base64-encode the input to {@code decryptString}. These + * deprecated APIs will be removed in the future. + */ + @Deprecated + public CryptoResult encryptString( + final CryptoMaterialsManager materialsManager, final String plaintext) { + return encryptString(materialsManager, plaintext, EMPTY_MAP); + } + + /** + * Decrypts the provided {@code ciphertext} by requesting that the {@code provider} unwrap any + * usable {@link DataKey} in the ciphertext and then decrypts the ciphertext using that {@code + * DataKey}. + */ + public > CryptoResult decryptData( + final MasterKeyProvider provider, final byte[] ciphertext) { + return decryptData( + Utils.assertNonNull(provider, "provider"), + new ParsedCiphertext(ciphertext, maxEncryptedDataKeys_)); + } + + /** + * Decrypts the provided ciphertext by delegating to the provided materialsManager to obtain the + * decrypted {@link DataKey}. + * + * @param materialsManager the {@link CryptoMaterialsManager} to use for decryption operations. + * @param ciphertext the ciphertext to attempt to decrypt. + * @return the {@link CryptoResult} with the decrypted data. + */ + public CryptoResult decryptData( + final CryptoMaterialsManager materialsManager, final byte[] ciphertext) { + return decryptData( + Utils.assertNonNull(materialsManager, "materialsManager"), + new ParsedCiphertext(ciphertext, maxEncryptedDataKeys_)); + } + + /** @see #decryptData(MasterKeyProvider, byte[]) */ + @SuppressWarnings("unchecked") + public > CryptoResult decryptData( + final MasterKeyProvider provider, final ParsedCiphertext ciphertext) { + Utils.assertNonNull(provider, "provider"); + return (CryptoResult) + decryptData(new DefaultCryptoMaterialsManager(provider), ciphertext); + } + + /** @see #decryptData(CryptoMaterialsManager, byte[]) */ + public CryptoResult decryptData( + final CryptoMaterialsManager materialsManager, final ParsedCiphertext ciphertext) { + Utils.assertNonNull(materialsManager, "materialsManager"); + + final MessageCryptoHandler cryptoHandler = + DecryptionHandler.create( + materialsManager, + ciphertext, + commitmentPolicy_, + SignaturePolicy.AllowEncryptAllowDecrypt, + maxEncryptedDataKeys_); + + final byte[] ciphertextBytes = ciphertext.getCiphertext(); + final int contentLen = ciphertextBytes.length - ciphertext.getOffset(); + final int outSizeEstimate = cryptoHandler.estimateOutputSize(contentLen); + final byte[] out = new byte[outSizeEstimate]; + final ProcessingSummary processed = + cryptoHandler.processBytes(ciphertextBytes, ciphertext.getOffset(), contentLen, out, 0); + if (processed.getBytesProcessed() != contentLen) { + throw new BadCiphertextException( + "Unable to process entire ciphertext. May have trailing data."); + } + int outLen = processed.getBytesWritten(); + outLen += cryptoHandler.doFinal(out, outLen); + + final byte[] outBytes = Utils.truncate(out, outLen); + + //noinspection unchecked + return new CryptoResult(outBytes, cryptoHandler.getMasterKeys(), cryptoHandler.getHeaders()); + } + + /** + * Base64 decodes the {@code ciphertext} prior to decryption and then treats the results as a + * UTF-8 encoded string. + * + * @see #decryptData(MasterKeyProvider, byte[]) + * @deprecated Use the {@link #decryptData(MasterKeyProvider, byte[])} and {@link + * #encryptData(MasterKeyProvider, byte[], Map)} APIs instead. {@code encryptString} and + * {@code decryptString} work as expected if you use them together. However, to work with + * other language implementations of the AWS Encryption SDK, you need to base64-decode the + * output of {@code encryptString} and base64-encode the input to {@code decryptString}. These + * deprecated APIs will be removed in the future. + */ + @Deprecated + @SuppressWarnings("unchecked") + public > CryptoResult decryptString( + final MasterKeyProvider provider, final String ciphertext) { + return (CryptoResult) + decryptString(new DefaultCryptoMaterialsManager(provider), ciphertext); + } + + /** + * Base64 decodes the {@code ciphertext} prior to decryption and then treats the results as a + * UTF-8 encoded string. + * + * @see #decryptData(CryptoMaterialsManager, byte[]) + * @deprecated Use the {@link #decryptData(CryptoMaterialsManager, byte[])} and {@link + * #encryptData(CryptoMaterialsManager, byte[], Map)} APIs instead. {@code encryptString} and + * {@code decryptString} work as expected if you use them together. However, to work with + * other language implementations of the AWS Encryption SDK, you need to base64-decode the + * output of {@code encryptString} and base64-encode the input to {@code decryptString}. These + * deprecated APIs will be removed in the future. + */ + @Deprecated + public CryptoResult decryptString( + final CryptoMaterialsManager provider, final String ciphertext) { + Utils.assertNonNull(provider, "provider"); + final byte[] ciphertextBytes; + try { + ciphertextBytes = Utils.decodeBase64String(Utils.assertNonNull(ciphertext, "ciphertext")); + } catch (final IllegalArgumentException ex) { + throw new BadCiphertextException("Invalid base 64", ex); + } + final CryptoResult ptBytes = decryptData(provider, ciphertextBytes); + //noinspection unchecked + return new CryptoResult( + new String(ptBytes.getResult(), StandardCharsets.UTF_8), + ptBytes.getMasterKeys(), + ptBytes.getHeaders()); + } + + /** + * Returns a {@link CryptoOutputStream} which encrypts the data prior to passing it onto the + * underlying {@link OutputStream}. + * + * @see #encryptData(MasterKeyProvider, byte[], Map) + * @see javax.crypto.CipherOutputStream + */ + public > CryptoOutputStream createEncryptingStream( + final MasterKeyProvider provider, + final OutputStream os, + final Map encryptionContext) { + //noinspection unchecked + return (CryptoOutputStream) + createEncryptingStream(new DefaultCryptoMaterialsManager(provider), os, encryptionContext); + } + + /** + * Returns a {@link CryptoOutputStream} which encrypts the data prior to passing it onto the + * underlying {@link OutputStream}. + * + * @see #encryptData(MasterKeyProvider, byte[], Map) + * @see javax.crypto.CipherOutputStream + */ + public CryptoOutputStream createEncryptingStream( + final CryptoMaterialsManager materialsManager, + final OutputStream os, + final Map encryptionContext) { + return new CryptoOutputStream<>( + os, getEncryptingStreamHandler(materialsManager, encryptionContext)); + } + + /** + * Returns the equivalent to calling {@link #createEncryptingStream(MasterKeyProvider, + * OutputStream, Map)} with an empty {@code encryptionContext}. + */ + public > CryptoOutputStream createEncryptingStream( + final MasterKeyProvider provider, final OutputStream os) { + return createEncryptingStream(provider, os, EMPTY_MAP); + } + + /** + * Returns the equivalent to calling {@link #createEncryptingStream(CryptoMaterialsManager, + * OutputStream, Map)} with an empty {@code encryptionContext}. + */ + public CryptoOutputStream createEncryptingStream( + final CryptoMaterialsManager materialsManager, final OutputStream os) { + return createEncryptingStream(materialsManager, os, EMPTY_MAP); + } + + /** + * Returns a {@link CryptoInputStream} which encrypts the data after reading it from the + * underlying {@link InputStream}. + * + * @see #encryptData(MasterKeyProvider, byte[], Map) + * @see javax.crypto.CipherInputStream + */ + public > CryptoInputStream createEncryptingStream( + final MasterKeyProvider provider, + final InputStream is, + final Map encryptionContext) { + //noinspection unchecked + return (CryptoInputStream) + createEncryptingStream(new DefaultCryptoMaterialsManager(provider), is, encryptionContext); + } + + /** + * Returns a {@link CryptoInputStream} which encrypts the data after reading it from the + * underlying {@link InputStream}. + * + * @see #encryptData(MasterKeyProvider, byte[], Map) + * @see javax.crypto.CipherInputStream + */ + public CryptoInputStream createEncryptingStream( + CryptoMaterialsManager materialsManager, + final InputStream is, + final Map encryptionContext) { + final MessageCryptoHandler cryptoHandler = + getEncryptingStreamHandler(materialsManager, encryptionContext); + + return new CryptoInputStream<>(is, cryptoHandler); + } + + /** + * Returns the equivalent to calling {@link #createEncryptingStream(MasterKeyProvider, + * InputStream, Map)} with an empty {@code encryptionContext}. + */ + public > CryptoInputStream createEncryptingStream( + final MasterKeyProvider provider, final InputStream is) { + return createEncryptingStream(provider, is, EMPTY_MAP); + } + + /** + * Returns the equivalent to calling {@link #createEncryptingStream(CryptoMaterialsManager, + * InputStream, Map)} with an empty {@code encryptionContext}. + */ + public CryptoInputStream createEncryptingStream( + final CryptoMaterialsManager materialsManager, final InputStream is) { + return createEncryptingStream(materialsManager, is, EMPTY_MAP); + } + + /** + * Returns a {@link CryptoOutputStream} which decrypts the data prior to passing it onto the + * underlying {@link OutputStream}. This version only accepts unsigned messages. + * + * @see #decryptData(MasterKeyProvider, byte[]) + * @see javax.crypto.CipherOutputStream + */ + public > CryptoOutputStream createUnsignedMessageDecryptingStream( + final MasterKeyProvider provider, final OutputStream os) { + final MessageCryptoHandler cryptoHandler = + DecryptionHandler.create( + provider, + commitmentPolicy_, + SignaturePolicy.AllowEncryptForbidDecrypt, + maxEncryptedDataKeys_); + return new CryptoOutputStream(os, cryptoHandler); + } + + /** + * Returns a {@link CryptoInputStream} which decrypts the data after reading it from the + * underlying {@link InputStream}. This version only accepts unsigned messages. + * + * @see #decryptData(MasterKeyProvider, byte[]) + * @see javax.crypto.CipherInputStream + */ + public > CryptoInputStream createUnsignedMessageDecryptingStream( + final MasterKeyProvider provider, final InputStream is) { + final MessageCryptoHandler cryptoHandler = + DecryptionHandler.create( + provider, + commitmentPolicy_, + SignaturePolicy.AllowEncryptForbidDecrypt, + maxEncryptedDataKeys_); + return new CryptoInputStream(is, cryptoHandler); + } + + /** + * Returns a {@link CryptoOutputStream} which decrypts the data prior to passing it onto the + * underlying {@link OutputStream}. This version only accepts unsigned messages. + * + * @see #decryptData(CryptoMaterialsManager, byte[]) + * @see javax.crypto.CipherOutputStream + */ + public CryptoOutputStream createUnsignedMessageDecryptingStream( + final CryptoMaterialsManager materialsManager, final OutputStream os) { + final MessageCryptoHandler cryptoHandler = + DecryptionHandler.create( + materialsManager, + commitmentPolicy_, + SignaturePolicy.AllowEncryptForbidDecrypt, + maxEncryptedDataKeys_); + return new CryptoOutputStream(os, cryptoHandler); + } + + /** + * Returns a {@link CryptoInputStream} which decrypts the data after reading it from the + * underlying {@link InputStream}. This version only accepts unsigned messages. + * + * @see #encryptData(CryptoMaterialsManager, byte[], Map) + * @see javax.crypto.CipherInputStream + */ + public CryptoInputStream createUnsignedMessageDecryptingStream( + final CryptoMaterialsManager materialsManager, final InputStream is) { + final MessageCryptoHandler cryptoHandler = + DecryptionHandler.create( + materialsManager, + commitmentPolicy_, + SignaturePolicy.AllowEncryptForbidDecrypt, + maxEncryptedDataKeys_); + return new CryptoInputStream(is, cryptoHandler); + } + + /** + * Returns a {@link CryptoOutputStream} which decrypts the data prior to passing it onto the + * underlying {@link OutputStream}. + * + *

Note that if the encrypted message includes a trailing signature, by necessity it cannot be + * verified until after the decrypted plaintext has been released to the underlying {@link + * OutputStream}! This behavior can be avoided by using the non-streaming + * #decryptData(MasterKeyProvider, byte[]) method instead, or + * #createUnsignedMessageDecryptingStream(MasterKeyProvider, OutputStream) if you do not need to + * decrypt signed messages. + * + * @see #decryptData(MasterKeyProvider, byte[]) + * @see #createUnsignedMessageDecryptingStream(MasterKeyProvider, OutputStream) + * @see javax.crypto.CipherOutputStream + */ + public > CryptoOutputStream createDecryptingStream( + final MasterKeyProvider provider, final OutputStream os) { + final MessageCryptoHandler cryptoHandler = + DecryptionHandler.create( + provider, + commitmentPolicy_, + SignaturePolicy.AllowEncryptAllowDecrypt, + maxEncryptedDataKeys_); + return new CryptoOutputStream(os, cryptoHandler); + } + + /** + * Returns a {@link CryptoInputStream} which decrypts the data after reading it from the + * underlying {@link InputStream}. + * + *

Note that if the encrypted message includes a trailing signature, by necessity it cannot be + * verified until after the decrypted plaintext has been produced from the {@link InputStream}! + * This behavior can be avoided by using the non-streaming #decryptData(MasterKeyProvider, byte[]) + * method instead, or #createUnsignedMessageDecryptingStream(MasterKeyProvider, InputStream) if + * you do not need to decrypt signed messages. + * + * @see #decryptData(MasterKeyProvider, byte[]) + * @see #createUnsignedMessageDecryptingStream(MasterKeyProvider, InputStream) + * @see javax.crypto.CipherInputStream + */ + public > CryptoInputStream createDecryptingStream( + final MasterKeyProvider provider, final InputStream is) { + final MessageCryptoHandler cryptoHandler = + DecryptionHandler.create( + provider, + commitmentPolicy_, + SignaturePolicy.AllowEncryptAllowDecrypt, + maxEncryptedDataKeys_); + return new CryptoInputStream(is, cryptoHandler); + } + + /** + * Returns a {@link CryptoOutputStream} which decrypts the data prior to passing it onto the + * underlying {@link OutputStream}. + * + *

Note that if the encrypted message includes a trailing signature, by necessity it cannot be + * verified until after the decrypted plaintext has been released to the underlying {@link + * OutputStream}! This behavior can be avoided by using the non-streaming + * #decryptData(CryptoMaterialsManager, byte[]) method instead, or + * #createUnsignedMessageDecryptingStream(CryptoMaterialsManager, OutputStream) if you do not need + * to decrypt signed messages. + * + * @see #decryptData(CryptoMaterialsManager, byte[]) + * @see #createUnsignedMessageDecryptingStream(CryptoMaterialsManager, OutputStream) + * @see javax.crypto.CipherOutputStream + */ + public CryptoOutputStream createDecryptingStream( + final CryptoMaterialsManager materialsManager, final OutputStream os) { + final MessageCryptoHandler cryptoHandler = + DecryptionHandler.create( + materialsManager, + commitmentPolicy_, + SignaturePolicy.AllowEncryptAllowDecrypt, + maxEncryptedDataKeys_); + return new CryptoOutputStream(os, cryptoHandler); + } + + /** + * Returns a {@link CryptoInputStream} which decrypts the data after reading it from the + * underlying {@link InputStream}. + * + *

Note that if the encrypted message includes a trailing signature, by necessity it cannot be + * verified until after the decrypted plaintext has been produced from the {@link InputStream}! + * This behavior can be avoided by using the non-streaming #decryptData(CryptoMaterialsManager, + * byte[]) method instead, or #createUnsignedMessageDecryptingStream(CryptoMaterialsManager, + * InputStream) if you do not need to decrypt signed messages. + * + * @see #decryptData(CryptoMaterialsManager, byte[]) + * @see #createUnsignedMessageDecryptingStream(CryptoMaterialsManager, InputStream) + * @see javax.crypto.CipherInputStream + */ + public CryptoInputStream createDecryptingStream( + final CryptoMaterialsManager materialsManager, final InputStream is) { + final MessageCryptoHandler cryptoHandler = + DecryptionHandler.create( + materialsManager, + commitmentPolicy_, + SignaturePolicy.AllowEncryptAllowDecrypt, + maxEncryptedDataKeys_); + return new CryptoInputStream(is, cryptoHandler); + } + + private MessageCryptoHandler getEncryptingStreamHandler( + CryptoMaterialsManager materialsManager, Map encryptionContext) { + Utils.assertNonNull(materialsManager, "materialsManager"); + Utils.assertNonNull(encryptionContext, "encryptionContext"); + + EncryptionMaterialsRequest.Builder requestBuilder = + EncryptionMaterialsRequest.newBuilder() + .setContext(encryptionContext) + .setRequestedAlgorithm(getEncryptionAlgorithm()) + .setCommitmentPolicy(commitmentPolicy_); + + return new LazyMessageCryptoHandler( + info -> { + // Hopefully we know the input size now, so we can pass it along to the CMM. + if (info.getMaxInputSize() != -1) { + requestBuilder.setPlaintextSize(info.getMaxInputSize()); + } + + return new EncryptionHandler( + getEncryptionFrameSize(), + checkMaxEncryptedDataKeys( + checkAlgorithm(materialsManager.getMaterialsForEncrypt(requestBuilder.build()))), + commitmentPolicy_); }); - } + } - private EncryptionMaterials checkAlgorithm(EncryptionMaterials result) { - if (encryptionAlgorithm_ != null && result.getAlgorithm() != encryptionAlgorithm_) { - throw new AwsCryptoException( - String.format("Materials manager ignored requested algorithm; algorithm %s was set on AwsCrypto " + - "but %s was selected", encryptionAlgorithm_, result.getAlgorithm()) - ); - } - - return result; + private EncryptionMaterials checkAlgorithm(EncryptionMaterials result) { + if (encryptionAlgorithm_ != null && result.getAlgorithm() != encryptionAlgorithm_) { + throw new AwsCryptoException( + String.format( + "Materials manager ignored requested algorithm; algorithm %s was set on AwsCrypto " + + "but %s was selected", + encryptionAlgorithm_, result.getAlgorithm())); } - private EncryptionMaterials checkMaxEncryptedDataKeys(EncryptionMaterials materials) { - if (maxEncryptedDataKeys_ > 0 && materials.getEncryptedDataKeys().size() > maxEncryptedDataKeys_) { - throw new AwsCryptoException("Encrypted data keys exceed maxEncryptedDataKeys"); - } - return materials; + return result; + } + + private EncryptionMaterials checkMaxEncryptedDataKeys(EncryptionMaterials materials) { + if (maxEncryptedDataKeys_ > 0 + && materials.getEncryptedDataKeys().size() > maxEncryptedDataKeys_) { + throw new AwsCryptoException("Encrypted data keys exceed maxEncryptedDataKeys"); } + return materials; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/CommitmentPolicy.java b/src/main/java/com/amazonaws/encryptionsdk/CommitmentPolicy.java index 295512e2c..d17c078b4 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/CommitmentPolicy.java +++ b/src/main/java/com/amazonaws/encryptionsdk/CommitmentPolicy.java @@ -3,34 +3,34 @@ package com.amazonaws.encryptionsdk; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; - public enum CommitmentPolicy { - ForbidEncryptAllowDecrypt, - RequireEncryptAllowDecrypt, - RequireEncryptRequireDecrypt; + ForbidEncryptAllowDecrypt, + RequireEncryptAllowDecrypt, + RequireEncryptRequireDecrypt; - public boolean algorithmAllowedForEncrypt(CryptoAlgorithm algorithm) { - switch (this) { - case ForbidEncryptAllowDecrypt: - return !algorithm.isCommitting(); - case RequireEncryptAllowDecrypt: - case RequireEncryptRequireDecrypt: - return algorithm.isCommitting(); - default: - throw new UnsupportedOperationException("Support for commitment policy " + this + " not yet built."); - } + public boolean algorithmAllowedForEncrypt(CryptoAlgorithm algorithm) { + switch (this) { + case ForbidEncryptAllowDecrypt: + return !algorithm.isCommitting(); + case RequireEncryptAllowDecrypt: + case RequireEncryptRequireDecrypt: + return algorithm.isCommitting(); + default: + throw new UnsupportedOperationException( + "Support for commitment policy " + this + " not yet built."); } + } - public boolean algorithmAllowedForDecrypt(CryptoAlgorithm algorithm) { - switch (this) { - case ForbidEncryptAllowDecrypt: - case RequireEncryptAllowDecrypt: - return true; - case RequireEncryptRequireDecrypt: - return algorithm.isCommitting(); - default: - throw new UnsupportedOperationException("Support for commitment policy " + this + " not yet built."); - } + public boolean algorithmAllowedForDecrypt(CryptoAlgorithm algorithm) { + switch (this) { + case ForbidEncryptAllowDecrypt: + case RequireEncryptAllowDecrypt: + return true; + case RequireEncryptRequireDecrypt: + return algorithm.isCommitting(); + default: + throw new UnsupportedOperationException( + "Support for commitment policy " + this + " not yet built."); } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/CryptoAlgorithm.java b/src/main/java/com/amazonaws/encryptionsdk/CryptoAlgorithm.java index 2ee8bec6c..821718652 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/CryptoAlgorithm.java +++ b/src/main/java/com/amazonaws/encryptionsdk/CryptoAlgorithm.java @@ -3,6 +3,11 @@ package com.amazonaws.encryptionsdk; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; +import com.amazonaws.encryptionsdk.internal.CommittedKey; +import com.amazonaws.encryptionsdk.internal.Constants; +import com.amazonaws.encryptionsdk.internal.HmacKeyDerivationFunction; +import com.amazonaws.encryptionsdk.model.CiphertextHeaders; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.security.InvalidKeyException; @@ -11,437 +16,515 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.Map; - import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; -import com.amazonaws.encryptionsdk.internal.HmacKeyDerivationFunction; - -import com.amazonaws.encryptionsdk.internal.Constants; -import com.amazonaws.encryptionsdk.internal.CommittedKey; -import com.amazonaws.encryptionsdk.model.CiphertextHeaders; - /** * Describes the cryptographic algorithms available for use in this library. * - *

- * Format: CryptoAlgorithm(block size, nonce length, tag length, max content length, key algo, key - * length, short value representing this algorithm, trailing signature alg, trailing signature + *

Format: CryptoAlgorithm(block size, nonce length, tag length, max content length, key algo, + * key length, short value representing this algorithm, trailing signature alg, trailing signature * length) */ public enum CryptoAlgorithm { - /** - * AES-GCM 128 - */ - ALG_AES_128_GCM_IV12_TAG16_NO_KDF(1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 16, 0x0014, "AES", 16, false), - /** - * AES-GCM 192 - */ - ALG_AES_192_GCM_IV12_TAG16_NO_KDF(1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 24, 0x0046, "AES", 24, false), - /** - * AES-GCM 256 - */ - ALG_AES_256_GCM_IV12_TAG16_NO_KDF(1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 32, 0x0078, "AES", 32, false), - /** - * AES-GCM 128 with HKDF-SHA256 - */ - ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256(1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 16, 0x0114, "HkdfSHA256", - 16, true), - /** - * AES-GCM 192 - */ - ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA256(1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 24, 0x0146, "HkdfSHA256", - 24, true), - /** - * AES-GCM 256 - */ - ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256(1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 32, 0x0178, "HkdfSHA256", - 32, true), - - /** - * AES-GCM 128 with ECDSA (SHA256 with the secp256r1 curve) - */ - ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256(1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 16, 0x0214, - "HkdfSHA256", 16, - true, "SHA256withECDSA", 71), - /** - * AES-GCM 192 with ECDSA (SHA384 with the secp384r1 curve) - */ - ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384(1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 24, 0x0346, - "HkdfSHA384", 24, - true, "SHA384withECDSA", 103), - /** - * AES-GCM 256 with ECDSA (SHA384 with the secp384r1 curve) - */ - ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384(1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 32, 0x0378, - "HkdfSHA384", 32, - true, "SHA384withECDSA", 103), - /** - * AES-GCM 256 with key commitment - * Note: 1.7.0 of this library only supports decryption of using this crypto algorithm and does not support encryption with this algorithm - */ - ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY(2, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 32, 0x0478, - "HkdfSHA512", 32, true, null, 0, "HkdfSHA512", 32, 32, 32), - /** - * AES-GCM 256 with ECDSA (SHA384 with the secp384r1 curve) and key commitment - * Note: 1.7.0 of this library only supports decryption of using this crypto algorithm and does not support encryption with this algorithm - */ - ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384(2, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 32, 0x0578, - "HkdfSHA512", 32, true, "SHA384withECDSA", 103, "HkdfSHA512", 32, 32, 32); - - private final byte messageFormatVersion_; - private final int blockSizeBits_; - private final byte nonceLenBytes_; - private final int tagLenBytes_; - private final long maxContentLen_; - private final String keyAlgo_; - private final int keyLenBytes_; - private final short value_; - private final String trailingSigAlgo_; - private final short trailingSigLen_; - private final String dataKeyAlgo_; - private final int dataKeyLen_; - private final boolean safeToCache_; - private final String keyCommitmentAlgo_; - private final int commitmentLength_; - private final int commitmentNonceLength_; - private final int suiteDataLength_; - - private static final byte VERSION_1 = (byte) 1; - private static final byte VERSION_2 = (byte) 2; - private static final int VERSION_1_MESSAGE_ID_LEN = 16; - private static final int VERSION_2_MESSAGE_ID_LEN = 32; - - /* - * Create a mapping between the CiphertextType object and its byte value representation. Make - * this a static method so the map is created when the object is created. This enables fast - * lookups of the CryptoAlgorithm given its short value representation and message format version. - */ - private static final Map ID_MAPPING = new HashMap<>(); - - static { - for (final CryptoAlgorithm s : EnumSet.allOf(CryptoAlgorithm.class)) { - ID_MAPPING.put(fieldsToLookupKey(s.messageFormatVersion_, s.value_), s); - } - } - - CryptoAlgorithm( - final int messageFormatVersion, - final int blockSizeBits, final int nonceLenBytes, final int tagLenBytes, - final long maxContentLen, final String keyAlgo, final int keyLenBytes, final int value, - final String dataKeyAlgo, final int dataKeyLen, boolean safeToCache - ) { - this(messageFormatVersion, blockSizeBits, nonceLenBytes, tagLenBytes, - maxContentLen, keyAlgo, keyLenBytes, value, - dataKeyAlgo, dataKeyLen, safeToCache, null, 0); - - } - - CryptoAlgorithm( - final int messageFormatVersion, - final int blockSizeBits, final int nonceLenBytes, final int tagLenBytes, - final long maxContentLen, final String keyAlgo, final int keyLenBytes, final int value, - final String dataKeyAlgo, final int dataKeyLen, - boolean safeToCache, final String trailingSignatureAlgo, final int trailingSignatureLength - ) { - this(messageFormatVersion, blockSizeBits, nonceLenBytes, tagLenBytes, - maxContentLen, keyAlgo, keyLenBytes, value, - dataKeyAlgo, dataKeyLen, safeToCache, trailingSignatureAlgo, trailingSignatureLength, - null, 0, 0, 0); - } - - CryptoAlgorithm( - final int messageFormatVersion, - final int blockSizeBits, final int nonceLenBytes, final int tagLenBytes, - final long maxContentLen, final String keyAlgo, final int keyLenBytes, final int value, - final String dataKeyAlgo, final int dataKeyLen, - boolean safeToCache, final String trailingSignatureAlgo, final int trailingSignatureLength, - final String keyCommitmentAlgo, final int commitmentLength, final int commitmentNonceLength, - final int suiteDataLength - ) { - if ((messageFormatVersion & 0xFF) != messageFormatVersion) { - throw new IllegalArgumentException("Invalid messageFormatVersion: " + messageFormatVersion); - } - // All non-null key commitment algs must be the same as the kdf alg - if (keyCommitmentAlgo != null && !keyCommitmentAlgo.equals(dataKeyAlgo)) { - throw new IllegalArgumentException("Invalid keyCommitmentAlgo " + keyCommitmentAlgo + - ". Must be equal to dataKeyAlgo " + dataKeyAlgo + "."); - } - messageFormatVersion_ = (byte) (messageFormatVersion & 0xFF); - blockSizeBits_ = blockSizeBits; - nonceLenBytes_ = (byte) nonceLenBytes; - tagLenBytes_ = tagLenBytes; - keyAlgo_ = keyAlgo; - keyLenBytes_ = keyLenBytes; - maxContentLen_ = maxContentLen; - safeToCache_ = safeToCache; - if (value > Short.MAX_VALUE || value < Short.MIN_VALUE) { - throw new IllegalArgumentException("Invalid value " + value); - } - value_ = (short) value; - dataKeyAlgo_ = dataKeyAlgo; - dataKeyLen_ = dataKeyLen; - trailingSigAlgo_ = trailingSignatureAlgo; - if (trailingSignatureLength > Short.MAX_VALUE || trailingSignatureLength < 0) { - throw new IllegalArgumentException("Invalid value " + trailingSignatureLength); - } - trailingSigLen_ = (short) trailingSignatureLength; - keyCommitmentAlgo_ = keyCommitmentAlgo; - commitmentLength_ = commitmentLength; - commitmentNonceLength_ = commitmentNonceLength; - suiteDataLength_ = suiteDataLength; - } - - private static int fieldsToLookupKey(final byte messageFormatVersion, final short algorithmId) { - // We pack the message format version and algorithm id into a single value. - // Since the algorithm ID is a short and thus 16 bits long, we'll just - // left shift the message format version by that amount. - // The message format version is 8 bits, so this totals 24 bits and fits - // within a standard 32 bit integer. - return (messageFormatVersion << 16) | algorithmId; - } - - /** - * Returns the CryptoAlgorithm object that matches the given value - * assuming a message format version of 1. - * - * @param value - * the value of the object - * @return the CryptoAlgorithm object that matches the given value, null if no match is found. - * @deprecated See {@link #deserialize(byte, short)} - */ - public static CryptoAlgorithm deserialize(final byte messageFormatVersion, final short value) { - return ID_MAPPING.get(fieldsToLookupKey(messageFormatVersion, value)); - } - - /** - * Returns the length of the message Id in the header for this algorithm. - */ - public int getMessageIdLength() { - // For now this is a derived value rather than stored explicitly - switch (messageFormatVersion_) { - case VERSION_1: - return VERSION_1_MESSAGE_ID_LEN; - case VERSION_2: - return VERSION_2_MESSAGE_ID_LEN; - default: - throw new UnsupportedOperationException("Support for version " + messageFormatVersion_ + " not yet built."); - } - } - - /** - * Returns the header nonce to use with this algorithm. - * null indicates that the header nonce is not a parameter of the algorithm, - * and is instead stored as part of the message header. - */ - public byte[] getHeaderNonce() { - // For now this is a derived value rather than stored explicitly - switch (messageFormatVersion_) { - case VERSION_1: - return null; - case VERSION_2: - // V2 explicitly uses an IV of 0 in the header - return new byte[nonceLenBytes_]; - default: - throw new UnsupportedOperationException("Support for version " + messageFormatVersion_ + " not yet built."); - } + /** AES-GCM 128 */ + ALG_AES_128_GCM_IV12_TAG16_NO_KDF( + 1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 16, 0x0014, "AES", 16, false), + /** AES-GCM 192 */ + ALG_AES_192_GCM_IV12_TAG16_NO_KDF( + 1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 24, 0x0046, "AES", 24, false), + /** AES-GCM 256 */ + ALG_AES_256_GCM_IV12_TAG16_NO_KDF( + 1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 32, 0x0078, "AES", 32, false), + /** AES-GCM 128 with HKDF-SHA256 */ + ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256( + 1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 16, 0x0114, "HkdfSHA256", 16, true), + /** AES-GCM 192 */ + ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA256( + 1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 24, 0x0146, "HkdfSHA256", 24, true), + /** AES-GCM 256 */ + ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256( + 1, 128, 12, 16, Constants.GCM_MAX_CONTENT_LEN, "AES", 32, 0x0178, "HkdfSHA256", 32, true), + + /** AES-GCM 128 with ECDSA (SHA256 with the secp256r1 curve) */ + ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256( + 1, + 128, + 12, + 16, + Constants.GCM_MAX_CONTENT_LEN, + "AES", + 16, + 0x0214, + "HkdfSHA256", + 16, + true, + "SHA256withECDSA", + 71), + /** AES-GCM 192 with ECDSA (SHA384 with the secp384r1 curve) */ + ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384( + 1, + 128, + 12, + 16, + Constants.GCM_MAX_CONTENT_LEN, + "AES", + 24, + 0x0346, + "HkdfSHA384", + 24, + true, + "SHA384withECDSA", + 103), + /** AES-GCM 256 with ECDSA (SHA384 with the secp384r1 curve) */ + ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384( + 1, + 128, + 12, + 16, + Constants.GCM_MAX_CONTENT_LEN, + "AES", + 32, + 0x0378, + "HkdfSHA384", + 32, + true, + "SHA384withECDSA", + 103), + /** + * AES-GCM 256 with key commitment Note: 1.7.0 of this library only supports decryption of using + * this crypto algorithm and does not support encryption with this algorithm + */ + ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY( + 2, + 128, + 12, + 16, + Constants.GCM_MAX_CONTENT_LEN, + "AES", + 32, + 0x0478, + "HkdfSHA512", + 32, + true, + null, + 0, + "HkdfSHA512", + 32, + 32, + 32), + /** + * AES-GCM 256 with ECDSA (SHA384 with the secp384r1 curve) and key commitment Note: 1.7.0 of this + * library only supports decryption of using this crypto algorithm and does not support encryption + * with this algorithm + */ + ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384( + 2, + 128, + 12, + 16, + Constants.GCM_MAX_CONTENT_LEN, + "AES", + 32, + 0x0578, + "HkdfSHA512", + 32, + true, + "SHA384withECDSA", + 103, + "HkdfSHA512", + 32, + 32, + 32); + + private final byte messageFormatVersion_; + private final int blockSizeBits_; + private final byte nonceLenBytes_; + private final int tagLenBytes_; + private final long maxContentLen_; + private final String keyAlgo_; + private final int keyLenBytes_; + private final short value_; + private final String trailingSigAlgo_; + private final short trailingSigLen_; + private final String dataKeyAlgo_; + private final int dataKeyLen_; + private final boolean safeToCache_; + private final String keyCommitmentAlgo_; + private final int commitmentLength_; + private final int commitmentNonceLength_; + private final int suiteDataLength_; + + private static final byte VERSION_1 = (byte) 1; + private static final byte VERSION_2 = (byte) 2; + private static final int VERSION_1_MESSAGE_ID_LEN = 16; + private static final int VERSION_2_MESSAGE_ID_LEN = 32; + + /* + * Create a mapping between the CiphertextType object and its byte value representation. Make + * this a static method so the map is created when the object is created. This enables fast + * lookups of the CryptoAlgorithm given its short value representation and message format version. + */ + private static final Map ID_MAPPING = new HashMap<>(); + + static { + for (final CryptoAlgorithm s : EnumSet.allOf(CryptoAlgorithm.class)) { + ID_MAPPING.put(fieldsToLookupKey(s.messageFormatVersion_, s.value_), s); } - - /** - * Returns the message format version associated with this algorithm suite. - */ - public byte getMessageFormatVersion() { - return messageFormatVersion_; + } + + CryptoAlgorithm( + final int messageFormatVersion, + final int blockSizeBits, + final int nonceLenBytes, + final int tagLenBytes, + final long maxContentLen, + final String keyAlgo, + final int keyLenBytes, + final int value, + final String dataKeyAlgo, + final int dataKeyLen, + boolean safeToCache) { + this( + messageFormatVersion, + blockSizeBits, + nonceLenBytes, + tagLenBytes, + maxContentLen, + keyAlgo, + keyLenBytes, + value, + dataKeyAlgo, + dataKeyLen, + safeToCache, + null, + 0); + } + + CryptoAlgorithm( + final int messageFormatVersion, + final int blockSizeBits, + final int nonceLenBytes, + final int tagLenBytes, + final long maxContentLen, + final String keyAlgo, + final int keyLenBytes, + final int value, + final String dataKeyAlgo, + final int dataKeyLen, + boolean safeToCache, + final String trailingSignatureAlgo, + final int trailingSignatureLength) { + this( + messageFormatVersion, + blockSizeBits, + nonceLenBytes, + tagLenBytes, + maxContentLen, + keyAlgo, + keyLenBytes, + value, + dataKeyAlgo, + dataKeyLen, + safeToCache, + trailingSignatureAlgo, + trailingSignatureLength, + null, + 0, + 0, + 0); + } + + CryptoAlgorithm( + final int messageFormatVersion, + final int blockSizeBits, + final int nonceLenBytes, + final int tagLenBytes, + final long maxContentLen, + final String keyAlgo, + final int keyLenBytes, + final int value, + final String dataKeyAlgo, + final int dataKeyLen, + boolean safeToCache, + final String trailingSignatureAlgo, + final int trailingSignatureLength, + final String keyCommitmentAlgo, + final int commitmentLength, + final int commitmentNonceLength, + final int suiteDataLength) { + if ((messageFormatVersion & 0xFF) != messageFormatVersion) { + throw new IllegalArgumentException("Invalid messageFormatVersion: " + messageFormatVersion); } - - /** - * Returns the block size of this algorithm in bytes. - */ - public int getBlockSize() { - return blockSizeBits_ / 8; + // All non-null key commitment algs must be the same as the kdf alg + if (keyCommitmentAlgo != null && !keyCommitmentAlgo.equals(dataKeyAlgo)) { + throw new IllegalArgumentException( + "Invalid keyCommitmentAlgo " + + keyCommitmentAlgo + + ". Must be equal to dataKeyAlgo " + + dataKeyAlgo + + "."); } - - /** - * Returns the nonce length used in this algorithm in bytes. - */ - public byte getNonceLen() { - return nonceLenBytes_; + messageFormatVersion_ = (byte) (messageFormatVersion & 0xFF); + blockSizeBits_ = blockSizeBits; + nonceLenBytes_ = (byte) nonceLenBytes; + tagLenBytes_ = tagLenBytes; + keyAlgo_ = keyAlgo; + keyLenBytes_ = keyLenBytes; + maxContentLen_ = maxContentLen; + safeToCache_ = safeToCache; + if (value > Short.MAX_VALUE || value < Short.MIN_VALUE) { + throw new IllegalArgumentException("Invalid value " + value); } - - /** - * Returns the tag length used in this algorithm in bytes. - */ - public int getTagLen() { - return tagLenBytes_; - } - - /** - * Returns the maximum content length in bytes that can be processed under a single data key in - * this algorithm. - */ - public long getMaxContentLen() { - return maxContentLen_; + value_ = (short) value; + dataKeyAlgo_ = dataKeyAlgo; + dataKeyLen_ = dataKeyLen; + trailingSigAlgo_ = trailingSignatureAlgo; + if (trailingSignatureLength > Short.MAX_VALUE || trailingSignatureLength < 0) { + throw new IllegalArgumentException("Invalid value " + trailingSignatureLength); } - - /** - * Returns the algorithm used for encrypting the plaintext data. - */ - public String getKeyAlgo() { - return keyAlgo_; + trailingSigLen_ = (short) trailingSignatureLength; + keyCommitmentAlgo_ = keyCommitmentAlgo; + commitmentLength_ = commitmentLength; + commitmentNonceLength_ = commitmentNonceLength; + suiteDataLength_ = suiteDataLength; + } + + private static int fieldsToLookupKey(final byte messageFormatVersion, final short algorithmId) { + // We pack the message format version and algorithm id into a single value. + // Since the algorithm ID is a short and thus 16 bits long, we'll just + // left shift the message format version by that amount. + // The message format version is 8 bits, so this totals 24 bits and fits + // within a standard 32 bit integer. + return (messageFormatVersion << 16) | algorithmId; + } + + /** + * Returns the CryptoAlgorithm object that matches the given value assuming a message format + * version of 1. + * + * @param value the value of the object + * @return the CryptoAlgorithm object that matches the given value, null if no match is found. + * @deprecated See {@link #deserialize(byte, short)} + */ + public static CryptoAlgorithm deserialize(final byte messageFormatVersion, final short value) { + return ID_MAPPING.get(fieldsToLookupKey(messageFormatVersion, value)); + } + + /** Returns the length of the message Id in the header for this algorithm. */ + public int getMessageIdLength() { + // For now this is a derived value rather than stored explicitly + switch (messageFormatVersion_) { + case VERSION_1: + return VERSION_1_MESSAGE_ID_LEN; + case VERSION_2: + return VERSION_2_MESSAGE_ID_LEN; + default: + throw new UnsupportedOperationException( + "Support for version " + messageFormatVersion_ + " not yet built."); } - - /** - * Returns the length of the key used in this algorithm in bytes. - */ - public int getKeyLength() { - return keyLenBytes_; + } + + /** + * Returns the header nonce to use with this algorithm. null indicates that the header nonce is + * not a parameter of the algorithm, and is instead stored as part of the message header. + */ + public byte[] getHeaderNonce() { + // For now this is a derived value rather than stored explicitly + switch (messageFormatVersion_) { + case VERSION_1: + return null; + case VERSION_2: + // V2 explicitly uses an IV of 0 in the header + return new byte[nonceLenBytes_]; + default: + throw new UnsupportedOperationException( + "Support for version " + messageFormatVersion_ + " not yet built."); } - - /** - * Returns the value used to encode this algorithm in the ciphertext. - */ - public short getValue() { - return value_; + } + + /** Returns the message format version associated with this algorithm suite. */ + public byte getMessageFormatVersion() { + return messageFormatVersion_; + } + + /** Returns the block size of this algorithm in bytes. */ + public int getBlockSize() { + return blockSizeBits_ / 8; + } + + /** Returns the nonce length used in this algorithm in bytes. */ + public byte getNonceLen() { + return nonceLenBytes_; + } + + /** Returns the tag length used in this algorithm in bytes. */ + public int getTagLen() { + return tagLenBytes_; + } + + /** + * Returns the maximum content length in bytes that can be processed under a single data key in + * this algorithm. + */ + public long getMaxContentLen() { + return maxContentLen_; + } + + /** Returns the algorithm used for encrypting the plaintext data. */ + public String getKeyAlgo() { + return keyAlgo_; + } + + /** Returns the length of the key used in this algorithm in bytes. */ + public int getKeyLength() { + return keyLenBytes_; + } + + /** Returns the value used to encode this algorithm in the ciphertext. */ + public short getValue() { + return value_; + } + + /** Returns the algorithm associated with the data key. */ + public String getDataKeyAlgo() { + return dataKeyAlgo_; + } + + /** Returns the length of the data key in bytes. */ + public int getDataKeyLength() { + return dataKeyLen_; + } + + /** Returns the algorithm used to calculate the trailing signature */ + public String getTrailingSignatureAlgo() { + return trailingSigAlgo_; + } + + /** + * Returns whether data keys used with this crypto algorithm can safely be cached and reused for a + * different message. If this returns false, reuse of data keys is likely to result in severe + * cryptographic weaknesses, potentially even with only a single such use. + */ + public boolean isSafeToCache() { + return safeToCache_; + } + + /** + * Returns the length of the trailing signature generated by this algorithm. The actual trailing + * signature may be shorter than this. + */ + public short getTrailingSignatureLength() { + return trailingSigLen_; + } + + public String getKeyCommitmentAlgo_() { + return keyCommitmentAlgo_; + } + + /** + * Returns a derived value of whether a commitment value is generated with the key in order to + * ensure key commitment. + */ + public boolean isCommitting() { + return keyCommitmentAlgo_ != null; + } + + public int getCommitmentLength() { + return commitmentLength_; + } + + public int getCommitmentNonceLength() { + return commitmentNonceLength_; + } + + public int getSuiteDataLength() { + return suiteDataLength_; + } + + public SecretKey getEncryptionKeyFromDataKey( + final SecretKey dataKey, final CiphertextHeaders headers) throws InvalidKeyException { + if (!dataKey.getAlgorithm().equalsIgnoreCase(getDataKeyAlgo())) { + throw new InvalidKeyException( + "DataKey of incorrect algorithm. Expected " + + getDataKeyAlgo() + + " but was " + + dataKey.getAlgorithm()); } - /** - * Returns the algorithm associated with the data key. - */ - public String getDataKeyAlgo() { - return dataKeyAlgo_; + // We perform key derivation differently depending on the message format version + switch (messageFormatVersion_) { + case VERSION_1: + return getNonCommittedEncryptionKey(dataKey, headers); + case VERSION_2: + return getCommittedEncryptionKey(dataKey, headers); + default: + throw new UnsupportedOperationException( + "Support for message format version " + messageFormatVersion_ + " not yet built."); } - - /** - * Returns the length of the data key in bytes. - */ - public int getDataKeyLength() { - return dataKeyLen_; + } + + private SecretKey getCommittedEncryptionKey( + final SecretKey dataKey, final CiphertextHeaders headers) throws InvalidKeyException { + final CommittedKey committedKey = CommittedKey.generate(this, dataKey, headers.getMessageId()); + if (!MessageDigest.isEqual(committedKey.getCommitment(), headers.getSuiteData())) { + throw new BadCiphertextException( + "Key commitment validation failed. Key identity does not match the " + + "identity asserted in the message. Halting processing of this message."); } - - /** - * Returns the algorithm used to calculate the trailing signature - */ - public String getTrailingSignatureAlgo() { - return trailingSigAlgo_; + return committedKey.getKey(); + } + + private SecretKey getNonCommittedEncryptionKey( + final SecretKey dataKey, final CiphertextHeaders headers) throws InvalidKeyException { + final String macAlgorithm; + + switch (this) { + case ALG_AES_128_GCM_IV12_TAG16_NO_KDF: + case ALG_AES_192_GCM_IV12_TAG16_NO_KDF: + case ALG_AES_256_GCM_IV12_TAG16_NO_KDF: + return dataKey; + case ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256: + case ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA256: + case ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256: + case ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256: + macAlgorithm = "HmacSHA256"; + break; + case ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: + case ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: + macAlgorithm = "HmacSHA384"; + break; + default: + throw new UnsupportedOperationException("Support for " + this + " not yet built."); } - - /** - * Returns whether data keys used with this crypto algorithm can safely be cached and reused for a different - * message. If this returns false, reuse of data keys is likely to result in severe cryptographic weaknesses, - * potentially even with only a single such use. - */ - public boolean isSafeToCache() { - return safeToCache_; + if (!dataKey.getFormat().equalsIgnoreCase("RAW")) { + throw new InvalidKeyException( + "Currently only RAW format keys are supported for HKDF algorithms. Actual format was " + + dataKey.getFormat()); } - - /** - * Returns the length of the trailing signature generated by this algorithm. The actual trailing - * signature may be shorter than this. - */ - public short getTrailingSignatureLength() { - return trailingSigLen_; + final byte[] messageId = headers.getMessageId(); + final ByteBuffer info = ByteBuffer.allocate(messageId.length + 2); + info.order(ByteOrder.BIG_ENDIAN); + info.putShort(getValue()); + info.put(messageId); + + final byte[] rawDataKey = dataKey.getEncoded(); + if (rawDataKey.length != getDataKeyLength()) { + throw new InvalidKeyException( + "DataKey of incorrect length. Expected " + + getDataKeyLength() + + " but was " + + rawDataKey.length); } - public String getKeyCommitmentAlgo_() { - return keyCommitmentAlgo_; + final HmacKeyDerivationFunction hkdf; + try { + hkdf = HmacKeyDerivationFunction.getInstance(macAlgorithm); + } catch (NoSuchAlgorithmException e) { + throw new IllegalStateException(e); } - /** - * Returns a derived value of whether a commitment value is generated with the key in order to ensure key commitment. - */ - public boolean isCommitting() { - return keyCommitmentAlgo_ != null; - } - - public int getCommitmentLength() { - return commitmentLength_; - } - - public int getCommitmentNonceLength() { - return commitmentNonceLength_; - } - - public int getSuiteDataLength() { - return suiteDataLength_; - } - - public SecretKey getEncryptionKeyFromDataKey(final SecretKey dataKey, final CiphertextHeaders headers) - throws InvalidKeyException { - if (!dataKey.getAlgorithm().equalsIgnoreCase(getDataKeyAlgo())) { - throw new InvalidKeyException("DataKey of incorrect algorithm. Expected " + getDataKeyAlgo() + " but was " - + dataKey.getAlgorithm()); - } - - // We perform key derivation differently depending on the message format version - switch (messageFormatVersion_) { - case VERSION_1: - return getNonCommittedEncryptionKey(dataKey, headers); - case VERSION_2: - return getCommittedEncryptionKey(dataKey, headers); - default: - throw new UnsupportedOperationException("Support for message format version " + messageFormatVersion_ + - " not yet built."); - } - } - - private SecretKey getCommittedEncryptionKey(final SecretKey dataKey, final CiphertextHeaders headers) - throws InvalidKeyException { - final CommittedKey committedKey = CommittedKey.generate(this, dataKey, headers.getMessageId()); - if (!MessageDigest.isEqual(committedKey.getCommitment(), headers.getSuiteData())) { - throw new BadCiphertextException("Key commitment validation failed. Key identity does not match the " + - "identity asserted in the message. Halting processing of this message."); - } - return committedKey.getKey(); - } - - private SecretKey getNonCommittedEncryptionKey(final SecretKey dataKey, final CiphertextHeaders headers) - throws InvalidKeyException { - final String macAlgorithm; - - switch (this) { - case ALG_AES_128_GCM_IV12_TAG16_NO_KDF: - case ALG_AES_192_GCM_IV12_TAG16_NO_KDF: - case ALG_AES_256_GCM_IV12_TAG16_NO_KDF: - return dataKey; - case ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256: - case ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA256: - case ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256: - case ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256: - macAlgorithm = "HmacSHA256"; - break; - case ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: - case ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: - macAlgorithm = "HmacSHA384"; - break; - default: - throw new UnsupportedOperationException("Support for " + this + " not yet built."); - } - if (!dataKey.getFormat().equalsIgnoreCase("RAW")) { - throw new InvalidKeyException( - "Currently only RAW format keys are supported for HKDF algorithms. Actual format was " - + dataKey.getFormat()); - } - final byte[] messageId = headers.getMessageId(); - final ByteBuffer info = ByteBuffer.allocate(messageId.length + 2); - info.order(ByteOrder.BIG_ENDIAN); - info.putShort(getValue()); - info.put(messageId); - - final byte[] rawDataKey = dataKey.getEncoded(); - if (rawDataKey.length != getDataKeyLength()) { - throw new InvalidKeyException("DataKey of incorrect length. Expected " + getDataKeyLength() + " but was " - + rawDataKey.length); - } - - final HmacKeyDerivationFunction hkdf; - try { - hkdf = HmacKeyDerivationFunction.getInstance(macAlgorithm); - } catch (NoSuchAlgorithmException e) { - throw new IllegalStateException(e); - } - - hkdf.init(rawDataKey); - return new SecretKeySpec(hkdf.deriveKey(info.array(), getKeyLength()), getKeyAlgo()); - } + hkdf.init(rawDataKey); + return new SecretKeySpec(hkdf.deriveKey(info.array(), getKeyLength()), getKeyAlgo()); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/CryptoInputStream.java b/src/main/java/com/amazonaws/encryptionsdk/CryptoInputStream.java index c71dca1e9..33ff8bd5e 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/CryptoInputStream.java +++ b/src/main/java/com/amazonaws/encryptionsdk/CryptoInputStream.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -15,245 +15,239 @@ import static com.amazonaws.encryptionsdk.internal.Utils.assertNonNull; -import java.io.IOException; -import java.io.InputStream; -import java.util.List; - import com.amazonaws.encryptionsdk.caching.CachingCryptoMaterialsManager; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.internal.MessageCryptoHandler; import com.amazonaws.encryptionsdk.internal.Utils; +import java.io.IOException; +import java.io.InputStream; +import java.util.List; /** * A CryptoInputStream is a subclass of java.io.InputStream. It performs cryptographic * transformation of the bytes passing through it. * - *

- * The CryptoInputStream wraps a provided InputStream object and performs cryptographic + *

The CryptoInputStream wraps a provided InputStream object and performs cryptographic * transformation of the bytes read from the wrapped InputStream. It uses the cryptography handler * provided during construction to invoke methods that perform the cryptographic transformations. * - *

- * In short, reading from the CryptoInputStream returns bytes that are the cryptographic + *

In short, reading from the CryptoInputStream returns bytes that are the cryptographic * transformations of the bytes read from the wrapped InputStream. * - *

- * For example, if the cryptography handler provides methods for decryption, the CryptoInputStream - * will read ciphertext bytes from the wrapped InputStream, decrypt, and return them as plaintext - * bytes. + *

For example, if the cryptography handler provides methods for decryption, the + * CryptoInputStream will read ciphertext bytes from the wrapped InputStream, decrypt, and return + * them as plaintext bytes. * - *

- * This class adheres strictly to the semantics, especially the failure semantics, of its ancestor - * class java.io.InputStream. This class overrides all the methods specified in its ancestor class. + *

This class adheres strictly to the semantics, especially the failure semantics, of its + * ancestor class java.io.InputStream. This class overrides all the methods specified in its + * ancestor class. * - *

- * To instantiate an instance of this class, please see {@link AwsCrypto}. + *

To instantiate an instance of this class, please see {@link AwsCrypto}. * - * @param - * The type of {@link MasterKey}s used to manipulate the data. + * @param The type of {@link MasterKey}s used to manipulate the data. */ public class CryptoInputStream> extends InputStream { - private static final int MAX_READ_LEN = 4096; - - private byte[] outBytes_ = new byte[0]; - private int outStart_; - private int outEnd_; - private final InputStream inputStream_; - private final MessageCryptoHandler cryptoHandler_; - private boolean hasFinalCalled_; - private boolean hasProcessBytesCalled_; - - /** - * Constructs a CryptoInputStream that wraps the provided InputStream object. It performs - * cryptographic transformation of the bytes read from the wrapped InputStream using the methods - * provided in the provided CryptoHandler implementation. - * - * @param inputStream - * the inputStream object to be wrapped. - * @param cryptoHandler - * the cryptoHandler implementation that provides the methods to use in performing - * cryptographic transformation of the bytes read from the inputStream. - */ - CryptoInputStream(final InputStream inputStream, final MessageCryptoHandler cryptoHandler) { - inputStream_ = Utils.assertNonNull(inputStream, "inputStream"); - cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler"); - } - - /** - * Fill the output bytes by reading from the wrapped InputStream and processing it through the - * crypto handler. - * - * @return the number of bytes processed and returned by the crypto handler. - */ - private int fillOutBytes() throws IOException, BadCiphertextException { - final byte[] inputStreamBytes = new byte[MAX_READ_LEN]; - - final int readLen = inputStream_.read(inputStreamBytes); - - outStart_ = 0; - - int processedLen; - if (readLen < 0) { - // Mark end of stream until doFinal returns something. - processedLen = -1; - - if (!hasFinalCalled_) { - int outOffset = 0; - int outLen = 0; - - // Handle the case where processBytes() was never called before. - // This happens with an empty file where the end of stream is - // reached on the first read attempt. In this case, - // processBytes() must be called so the header bytes are written - // during encryption. - if (!hasProcessBytesCalled_) { - outBytes_ = new byte[cryptoHandler_.estimateOutputSize(0)]; - outLen += cryptoHandler_.processBytes(inputStreamBytes, 0, 0, outBytes_, outOffset) - .getBytesWritten(); - outOffset += outLen; - } else { - outBytes_ = new byte[cryptoHandler_.estimateFinalOutputSize()]; - } - - // Get final bytes. - outLen += cryptoHandler_.doFinal(outBytes_, outOffset); - processedLen = outLen; - hasFinalCalled_ = true; - } + private static final int MAX_READ_LEN = 4096; + + private byte[] outBytes_ = new byte[0]; + private int outStart_; + private int outEnd_; + private final InputStream inputStream_; + private final MessageCryptoHandler cryptoHandler_; + private boolean hasFinalCalled_; + private boolean hasProcessBytesCalled_; + + /** + * Constructs a CryptoInputStream that wraps the provided InputStream object. It performs + * cryptographic transformation of the bytes read from the wrapped InputStream using the methods + * provided in the provided CryptoHandler implementation. + * + * @param inputStream the inputStream object to be wrapped. + * @param cryptoHandler the cryptoHandler implementation that provides the methods to use in + * performing cryptographic transformation of the bytes read from the inputStream. + */ + CryptoInputStream(final InputStream inputStream, final MessageCryptoHandler cryptoHandler) { + inputStream_ = Utils.assertNonNull(inputStream, "inputStream"); + cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler"); + } + + /** + * Fill the output bytes by reading from the wrapped InputStream and processing it through the + * crypto handler. + * + * @return the number of bytes processed and returned by the crypto handler. + */ + private int fillOutBytes() throws IOException, BadCiphertextException { + final byte[] inputStreamBytes = new byte[MAX_READ_LEN]; + + final int readLen = inputStream_.read(inputStreamBytes); + + outStart_ = 0; + + int processedLen; + if (readLen < 0) { + // Mark end of stream until doFinal returns something. + processedLen = -1; + + if (!hasFinalCalled_) { + int outOffset = 0; + int outLen = 0; + + // Handle the case where processBytes() was never called before. + // This happens with an empty file where the end of stream is + // reached on the first read attempt. In this case, + // processBytes() must be called so the header bytes are written + // during encryption. + if (!hasProcessBytesCalled_) { + outBytes_ = new byte[cryptoHandler_.estimateOutputSize(0)]; + outLen += + cryptoHandler_ + .processBytes(inputStreamBytes, 0, 0, outBytes_, outOffset) + .getBytesWritten(); + outOffset += outLen; } else { - // process the read bytes. - outBytes_ = new byte[cryptoHandler_.estimatePartialOutputSize(readLen)]; - processedLen = cryptoHandler_.processBytes(inputStreamBytes, 0, readLen, outBytes_, outStart_) - .getBytesWritten(); - hasProcessBytesCalled_ = true; + outBytes_ = new byte[cryptoHandler_.estimateFinalOutputSize()]; } - outEnd_ = processedLen; - return processedLen; + // Get final bytes. + outLen += cryptoHandler_.doFinal(outBytes_, outOffset); + processedLen = outLen; + hasFinalCalled_ = true; + } + } else { + // process the read bytes. + outBytes_ = new byte[cryptoHandler_.estimatePartialOutputSize(readLen)]; + processedLen = + cryptoHandler_ + .processBytes(inputStreamBytes, 0, readLen, outBytes_, outStart_) + .getBytesWritten(); + hasProcessBytesCalled_ = true; } - /** - * {@inheritDoc} - * - * @throws BadCiphertextException - * This is thrown only during decryption if b contains invalid or corrupt - * ciphertext. - */ - @Override - public int read(final byte[] b, final int off, final int len) throws IllegalArgumentException, IOException, - BadCiphertextException { - assertNonNull(b, "b"); - - if (len < 0 || off < 0) { - throw new IllegalArgumentException(String.format("Invalid values for offset: %d and length: %d", off, len)); - } - - if (b.length == 0 || len == 0) { - return 0; - } - - // fill the output bytes if there aren't any left to return. - if ((outEnd_ - outStart_) <= 0) { - int newBytesLen = 0; - - // Block until a byte is read or end of stream in the underlying - // stream is reached. - while (newBytesLen == 0) { - newBytesLen = fillOutBytes(); - } - if (newBytesLen < 0) { - return -1; - } - } - - final int copyLen = Math.min((outEnd_ - outStart_), len); - System.arraycopy(outBytes_, outStart_, b, off, copyLen); - outStart_ += copyLen; - - return copyLen; + outEnd_ = processedLen; + return processedLen; + } + + /** + * {@inheritDoc} + * + * @throws BadCiphertextException This is thrown only during decryption if b contains invalid or + * corrupt ciphertext. + */ + @Override + public int read(final byte[] b, final int off, final int len) + throws IllegalArgumentException, IOException, BadCiphertextException { + assertNonNull(b, "b"); + + if (len < 0 || off < 0) { + throw new IllegalArgumentException( + String.format("Invalid values for offset: %d and length: %d", off, len)); } - /** - * {@inheritDoc} - * - * @throws BadCiphertextException - * This is thrown only during decryption if b contains invalid or corrupt - * ciphertext. - */ - @Override - public int read(final byte[] b) throws IllegalArgumentException, IOException, BadCiphertextException { - return read(b, 0, b.length); + if (b.length == 0 || len == 0) { + return 0; } - /** - * {@inheritDoc} - * - * @throws BadCiphertextException - * if b contains invalid or corrupt ciphertext. This is thrown only during - * decryption. - */ - @Override - public int read() throws IOException, BadCiphertextException { - final byte[] bArray = new byte[1]; - int result = 0; - - while (result == 0) { - result = read(bArray, 0, 1); - } - - if (result > 0) { - return (bArray[0] & 0xFF); - } else { - return result; - } + // fill the output bytes if there aren't any left to return. + if ((outEnd_ - outStart_) <= 0) { + int newBytesLen = 0; + + // Block until a byte is read or end of stream in the underlying + // stream is reached. + while (newBytesLen == 0) { + newBytesLen = fillOutBytes(); + } + if (newBytesLen < 0) { + return -1; + } } - @Override - public void close() throws IOException { - inputStream_.close(); + final int copyLen = Math.min((outEnd_ - outStart_), len); + System.arraycopy(outBytes_, outStart_, b, off, copyLen); + outStart_ += copyLen; + + return copyLen; + } + + /** + * {@inheritDoc} + * + * @throws BadCiphertextException This is thrown only during decryption if b contains invalid or + * corrupt ciphertext. + */ + @Override + public int read(final byte[] b) + throws IllegalArgumentException, IOException, BadCiphertextException { + return read(b, 0, b.length); + } + + /** + * {@inheritDoc} + * + * @throws BadCiphertextException if b contains invalid or corrupt ciphertext. This is thrown only + * during decryption. + */ + @Override + public int read() throws IOException, BadCiphertextException { + final byte[] bArray = new byte[1]; + int result = 0; + + while (result == 0) { + result = read(bArray, 0, 1); } - /** - * Returns metadata associated with the performed cryptographic operation. - */ - @Override - public int available() throws IOException { - return (outBytes_.length + inputStream_.available()); + if (result > 0) { + return (bArray[0] & 0xFF); + } else { + return result; } - - /** - * Sets an upper bound on the size of the input data. This method should be called before reading any data from the - * stream. If this method is not called prior to reading any data, performance may be reduced (notably, it will not - * be possible to cache data keys when encrypting). - * - * Among other things, this size is used to enforce limits configured on the {@link CachingCryptoMaterialsManager}. - * - * If the input size set here is exceeded, an exception will be thrown, and the encyption or decryption will fail. - * - * If this method is called multiple times, the smallest bound will be used. - * - * @param size Maximum input size. - */ - public void setMaxInputLength(long size) { - cryptoHandler_.setMaxInputLength(size); - } - - /** - * Returns the result of the cryptographic operations including associate metadata. - * - * @throws IOException - * @throws BadCiphertextException - */ - public CryptoResult, K> getCryptoResult() throws BadCiphertextException, IOException { - while (!cryptoHandler_.getHeaders().isComplete()) { - if (fillOutBytes() == -1) { - throw new BadCiphertextException("No CiphertextHeaders found."); - } - } - //noinspection unchecked - return new CryptoResult<>( - this, - (List) cryptoHandler_.getMasterKeys(), - cryptoHandler_.getHeaders()); + } + + @Override + public void close() throws IOException { + inputStream_.close(); + } + + /** Returns metadata associated with the performed cryptographic operation. */ + @Override + public int available() throws IOException { + return (outBytes_.length + inputStream_.available()); + } + + /** + * Sets an upper bound on the size of the input data. This method should be called before reading + * any data from the stream. If this method is not called prior to reading any data, performance + * may be reduced (notably, it will not be possible to cache data keys when encrypting). + * + *

Among other things, this size is used to enforce limits configured on the {@link + * CachingCryptoMaterialsManager}. + * + *

If the input size set here is exceeded, an exception will be thrown, and the encyption or + * decryption will fail. + * + *

If this method is called multiple times, the smallest bound will be used. + * + * @param size Maximum input size. + */ + public void setMaxInputLength(long size) { + cryptoHandler_.setMaxInputLength(size); + } + + /** + * Returns the result of the cryptographic operations including associate metadata. + * + * @throws IOException + * @throws BadCiphertextException + */ + public CryptoResult, K> getCryptoResult() + throws BadCiphertextException, IOException { + while (!cryptoHandler_.getHeaders().isComplete()) { + if (fillOutBytes() == -1) { + throw new BadCiphertextException("No CiphertextHeaders found."); + } } + //noinspection unchecked + return new CryptoResult<>( + this, (List) cryptoHandler_.getMasterKeys(), cryptoHandler_.getHeaders()); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/CryptoMaterialsManager.java b/src/main/java/com/amazonaws/encryptionsdk/CryptoMaterialsManager.java index 6b29b2358..38bcb1b22 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/CryptoMaterialsManager.java +++ b/src/main/java/com/amazonaws/encryptionsdk/CryptoMaterialsManager.java @@ -1,29 +1,29 @@ package com.amazonaws.encryptionsdk; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; /** - * The crypto materials manager is responsible for preparing the cryptographic materials needed to process a request - - * notably, preparing the cleartext data key and (if applicable) trailing signature keys on both encrypt and decrypt. + * The crypto materials manager is responsible for preparing the cryptographic materials needed to + * process a request - notably, preparing the cleartext data key and (if applicable) trailing + * signature keys on both encrypt and decrypt. */ public interface CryptoMaterialsManager { - /** - * Prepares materials for an encrypt request. The resulting materials result must have a cleartext data key and - * (if applicable for the crypto algorithm in use) a trailing signature key. - * - * The encryption context returned may be different from the one passed in the materials request, and will be - * serialized (in cleartext) within the encrypted message. - * - * @see EncryptionMaterials - * @see EncryptionMaterialsRequest - * - * @param request - * @return - */ - EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request); + /** + * Prepares materials for an encrypt request. The resulting materials result must have a cleartext + * data key and (if applicable for the crypto algorithm in use) a trailing signature key. + * + *

The encryption context returned may be different from the one passed in the materials + * request, and will be serialized (in cleartext) within the encrypted message. + * + * @see EncryptionMaterials + * @see EncryptionMaterialsRequest + * @param request + * @return + */ + EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request); - DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request); + DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/CryptoOutputStream.java b/src/main/java/com/amazonaws/encryptionsdk/CryptoOutputStream.java index 1dbee7c46..69526d800 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/CryptoOutputStream.java +++ b/src/main/java/com/amazonaws/encryptionsdk/CryptoOutputStream.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,174 +13,159 @@ package com.amazonaws.encryptionsdk; -import java.io.IOException; -import java.io.OutputStream; -import java.util.List; - import com.amazonaws.encryptionsdk.caching.CachingCryptoMaterialsManager; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.internal.MessageCryptoHandler; import com.amazonaws.encryptionsdk.internal.Utils; +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; /** * A CryptoOutputStream is a subclass of java.io.OutputStream. It performs cryptographic * transformation of the bytes passing through it. - * - *

- * The CryptoOutputStream wraps a provided OutputStream object and performs cryptographic + * + *

The CryptoOutputStream wraps a provided OutputStream object and performs cryptographic * transformation of the bytes written to it. The transformed bytes are then written to the wrapped * OutputStream. It uses the cryptography handler provided during construction to invoke methods * that perform the cryptographic transformations. - * - *

- * In short, writing to the CryptoOutputStream results in those bytes being cryptographically + * + *

In short, writing to the CryptoOutputStream results in those bytes being cryptographically * transformed and written to the wrapped OutputStream. - * - *

- * For example, if the crypto handler provides methods for decryption, the CryptoOutputStream will - * decrypt the provided ciphertext bytes and write the plaintext bytes to the wrapped OutputStream. - * - *

- * This class adheres strictly to the semantics, especially the failure semantics, of its ancestor - * class java.io.OutputStream. This class overrides all the methods specified in its ancestor class. - * - *

- * To instantiate an instance of this class, please see {@link AwsCrypto}. - * - * @param - * The type of {@link MasterKey}s used to manipulate the data. + * + *

For example, if the crypto handler provides methods for decryption, the CryptoOutputStream + * will decrypt the provided ciphertext bytes and write the plaintext bytes to the wrapped + * OutputStream. + * + *

This class adheres strictly to the semantics, especially the failure semantics, of its + * ancestor class java.io.OutputStream. This class overrides all the methods specified in its + * ancestor class. + * + *

To instantiate an instance of this class, please see {@link AwsCrypto}. + * + * @param The type of {@link MasterKey}s used to manipulate the data. */ public class CryptoOutputStream> extends OutputStream { - private final OutputStream outputStream_; - - private final MessageCryptoHandler cryptoHandler_; - - /** - * Constructs a CryptoOutputStream that wraps the provided OutputStream object. It performs - * cryptographic transformation of the bytes written to it using the methods provided in the - * provided CryptoHandler implementation. The transformed bytes are then written to the wrapped - * OutputStream. - * - * @param outputStream - * the outputStream object to be wrapped. - * @param cryptoHandler - * the cryptoHandler implementation that provides the methods to use in performing - * cryptographic transformation of the bytes written to this stream. - */ - CryptoOutputStream(final OutputStream outputStream, final MessageCryptoHandler cryptoHandler) { - outputStream_ = Utils.assertNonNull(outputStream, "outputStream"); - cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler"); + private final OutputStream outputStream_; + + private final MessageCryptoHandler cryptoHandler_; + + /** + * Constructs a CryptoOutputStream that wraps the provided OutputStream object. It performs + * cryptographic transformation of the bytes written to it using the methods provided in the + * provided CryptoHandler implementation. The transformed bytes are then written to the wrapped + * OutputStream. + * + * @param outputStream the outputStream object to be wrapped. + * @param cryptoHandler the cryptoHandler implementation that provides the methods to use in + * performing cryptographic transformation of the bytes written to this stream. + */ + CryptoOutputStream(final OutputStream outputStream, final MessageCryptoHandler cryptoHandler) { + outputStream_ = Utils.assertNonNull(outputStream, "outputStream"); + cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler"); + } + + /** + * {@inheritDoc} + * + * @throws BadCiphertextException This is thrown only during decryption if b contains invalid or + * corrupt ciphertext. + */ + @Override + public void write(final byte[] b) + throws IllegalArgumentException, IOException, BadCiphertextException { + if (b == null) { + throw new IllegalArgumentException("b cannot be null"); } - - /** - * {@inheritDoc} - * - * @throws BadCiphertextException - * This is thrown only during decryption if b contains invalid or corrupt - * ciphertext. - */ - @Override - public void write(final byte[] b) throws IllegalArgumentException, IOException, BadCiphertextException { - if (b == null) { - throw new IllegalArgumentException("b cannot be null"); - } - write(b, 0, b.length); + write(b, 0, b.length); + } + + /** + * {@inheritDoc} + * + * @throws BadCiphertextException This is thrown only during decryption if b contains invalid or + * corrupt ciphertext. + */ + @Override + public void write(final byte[] b, final int off, final int len) + throws IllegalArgumentException, IOException, BadCiphertextException { + if (b == null) { + throw new IllegalArgumentException("b cannot be null"); } - /** - * {@inheritDoc} - * - * @throws BadCiphertextException - * This is thrown only during decryption if b contains invalid or corrupt - * ciphertext. - */ - @Override - public void write(final byte[] b, final int off, final int len) throws IllegalArgumentException, IOException, - BadCiphertextException { - if (b == null) { - throw new IllegalArgumentException("b cannot be null"); - } - - if (len < 0 || off < 0) { - throw new IllegalArgumentException(String.format("Invalid values for offset: %d and length: %d", off, len)); - } - - final int outLen = cryptoHandler_.estimatePartialOutputSize(len); - final byte[] outBytes = new byte[outLen]; - - int bytesWritten = cryptoHandler_.processBytes(b, off, len, outBytes, 0).getBytesWritten(); - if (bytesWritten > 0) { - outputStream_.write(outBytes, 0, bytesWritten); - } + if (len < 0 || off < 0) { + throw new IllegalArgumentException( + String.format("Invalid values for offset: %d and length: %d", off, len)); } - /** - * {@inheritDoc} - * - * @throws BadCiphertextException - * This is thrown only during decryption if b contains invalid or corrupt - * ciphertext. - */ - @Override - public void write(int b) throws IOException, BadCiphertextException { - byte[] bArray = new byte[1]; - bArray[0] = (byte) b; - write(bArray, 0, 1); - } + final int outLen = cryptoHandler_.estimatePartialOutputSize(len); + final byte[] outBytes = new byte[outLen]; - /** - * Closes this output stream and releases any system resources associated - * with this stream. - * - *

- * This method writes any final bytes to the underlying stream that complete - * the cyptographic transformation of the written bytes. It also calls close - * on the wrapped OutputStream. - * - * @throws IOException - * if an I/O error occurs. - * @throws BadCiphertextException - * This is thrown only during decryption if b contains invalid - * or corrupt ciphertext. - */ - @Override - public void close() throws IOException, BadCiphertextException { - final byte[] outBytes = new byte[cryptoHandler_.estimateFinalOutputSize()]; - int finalLen = cryptoHandler_.doFinal(outBytes, 0); - - outputStream_.write(outBytes, 0, finalLen); - outputStream_.close(); + int bytesWritten = cryptoHandler_.processBytes(b, off, len, outBytes, 0).getBytesWritten(); + if (bytesWritten > 0) { + outputStream_.write(outBytes, 0, bytesWritten); } - - /** - * Sets an upper bound on the size of the input data. This method should be called before writing any data to the - * stream. If this method is not called prior to writing data, performance may be reduced (notably, it will not - * be possible to cache data keys when encrypting). - * - * Among other things, this size is used to enforce limits configured on the {@link CachingCryptoMaterialsManager}. - * - * If the size set here is exceeded, an exception will be thrown, and the encyption or decryption will fail. - * - * If this method is called multiple times, the smallest bound will be used. - * - * @param size Maximum input size. - */ - public void setMaxInputLength(long size) { - cryptoHandler_.setMaxInputLength(size); + } + + /** + * {@inheritDoc} + * + * @throws BadCiphertextException This is thrown only during decryption if b contains invalid or + * corrupt ciphertext. + */ + @Override + public void write(int b) throws IOException, BadCiphertextException { + byte[] bArray = new byte[1]; + bArray[0] = (byte) b; + write(bArray, 0, 1); + } + + /** + * Closes this output stream and releases any system resources associated with this stream. + * + *

This method writes any final bytes to the underlying stream that complete the cyptographic + * transformation of the written bytes. It also calls close on the wrapped OutputStream. + * + * @throws IOException if an I/O error occurs. + * @throws BadCiphertextException This is thrown only during decryption if b contains invalid or + * corrupt ciphertext. + */ + @Override + public void close() throws IOException, BadCiphertextException { + final byte[] outBytes = new byte[cryptoHandler_.estimateFinalOutputSize()]; + int finalLen = cryptoHandler_.doFinal(outBytes, 0); + + outputStream_.write(outBytes, 0, finalLen); + outputStream_.close(); + } + + /** + * Sets an upper bound on the size of the input data. This method should be called before writing + * any data to the stream. If this method is not called prior to writing data, performance may be + * reduced (notably, it will not be possible to cache data keys when encrypting). + * + *

Among other things, this size is used to enforce limits configured on the {@link + * CachingCryptoMaterialsManager}. + * + *

If the size set here is exceeded, an exception will be thrown, and the encyption or + * decryption will fail. + * + *

If this method is called multiple times, the smallest bound will be used. + * + * @param size Maximum input size. + */ + public void setMaxInputLength(long size) { + cryptoHandler_.setMaxInputLength(size); + } + + /** Returns the result of the cryptographic operations including associate metadata. */ + public CryptoResult, K> getCryptoResult() { + if (!cryptoHandler_.getHeaders().isComplete()) { + throw new IllegalStateException("Ciphertext headers not yet written to stream"); } - /** - * Returns the result of the cryptographic operations including associate metadata. - */ - public CryptoResult, K> getCryptoResult() { - if (!cryptoHandler_.getHeaders().isComplete()) { - throw new IllegalStateException("Ciphertext headers not yet written to stream"); - } - - //noinspection unchecked - return new CryptoResult<>( - this, - (List) cryptoHandler_.getMasterKeys(), - cryptoHandler_.getHeaders()); - } + //noinspection unchecked + return new CryptoResult<>( + this, (List) cryptoHandler_.getMasterKeys(), cryptoHandler_.getHeaders()); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/CryptoResult.java b/src/main/java/com/amazonaws/encryptionsdk/CryptoResult.java index 616ed3756..167044132 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/CryptoResult.java +++ b/src/main/java/com/amazonaws/encryptionsdk/CryptoResult.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,84 +13,75 @@ package com.amazonaws.encryptionsdk; +import com.amazonaws.encryptionsdk.model.CiphertextHeaders; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; -import com.amazonaws.encryptionsdk.model.CiphertextHeaders; - /** - * Represents the result of an operation by {@link AwsCrypto}. It not only captures the - * {@code result} of the operation but also additional metadata such as the - * {@code encryptionContext}, {@code algorithm}, {@link MasterKey}(s), and any other information - * captured in the {@link CiphertextHeaders}. + * Represents the result of an operation by {@link AwsCrypto}. It not only captures the {@code + * result} of the operation but also additional metadata such as the {@code encryptionContext}, + * {@code algorithm}, {@link MasterKey}(s), and any other information captured in the {@link + * CiphertextHeaders}. * - * @param - * the type of the underlying {@code result} - * @param - * the type of the {@link MasterKey}s used in production of this result + * @param the type of the underlying {@code result} + * @param the type of the {@link MasterKey}s used in production of this result */ public class CryptoResult> { - private final T result_; - private final List masterKeys_; - private final Map encryptionContext_; - private final CiphertextHeaders headers_; + private final T result_; + private final List masterKeys_; + private final Map encryptionContext_; + private final CiphertextHeaders headers_; - /** - * Note, does not make a defensive copy of any of the data. - */ - CryptoResult(final T result, final List masterKeys, final CiphertextHeaders headers) { - result_ = result; - masterKeys_ = Collections.unmodifiableList(masterKeys); - headers_ = headers; - encryptionContext_ = headers_.getEncryptionContextMap(); - } + /** Note, does not make a defensive copy of any of the data. */ + CryptoResult(final T result, final List masterKeys, final CiphertextHeaders headers) { + result_ = result; + masterKeys_ = Collections.unmodifiableList(masterKeys); + headers_ = headers; + encryptionContext_ = headers_.getEncryptionContextMap(); + } - /** - * The actual result of the cryptographic operation. This is not a defensive copy and callers - * should not modify it. - * - * @return - */ - public T getResult() { - return result_; - } + /** + * The actual result of the cryptographic operation. This is not a defensive copy and callers + * should not modify it. + * + * @return + */ + public T getResult() { + return result_; + } - /** - * Returns all relevant {@link MasterKey}s. In the case of encryption, returns all - * {@code MasterKey}s used to protect the ciphertext. In the case of decryption, returns just - * the {@code MasterKey} used to decrypt the ciphertext. - * - * @return - */ - public List getMasterKeys() { - return masterKeys_; - } + /** + * Returns all relevant {@link MasterKey}s. In the case of encryption, returns all {@code + * MasterKey}s used to protect the ciphertext. In the case of decryption, returns just the {@code + * MasterKey} used to decrypt the ciphertext. + * + * @return + */ + public List getMasterKeys() { + return masterKeys_; + } - /** - * Convenience method for retrieving the keyIds in the results from {@link #getMasterKeys()}. - */ - public List getMasterKeyIds() { - final List result = new ArrayList<>(masterKeys_.size()); - for (final MasterKey mk : masterKeys_) { - result.add(mk.getKeyId()); - } - return result; + /** Convenience method for retrieving the keyIds in the results from {@link #getMasterKeys()}. */ + public List getMasterKeyIds() { + final List result = new ArrayList<>(masterKeys_.size()); + for (final MasterKey mk : masterKeys_) { + result.add(mk.getKeyId()); } + return result; + } - public Map getEncryptionContext() { - return encryptionContext_; - } + public Map getEncryptionContext() { + return encryptionContext_; + } - /** - * Convenience method equivalent to {@link #getHeaders()}.{@code getCryptoAlgoId()}. - */ - public CryptoAlgorithm getCryptoAlgorithm() { - return headers_.getCryptoAlgoId(); - } + /** Convenience method equivalent to {@link #getHeaders()}.{@code getCryptoAlgoId()}. */ + public CryptoAlgorithm getCryptoAlgorithm() { + return headers_.getCryptoAlgoId(); + } - public CiphertextHeaders getHeaders() { - return headers_; - } + public CiphertextHeaders getHeaders() { + return headers_; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/DataKey.java b/src/main/java/com/amazonaws/encryptionsdk/DataKey.java index 21d174f00..33e9492fd 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/DataKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/DataKey.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -18,50 +18,48 @@ /** * Represents both the cleartext and encrypted bytes of a data key. * - * @param - * the type of {@link MasterKey} used to protect this {@code DataKey}. + * @param the type of {@link MasterKey} used to protect this {@code DataKey}. */ public class DataKey> implements EncryptedDataKey { - private final byte[] providerInformation_; - private final byte[] encryptedDataKey_; - private final SecretKey key_; - private final M masterKey_; + private final byte[] providerInformation_; + private final byte[] encryptedDataKey_; + private final SecretKey key_; + private final M masterKey_; - public DataKey(final SecretKey key, final byte[] encryptedDataKey, final byte[] providerInformation, - final M masterKey) { - super(); - key_ = key; - encryptedDataKey_ = encryptedDataKey.clone(); - providerInformation_ = providerInformation.clone(); - masterKey_ = masterKey; - } + public DataKey( + final SecretKey key, + final byte[] encryptedDataKey, + final byte[] providerInformation, + final M masterKey) { + super(); + key_ = key; + encryptedDataKey_ = encryptedDataKey.clone(); + providerInformation_ = providerInformation.clone(); + masterKey_ = masterKey; + } - /** - * Returns the cleartext bytes of the data key. - */ - public SecretKey getKey() { - return key_; - } + /** Returns the cleartext bytes of the data key. */ + public SecretKey getKey() { + return key_; + } - @Override - public String getProviderId() { - return masterKey_.getProviderId(); - } + @Override + public String getProviderId() { + return masterKey_.getProviderId(); + } - @Override - public byte[] getProviderInformation() { - return providerInformation_.clone(); - } + @Override + public byte[] getProviderInformation() { + return providerInformation_.clone(); + } - @Override - public byte[] getEncryptedDataKey() { - return encryptedDataKey_.clone(); - } + @Override + public byte[] getEncryptedDataKey() { + return encryptedDataKey_.clone(); + } - /** - * Returns the {@link MasterKey} used to encrypt this {@link DataKey}. - */ - public M getMasterKey() { - return masterKey_; - } + /** Returns the {@link MasterKey} used to encrypt this {@link DataKey}. */ + public M getMasterKey() { + return masterKey_; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/DefaultCryptoMaterialsManager.java b/src/main/java/com/amazonaws/encryptionsdk/DefaultCryptoMaterialsManager.java index 43cc78c6f..22980ad2f 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/DefaultCryptoMaterialsManager.java +++ b/src/main/java/com/amazonaws/encryptionsdk/DefaultCryptoMaterialsManager.java @@ -5,157 +5,157 @@ import static com.amazonaws.encryptionsdk.internal.Utils.assertNonNull; -import java.security.GeneralSecurityException; -import java.security.KeyPair; -import java.security.PublicKey; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; import com.amazonaws.encryptionsdk.internal.Constants; -import com.amazonaws.encryptionsdk.internal.Utils; import com.amazonaws.encryptionsdk.internal.TrailingSignatureAlgorithm; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.internal.Utils; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.KeyBlob; -import com.amazonaws.encryptionsdk.model.EncryptionMaterials; +import java.security.GeneralSecurityException; +import java.security.KeyPair; +import java.security.PublicKey; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** * The default implementation of {@link CryptoMaterialsManager}, used implicitly when passing a * {@link MasterKeyProvider} to methods in {@link AwsCrypto}. * - * This default implementation delegates to a specific {@link MasterKeyProvider} specified at construction time. It also - * handles generating trailing signature keys when needed, placing them in the encryption context (and extracting them - * at decrypt time). + *

This default implementation delegates to a specific {@link MasterKeyProvider} specified at + * construction time. It also handles generating trailing signature keys when needed, placing them + * in the encryption context (and extracting them at decrypt time). */ public class DefaultCryptoMaterialsManager implements CryptoMaterialsManager { - private final MasterKeyProvider mkp; - - private final CryptoAlgorithm DEFAULT_CRYPTO_ALGORITHM = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; - - /** - * @param mkp The master key provider to delegate to - */ - public DefaultCryptoMaterialsManager(MasterKeyProvider mkp) { - Utils.assertNonNull(mkp, "mkp"); - this.mkp = mkp; + private final MasterKeyProvider mkp; + + private final CryptoAlgorithm DEFAULT_CRYPTO_ALGORITHM = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + + /** @param mkp The master key provider to delegate to */ + public DefaultCryptoMaterialsManager(MasterKeyProvider mkp) { + Utils.assertNonNull(mkp, "mkp"); + this.mkp = mkp; + } + + @Override + public EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request) { + Map context = request.getContext(); + + CryptoAlgorithm algo = request.getRequestedAlgorithm(); + CommitmentPolicy commitmentPolicy = request.getCommitmentPolicy(); + // Set default according to commitment policy + if (algo == null && commitmentPolicy == CommitmentPolicy.ForbidEncryptAllowDecrypt) { + algo = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + } else if (algo == null) { + algo = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; } - @Override public EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request) { - Map context = request.getContext(); - - CryptoAlgorithm algo = request.getRequestedAlgorithm(); - CommitmentPolicy commitmentPolicy = request.getCommitmentPolicy(); - // Set default according to commitment policy - if (algo == null && commitmentPolicy == CommitmentPolicy.ForbidEncryptAllowDecrypt) { - algo = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; - } else if (algo == null) { - algo = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; - } - - KeyPair trailingKeys = null; - if (algo.getTrailingSignatureLength() > 0) { - try { - trailingKeys = generateTrailingSigKeyPair(algo); - if (context.containsKey(Constants.EC_PUBLIC_KEY_FIELD)) { - throw new IllegalArgumentException("EncryptionContext contains reserved field " - + Constants.EC_PUBLIC_KEY_FIELD); - } - // make mutable - context = new HashMap<>(context); - context.put(Constants.EC_PUBLIC_KEY_FIELD, serializeTrailingKeyForEc(algo, trailingKeys)); - } catch (final GeneralSecurityException ex) { - throw new AwsCryptoException(ex); - } + KeyPair trailingKeys = null; + if (algo.getTrailingSignatureLength() > 0) { + try { + trailingKeys = generateTrailingSigKeyPair(algo); + if (context.containsKey(Constants.EC_PUBLIC_KEY_FIELD)) { + throw new IllegalArgumentException( + "EncryptionContext contains reserved field " + Constants.EC_PUBLIC_KEY_FIELD); } + // make mutable + context = new HashMap<>(context); + context.put(Constants.EC_PUBLIC_KEY_FIELD, serializeTrailingKeyForEc(algo, trailingKeys)); + } catch (final GeneralSecurityException ex) { + throw new AwsCryptoException(ex); + } + } - final MasterKeyRequest.Builder mkRequestBuilder = MasterKeyRequest.newBuilder(); - mkRequestBuilder.setEncryptionContext(context); + final MasterKeyRequest.Builder mkRequestBuilder = MasterKeyRequest.newBuilder(); + mkRequestBuilder.setEncryptionContext(context); - mkRequestBuilder.setStreaming(request.getPlaintextSize() == -1); - if (request.getPlaintext() != null) { - mkRequestBuilder.setPlaintext(request.getPlaintext()); - } else { - mkRequestBuilder.setSize(request.getPlaintextSize()); - } - - @SuppressWarnings("unchecked") - final List mks - = (List)assertNonNull(mkp, "provider") - .getMasterKeysForEncryption(mkRequestBuilder.build()); + mkRequestBuilder.setStreaming(request.getPlaintextSize() == -1); + if (request.getPlaintext() != null) { + mkRequestBuilder.setPlaintext(request.getPlaintext()); + } else { + mkRequestBuilder.setSize(request.getPlaintextSize()); + } - if (mks.isEmpty()) { - throw new IllegalArgumentException("No master keys provided"); - } + @SuppressWarnings("unchecked") + final List mks = + (List) + assertNonNull(mkp, "provider").getMasterKeysForEncryption(mkRequestBuilder.build()); - DataKey dataKey = mks.get(0).generateDataKey(algo, context); + if (mks.isEmpty()) { + throw new IllegalArgumentException("No master keys provided"); + } - List keyBlobs = new ArrayList<>(mks.size()); - keyBlobs.add(new KeyBlob(dataKey)); + DataKey dataKey = mks.get(0).generateDataKey(algo, context); - for (int i = 1; i < mks.size(); i++) { - //noinspection unchecked - keyBlobs.add(new KeyBlob(mks.get(i).encryptDataKey(algo, context, dataKey))); - } + List keyBlobs = new ArrayList<>(mks.size()); + keyBlobs.add(new KeyBlob(dataKey)); - //noinspection unchecked - return EncryptionMaterials.newBuilder() - .setAlgorithm(algo) - .setCleartextDataKey(dataKey.getKey()) - .setEncryptedDataKeys(keyBlobs) - .setEncryptionContext(context) - .setTrailingSignatureKey(trailingKeys == null ? null : trailingKeys.getPrivate()) - .setMasterKeys(mks) - .build(); + for (int i = 1; i < mks.size(); i++) { + //noinspection unchecked + keyBlobs.add(new KeyBlob(mks.get(i).encryptDataKey(algo, context, dataKey))); } - @Override public DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request) { - DataKey dataKey = mkp.decryptDataKey( - request.getAlgorithm(), - request.getEncryptedDataKeys(), - request.getEncryptionContext() - ); + //noinspection unchecked + return EncryptionMaterials.newBuilder() + .setAlgorithm(algo) + .setCleartextDataKey(dataKey.getKey()) + .setEncryptedDataKeys(keyBlobs) + .setEncryptionContext(context) + .setTrailingSignatureKey(trailingKeys == null ? null : trailingKeys.getPrivate()) + .setMasterKeys(mks) + .build(); + } + + @Override + public DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request) { + DataKey dataKey = + mkp.decryptDataKey( + request.getAlgorithm(), request.getEncryptedDataKeys(), request.getEncryptionContext()); + + if (dataKey == null) { + throw new CannotUnwrapDataKeyException("Could not decrypt any data keys"); + } - if (dataKey == null) { - throw new CannotUnwrapDataKeyException("Could not decrypt any data keys"); - } + PublicKey pubKey = null; + if (request.getAlgorithm().getTrailingSignatureLength() > 0) { + try { + String serializedPubKey = request.getEncryptionContext().get(Constants.EC_PUBLIC_KEY_FIELD); - PublicKey pubKey = null; - if (request.getAlgorithm().getTrailingSignatureLength() > 0) { - try { - String serializedPubKey = request.getEncryptionContext().get(Constants.EC_PUBLIC_KEY_FIELD); - - if (serializedPubKey == null) { - throw new AwsCryptoException("Missing trailing signature public key"); - } - - pubKey = deserializeTrailingKeyFromEc(request.getAlgorithm(), serializedPubKey); - } catch (final IllegalStateException ex) { - throw new AwsCryptoException(ex); - } - } else if (request.getEncryptionContext().containsKey(Constants.EC_PUBLIC_KEY_FIELD)) { - throw new AwsCryptoException("Trailing signature public key found for non-signed algorithm"); + if (serializedPubKey == null) { + throw new AwsCryptoException("Missing trailing signature public key"); } - return DecryptionMaterials.newBuilder() - .setDataKey(dataKey) - .setTrailingSignatureKey(pubKey) - .build(); + pubKey = deserializeTrailingKeyFromEc(request.getAlgorithm(), serializedPubKey); + } catch (final IllegalStateException ex) { + throw new AwsCryptoException(ex); + } + } else if (request.getEncryptionContext().containsKey(Constants.EC_PUBLIC_KEY_FIELD)) { + throw new AwsCryptoException("Trailing signature public key found for non-signed algorithm"); } - private PublicKey deserializeTrailingKeyFromEc(CryptoAlgorithm algo, String pubKey) { - return TrailingSignatureAlgorithm.forCryptoAlgorithm(algo).deserializePublicKey(pubKey); - } - - private static String serializeTrailingKeyForEc(CryptoAlgorithm algo, KeyPair trailingKeys) { - return TrailingSignatureAlgorithm.forCryptoAlgorithm(algo).serializePublicKey(trailingKeys.getPublic()); - } - - private static KeyPair generateTrailingSigKeyPair(CryptoAlgorithm algo) throws GeneralSecurityException { - return TrailingSignatureAlgorithm.forCryptoAlgorithm(algo).generateKey(); - } + return DecryptionMaterials.newBuilder() + .setDataKey(dataKey) + .setTrailingSignatureKey(pubKey) + .build(); + } + + private PublicKey deserializeTrailingKeyFromEc(CryptoAlgorithm algo, String pubKey) { + return TrailingSignatureAlgorithm.forCryptoAlgorithm(algo).deserializePublicKey(pubKey); + } + + private static String serializeTrailingKeyForEc(CryptoAlgorithm algo, KeyPair trailingKeys) { + return TrailingSignatureAlgorithm.forCryptoAlgorithm(algo) + .serializePublicKey(trailingKeys.getPublic()); + } + + private static KeyPair generateTrailingSigKeyPair(CryptoAlgorithm algo) + throws GeneralSecurityException { + return TrailingSignatureAlgorithm.forCryptoAlgorithm(algo).generateKey(); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/EncryptedDataKey.java b/src/main/java/com/amazonaws/encryptionsdk/EncryptedDataKey.java index 4629a9e07..781d8e93c 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/EncryptedDataKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/EncryptedDataKey.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,112 +13,111 @@ package com.amazonaws.encryptionsdk; -//@ model import java.util.Arrays; -//@ model import java.nio.charset.StandardCharsets; +// @ model import java.util.Arrays; +// @ model import java.nio.charset.StandardCharsets; - -//@ nullable_by_default +// @ nullable_by_default public interface EncryptedDataKey { - - //@// An EncryptedDataKey object abstractly contains 3 pieces of data. - //@// These are represented by 3 byte arrays: - - //@ model public instance byte[] providerId; - //@ model public instance byte[] providerInformation; - //@ model public instance byte[] encryptedDataKey; - - //@// The fields of an EncryptedDataKey may be populated via deserialization. The - //@// Encryption SDK design allows the deserialization routine to be called repeatedly, - //@// each call trying to fill in information that for some reason was not possible - //@// with the previous call. In some such "intermediate" states, the deserialization - //@// is incomplete in a way that other methods don't expect. Therefore, those methods - //@// should not be called in these incomplete intermediate states. The model field - //@// isDeserializing is true in those incomplete intermediate states, and it is used - //@// in method specifications. - //@ public model instance boolean isDeserializing; - //@// There are some complications surrounding the representations of strings versus - //@// byte arrays. The serialized form in message headers is always a sequence of - //@// bytes, but the EncryptedDataKey interface (and some other methods) - //@// expose the provider ID as if it were a string. Conversions (using UTF-8) - //@// between byte arrays and strings (which in Java use UTF-16) are not bijections. - //@// For example, both "\u003f".getBytes() and "\ud800".getBytes() yield a 1-byte - //@// array containing [0x3f], and calling `new String(..., StandardCharsets.UTF_8)` - //@// with either the 1-byte array [0x80] or the 3-byte array [0xef,0xbf,0xbd] yields - //@// the string "\ufffd". Therefore, all we can say about these conversions - //@// is that a given byte[]-String pair satisfies a conversion relation. - //@// - //@// The model functions "ba2s" and "s2ba" are used to specify the conversions - //@// between byte arrays and strings: - /*@ public normal_behavior - @ requires s != null; - @ ensures \result != null; - @ function - @ public model static byte[] s2ba(String s) { - @ return s.getBytes(StandardCharsets.UTF_8); - @ } - @*/ - /*@ public normal_behavior - @ requires ba != null; - @ ensures \result != null; - @ function - @ public model static String ba2s(byte[] ba) { - @ return new String(ba, StandardCharsets.UTF_8); - @ } - @*/ - //@// The "ba2s" and "s2ba" are given function bodies above, but the verification - //@// does not rely on these function bodies directly. Instead, the code (in KeyBlob) - //@// uses "assume" statements when it necessary to connect these functions with - //@// copies of their bodies that appear in the code. This is a limitation of JML. - //@// - //@// One of the properties that holds of "s2ba(s)" is that its result depends not - //@// on the particular String reference "s" being passed in, but only the contents - //@// of the string referenced by "s". This property is captured in the following - //@// lemma: - /*@ public normal_behavior - @ requires s != null && t != null && String.equals(s, t); - @ ensures Arrays.equalArrays(s2ba(s), s2ba(t)); - @ pure - @ public model static void lemma_s2ba_depends_only_string_contents_only(String s, String t); - @*/ - //@// - //@// As a specification convenience, the model function "ba2s2ba" uses the two - //@// model functions above to convert from a byte array to a String and then back - //@// to a byte array. As mentioned above, this does not always result in a byte - //@// array with the original contents. The "assume" statements about the conversion - //@// functions need to be careful not to assume too much. - /*@ public normal_behavior - @ requires ba != null; - @ ensures \result == s2ba(ba2s(ba)); - @ function - @ public model static byte[] ba2s2ba(byte[] ba) { - @ return s2ba(ba2s(ba)); - @ } - @*/ - - //@// Here follows 3 methods that access the abstract values of interface properties. - //@// Something to note about these methods is that each one requires the property - //@// requested to be known to be non-null. For example, "getProviderId" is only allowed - //@// to be called when "providerId" is known to be non-null. - - //@ public normal_behavior - //@ requires providerId != null; - //@ ensures \result != null; - //@ ensures String.equals(\result, ba2s(providerId)); - //@ pure - public String getProviderId(); + // @// An EncryptedDataKey object abstractly contains 3 pieces of data. + // @// These are represented by 3 byte arrays: + + // @ model public instance byte[] providerId; + // @ model public instance byte[] providerInformation; + // @ model public instance byte[] encryptedDataKey; + + // @// The fields of an EncryptedDataKey may be populated via deserialization. The + // @// Encryption SDK design allows the deserialization routine to be called repeatedly, + // @// each call trying to fill in information that for some reason was not possible + // @// with the previous call. In some such "intermediate" states, the deserialization + // @// is incomplete in a way that other methods don't expect. Therefore, those methods + // @// should not be called in these incomplete intermediate states. The model field + // @// isDeserializing is true in those incomplete intermediate states, and it is used + // @// in method specifications. + // @ public model instance boolean isDeserializing; + + // @// There are some complications surrounding the representations of strings versus + // @// byte arrays. The serialized form in message headers is always a sequence of + // @// bytes, but the EncryptedDataKey interface (and some other methods) + // @// expose the provider ID as if it were a string. Conversions (using UTF-8) + // @// between byte arrays and strings (which in Java use UTF-16) are not bijections. + // @// For example, both "\u003f".getBytes() and "\ud800".getBytes() yield a 1-byte + // @// array containing [0x3f], and calling `new String(..., StandardCharsets.UTF_8)` + // @// with either the 1-byte array [0x80] or the 3-byte array [0xef,0xbf,0xbd] yields + // @// the string "\ufffd". Therefore, all we can say about these conversions + // @// is that a given byte[]-String pair satisfies a conversion relation. + // @// + // @// The model functions "ba2s" and "s2ba" are used to specify the conversions + // @// between byte arrays and strings: + /*@ public normal_behavior + @ requires s != null; + @ ensures \result != null; + @ function + @ public model static byte[] s2ba(String s) { + @ return s.getBytes(StandardCharsets.UTF_8); + @ } + @*/ + /*@ public normal_behavior + @ requires ba != null; + @ ensures \result != null; + @ function + @ public model static String ba2s(byte[] ba) { + @ return new String(ba, StandardCharsets.UTF_8); + @ } + @*/ + // @// The "ba2s" and "s2ba" are given function bodies above, but the verification + // @// does not rely on these function bodies directly. Instead, the code (in KeyBlob) + // @// uses "assume" statements when it necessary to connect these functions with + // @// copies of their bodies that appear in the code. This is a limitation of JML. + // @// + // @// One of the properties that holds of "s2ba(s)" is that its result depends not + // @// on the particular String reference "s" being passed in, but only the contents + // @// of the string referenced by "s". This property is captured in the following + // @// lemma: + /*@ public normal_behavior + @ requires s != null && t != null && String.equals(s, t); + @ ensures Arrays.equalArrays(s2ba(s), s2ba(t)); + @ pure + @ public model static void lemma_s2ba_depends_only_string_contents_only(String s, String t); + @*/ + // @// + // @// As a specification convenience, the model function "ba2s2ba" uses the two + // @// model functions above to convert from a byte array to a String and then back + // @// to a byte array. As mentioned above, this does not always result in a byte + // @// array with the original contents. The "assume" statements about the conversion + // @// functions need to be careful not to assume too much. + /*@ public normal_behavior + @ requires ba != null; + @ ensures \result == s2ba(ba2s(ba)); + @ function + @ public model static byte[] ba2s2ba(byte[] ba) { + @ return s2ba(ba2s(ba)); + @ } + @*/ + + // @// Here follows 3 methods that access the abstract values of interface properties. + // @// Something to note about these methods is that each one requires the property + // @// requested to be known to be non-null. For example, "getProviderId" is only allowed + // @// to be called when "providerId" is known to be non-null. + + // @ public normal_behavior + // @ requires providerId != null; + // @ ensures \result != null; + // @ ensures String.equals(\result, ba2s(providerId)); + // @ pure + public String getProviderId(); - //@ public normal_behavior - //@ requires providerInformation != null; - //@ ensures \fresh(\result); - //@ ensures Arrays.equalArrays(providerInformation, \result); - //@ pure - public byte[] getProviderInformation(); + // @ public normal_behavior + // @ requires providerInformation != null; + // @ ensures \fresh(\result); + // @ ensures Arrays.equalArrays(providerInformation, \result); + // @ pure + public byte[] getProviderInformation(); - //@ public normal_behavior - //@ requires encryptedDataKey != null; - //@ ensures \fresh(\result); - //@ ensures Arrays.equalArrays(encryptedDataKey, \result); - //@ pure - public byte[] getEncryptedDataKey(); + // @ public normal_behavior + // @ requires encryptedDataKey != null; + // @ ensures \fresh(\result); + // @ ensures Arrays.equalArrays(encryptedDataKey, \result); + // @ pure + public byte[] getEncryptedDataKey(); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/MasterKey.java b/src/main/java/com/amazonaws/encryptionsdk/MasterKey.java index ae64e752e..b61e5640b 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/MasterKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/MasterKey.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,120 +13,120 @@ package com.amazonaws.encryptionsdk; +import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; +import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; -import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; -import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; - /** * Represents the cryptographic key used to protect the {@link DataKey} (which, in turn, protects * the data). * - * All MasterKeys extend {@link MasterKeyProvider} because they are all capable of providing exactly - * themselves. This simplifies implementation when only a single {@link MasterKey} is used and/or - * expected. + *

All MasterKeys extend {@link MasterKeyProvider} because they are all capable of providing + * exactly themselves. This simplifies implementation when only a single {@link MasterKey} is used + * and/or expected. * - * @param - * the concrete type of the {@link MasterKey} + * @param the concrete type of the {@link MasterKey} */ public abstract class MasterKey> extends MasterKeyProvider { - public abstract String getProviderId(); + public abstract String getProviderId(); - /** - * Equivalent to calling {@link #getProviderId()}. - */ - @Override - public String getDefaultProviderId() { - return getProviderId(); - } + /** Equivalent to calling {@link #getProviderId()}. */ + @Override + public String getDefaultProviderId() { + return getProviderId(); + } - public abstract String getKeyId(); + public abstract String getKeyId(); - /** - * Generates a new {@link DataKey} which is protected by this {@link MasterKey} for use with - * {@code algorithm} and associated with the provided {@code encryptionContext}. - */ - public abstract DataKey generateDataKey(CryptoAlgorithm algorithm, Map encryptionContext); + /** + * Generates a new {@link DataKey} which is protected by this {@link MasterKey} for use with + * {@code algorithm} and associated with the provided {@code encryptionContext}. + */ + public abstract DataKey generateDataKey( + CryptoAlgorithm algorithm, Map encryptionContext); - /** - * Returns a new copy of the provided {@code dataKey} which is protected by this - * {@link MasterKey} for use with {@code algorithm} and associated with the provided - * {@code encryptionContext}. - */ - public abstract DataKey encryptDataKey(CryptoAlgorithm algorithm, Map encryptionContext, - DataKey dataKey); + /** + * Returns a new copy of the provided {@code dataKey} which is protected by this {@link MasterKey} + * for use with {@code algorithm} and associated with the provided {@code encryptionContext}. + */ + public abstract DataKey encryptDataKey( + CryptoAlgorithm algorithm, Map encryptionContext, DataKey dataKey); - /** - * Returns {@code true} if and only if {@code provider} equals {@link #getProviderId()}. - */ - @Override - public boolean canProvide(final String provider) { - return getProviderId().equals(provider); - } + /** Returns {@code true} if and only if {@code provider} equals {@link #getProviderId()}. */ + @Override + public boolean canProvide(final String provider) { + return getProviderId().equals(provider); + } - /** - * Returns {@code this} if {@code provider} and {@code keyId} match {@code this}. Otherwise, - * throws an appropriate exception. - */ - @SuppressWarnings("unchecked") - @Override - public K getMasterKey(final String provider, final String keyId) throws UnsupportedProviderException, - NoSuchMasterKeyException { - if (!canProvide(provider)) { - throw new UnsupportedProviderException("MasterKeys can only provide themselves. Requested " - + buildName(provider, keyId) + " but only " + toString() + " is available"); - } - if (!getKeyId().equals(keyId)) { - throw new NoSuchMasterKeyException("MasterKeys can only provide themselves. Requested " - + buildName(provider, keyId) + " but only " + toString() + " is available"); - } - return (K) this; + /** + * Returns {@code this} if {@code provider} and {@code keyId} match {@code this}. Otherwise, + * throws an appropriate exception. + */ + @SuppressWarnings("unchecked") + @Override + public K getMasterKey(final String provider, final String keyId) + throws UnsupportedProviderException, NoSuchMasterKeyException { + if (!canProvide(provider)) { + throw new UnsupportedProviderException( + "MasterKeys can only provide themselves. Requested " + + buildName(provider, keyId) + + " but only " + + toString() + + " is available"); } - - @Override - public String toString() { - return buildName(getProviderId(), getKeyId()); + if (!getKeyId().equals(keyId)) { + throw new NoSuchMasterKeyException( + "MasterKeys can only provide themselves. Requested " + + buildName(provider, keyId) + + " but only " + + toString() + + " is available"); } + return (K) this; + } - /** - * Returns a list of length {@code 1} containing {@code this}. - */ - @SuppressWarnings("unchecked") - @Override - public List getMasterKeysForEncryption(final MasterKeyRequest request) { - return (List) Collections.singletonList(this); - } + @Override + public String toString() { + return buildName(getProviderId(), getKeyId()); + } - private static String buildName(final String provider, final String keyId) { - return String.format("%s://%s", provider, keyId); - } + /** Returns a list of length {@code 1} containing {@code this}. */ + @SuppressWarnings("unchecked") + @Override + public List getMasterKeysForEncryption(final MasterKeyRequest request) { + return (List) Collections.singletonList(this); + } - /** - * Two {@link MasterKey}s are equal if they are instances of the exact same class and - * their values for {@code keyId}, {@code providerId}, and {@code defaultProviderId} are equal. - */ - @Override - public boolean equals(final Object obj) { - if (obj == null) { - return false; - } - if (this == obj) { - return true; - } - if (!obj.getClass().equals(getClass())) { - return false; - } - final MasterKey mk = (MasterKey) obj; - return Objects.equals(getKeyId(), mk.getKeyId()) && - Objects.equals(getProviderId(), mk.getProviderId()) && - Objects.equals(getDefaultProviderId(), mk.getDefaultProviderId()); - } + private static String buildName(final String provider, final String keyId) { + return String.format("%s://%s", provider, keyId); + } - @Override - public int hashCode() { - return Objects.hash(getKeyId(), getProviderId(), getDefaultProviderId()); + /** + * Two {@link MasterKey}s are equal if they are instances of the exact same class and + * their values for {@code keyId}, {@code providerId}, and {@code defaultProviderId} are equal. + */ + @Override + public boolean equals(final Object obj) { + if (obj == null) { + return false; + } + if (this == obj) { + return true; } + if (!obj.getClass().equals(getClass())) { + return false; + } + final MasterKey mk = (MasterKey) obj; + return Objects.equals(getKeyId(), mk.getKeyId()) + && Objects.equals(getProviderId(), mk.getProviderId()) + && Objects.equals(getDefaultProviderId(), mk.getDefaultProviderId()); + } + + @Override + public int hashCode() { + return Objects.hash(getKeyId(), getProviderId(), getDefaultProviderId()); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/MasterKeyProvider.java b/src/main/java/com/amazonaws/encryptionsdk/MasterKeyProvider.java index 978586f2e..a99e8fb67 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/MasterKeyProvider.java +++ b/src/main/java/com/amazonaws/encryptionsdk/MasterKeyProvider.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,101 +13,97 @@ package com.amazonaws.encryptionsdk; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; - import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; /** * Represents the logic necessary to select and construct {@link MasterKey}s for encrypting and * decrypting messages. This is an abstract class. * - * @param - * the type of {@link MasterKey} returned by this provider + * @param the type of {@link MasterKey} returned by this provider */ public abstract class MasterKeyProvider> { - /** - * ProviderId used by this instance when no other is specified. - */ - public abstract String getDefaultProviderId(); + /** ProviderId used by this instance when no other is specified. */ + public abstract String getDefaultProviderId(); - /** - * Returns true if this MasterKeyProvider can provide keys from the specified @{code provider}. - * - * @param provider - * @return - */ - public boolean canProvide(final String provider) { - return getDefaultProviderId().equals(provider); - } + /** + * Returns true if this MasterKeyProvider can provide keys from the specified @{code provider}. + * + * @param provider + * @return + */ + public boolean canProvide(final String provider) { + return getDefaultProviderId().equals(provider); + } - /** - * Equivalent to calling {@link #getMasterKey(String, String)} using - * {@link #getDefaultProviderId()} as the provider. - */ - public K getMasterKey(final String keyId) throws UnsupportedProviderException, NoSuchMasterKeyException { - return getMasterKey(getDefaultProviderId(), keyId); - } + /** + * Equivalent to calling {@link #getMasterKey(String, String)} using {@link + * #getDefaultProviderId()} as the provider. + */ + public K getMasterKey(final String keyId) + throws UnsupportedProviderException, NoSuchMasterKeyException { + return getMasterKey(getDefaultProviderId(), keyId); + } - /** - * Returns the specified {@link MasterKey} if possible. - * - * @param provider - * @param keyId - * @return - * @throws UnsupportedProviderException - * if this object cannot return {@link MasterKey}s associated with the given - * provider - * @throws NoSuchMasterKeyException - * if this object cannot find (and thus construct) the {@link MasterKey} associated - * with {@code keyId} - */ - public abstract K getMasterKey(String provider, String keyId) throws UnsupportedProviderException, - NoSuchMasterKeyException; + /** + * Returns the specified {@link MasterKey} if possible. + * + * @param provider + * @param keyId + * @return + * @throws UnsupportedProviderException if this object cannot return {@link MasterKey}s associated + * with the given provider + * @throws NoSuchMasterKeyException if this object cannot find (and thus construct) the {@link + * MasterKey} associated with {@code keyId} + */ + public abstract K getMasterKey(String provider, String keyId) + throws UnsupportedProviderException, NoSuchMasterKeyException; - /** - * Returns all {@link MasterKey}s which should be used to protect the plaintext described by - * {@code request}. - */ - public abstract List getMasterKeysForEncryption(MasterKeyRequest request); + /** + * Returns all {@link MasterKey}s which should be used to protect the plaintext described by + * {@code request}. + */ + public abstract List getMasterKeysForEncryption(MasterKeyRequest request); - /** - * Iterates through {@code encryptedDataKeys} and returns the first one which can be - * successfully decrypted. - * - * @return a DataKey if one can be decrypted, otherwise returns {@code null} - * @throws UnsupportedProviderException - * if the {@code encryptedDataKey} is associated with an unsupported provider - * @throws CannotUnwrapDataKeyException - * if the {@code encryptedDataKey} cannot be decrypted - */ - public abstract DataKey decryptDataKey(CryptoAlgorithm algorithm, - Collection encryptedDataKeys, Map encryptionContext) - throws UnsupportedProviderException, AwsCryptoException; + /** + * Iterates through {@code encryptedDataKeys} and returns the first one which can be successfully + * decrypted. + * + * @return a DataKey if one can be decrypted, otherwise returns {@code null} + * @throws UnsupportedProviderException if the {@code encryptedDataKey} is associated with an + * unsupported provider + * @throws CannotUnwrapDataKeyException if the {@code encryptedDataKey} cannot be decrypted + */ + public abstract DataKey decryptDataKey( + CryptoAlgorithm algorithm, + Collection encryptedDataKeys, + Map encryptionContext) + throws UnsupportedProviderException, AwsCryptoException; - protected AwsCryptoException buildCannotDecryptDksException() { - return buildCannotDecryptDksException(Collections. emptyList()); - } + protected AwsCryptoException buildCannotDecryptDksException() { + return buildCannotDecryptDksException(Collections.emptyList()); + } - protected AwsCryptoException buildCannotDecryptDksException(Throwable t) { - return buildCannotDecryptDksException(Collections.singletonList(t)); - } + protected AwsCryptoException buildCannotDecryptDksException(Throwable t) { + return buildCannotDecryptDksException(Collections.singletonList(t)); + } - protected AwsCryptoException buildCannotDecryptDksException(List t) { - if (t == null || t.isEmpty()) { - return new CannotUnwrapDataKeyException("Unable to decrypt any data keys"); - } else { - final CannotUnwrapDataKeyException ex = new CannotUnwrapDataKeyException("Unable to decrypt any data keys", - t.get(0)); - for (final Throwable e : t) { - ex.addSuppressed(e); - } - return ex; - } + protected AwsCryptoException buildCannotDecryptDksException(List t) { + if (t == null || t.isEmpty()) { + return new CannotUnwrapDataKeyException("Unable to decrypt any data keys"); + } else { + final CannotUnwrapDataKeyException ex = + new CannotUnwrapDataKeyException("Unable to decrypt any data keys", t.get(0)); + for (final Throwable e : t) { + ex.addSuppressed(e); + } + return ex; } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/MasterKeyRequest.java b/src/main/java/com/amazonaws/encryptionsdk/MasterKeyRequest.java index 650fd2ee4..1b6350c3c 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/MasterKeyRequest.java +++ b/src/main/java/com/amazonaws/encryptionsdk/MasterKeyRequest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,120 +13,120 @@ package com.amazonaws.encryptionsdk; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.util.Collections; import java.util.HashMap; import java.util.Map; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; - /** * Contains information which {@link MasterKeyProvider}s can use to select which {@link MasterKey}s * to use to protect a given plaintext. This class is immutable. */ public final class MasterKeyRequest { - private final Map encryptionContext_; - private final boolean isStreaming_; - private final byte[] plaintext_; - private final long size_; - - private MasterKeyRequest(final Map encryptionContext, final boolean isStreaming, - final byte[] plaintext, final long size) { - encryptionContext_ = encryptionContext; - isStreaming_ = isStreaming; - plaintext_ = plaintext; - size_ = size; - } + private final Map encryptionContext_; + private final boolean isStreaming_; + private final byte[] plaintext_; + private final long size_; + + private MasterKeyRequest( + final Map encryptionContext, + final boolean isStreaming, + final byte[] plaintext, + final long size) { + encryptionContext_ = encryptionContext; + isStreaming_ = isStreaming; + plaintext_ = plaintext; + size_ = size; + } + + public Map getEncryptionContext() { + return encryptionContext_; + } + + public boolean isStreaming() { + return isStreaming_; + } + + /** The plaintext, if available, to be protected by this request. Otherwise, {@code null}. */ + public byte[] getPlaintext() { + return plaintext_ != null ? plaintext_.clone() : null; + } + + /** The size of the plaintext, if available. Otherwise {@code -1}. */ + public long getSize() { + return size_; + } + + public static Builder newBuilder() { + return new Builder(); + } + + public static final class Builder { + private Map encryptionContext_ = new HashMap<>(); + private boolean isStreaming_ = false; + private byte[] plaintext_ = null; + private long size_ = -1; public Map getEncryptionContext() { - return encryptionContext_; + return encryptionContext_; + } + + public Builder setEncryptionContext(final Map encryptionContext) { + encryptionContext_ = encryptionContext; + return this; } public boolean isStreaming() { - return isStreaming_; + return isStreaming_; + } + + public Builder setStreaming(final boolean isStreaming) { + isStreaming_ = isStreaming; + return this; } /** - * The plaintext, if available, to be protected by this request. Otherwise, {@code null}. + * Please note that this does not make a defensive copy of the plaintext and so any + * modifications made to the backing array will be reflected in this Builder. */ + @SuppressFBWarnings("EI_EXPOSE_REP") public byte[] getPlaintext() { - return plaintext_ != null ? plaintext_.clone() : null; + return plaintext_; } /** - * The size of the plaintext, if available. Otherwise {@code -1}. + * Please note that this does not make a defensive copy of the plaintext and so any + * modifications made to the backing array will be reflected in this Builder. */ - public long getSize() { - return size_; + @SuppressFBWarnings("EI_EXPOSE_REP") + public Builder setPlaintext(final byte[] plaintext) { + if (size_ != -1) { + throw new IllegalStateException( + "The plaintext may only be set if the size has not been explicitly set"); + } + plaintext_ = plaintext; + return this; } - public static Builder newBuilder() { - return new Builder(); + public Builder setSize(final long size) { + if (plaintext_ != null) { + throw new IllegalStateException( + "Size may only explicitly set when the plaintext is not set"); + } + size_ = size; + return this; + } + + public long getSize() { + return size_; } - public final static class Builder { - private Map encryptionContext_ = new HashMap<>(); - private boolean isStreaming_ = false; - private byte[] plaintext_ = null; - private long size_ = -1; - - public Map getEncryptionContext() { - return encryptionContext_; - } - - public Builder setEncryptionContext(final Map encryptionContext) { - encryptionContext_ = encryptionContext; - return this; - } - - public boolean isStreaming() { - return isStreaming_; - } - - public Builder setStreaming(final boolean isStreaming) { - isStreaming_ = isStreaming; - return this; - } - - /** - * Please note that this does not make a defensive copy of the plaintext and so any - * modifications made to the backing array will be reflected in this Builder. - */ - @SuppressFBWarnings("EI_EXPOSE_REP") - public byte[] getPlaintext() { - return plaintext_; - } - - /** - * Please note that this does not make a defensive copy of the plaintext and so any - * modifications made to the backing array will be reflected in this Builder. - */ - @SuppressFBWarnings("EI_EXPOSE_REP") - public Builder setPlaintext(final byte[] plaintext) { - if (size_ != -1) { - throw new IllegalStateException("The plaintext may only be set if the size has not been explicitly set"); - } - plaintext_ = plaintext; - return this; - } - - public Builder setSize(final long size) { - if (plaintext_ != null) { - throw new IllegalStateException("Size may only explicitly set when the plaintext is not set"); - } - size_ = size; - return this; - } - - public long getSize() { - return size_; - } - - public MasterKeyRequest build() { - return new MasterKeyRequest( - Collections.unmodifiableMap(new HashMap<>(encryptionContext_)), - isStreaming_, - plaintext_, - plaintext_ != null ? plaintext_.length : size_); - } + public MasterKeyRequest build() { + return new MasterKeyRequest( + Collections.unmodifiableMap(new HashMap<>(encryptionContext_)), + isStreaming_, + plaintext_, + plaintext_ != null ? plaintext_.length : size_); } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/ParsedCiphertext.java b/src/main/java/com/amazonaws/encryptionsdk/ParsedCiphertext.java index 13f0926b1..998a9beb6 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/ParsedCiphertext.java +++ b/src/main/java/com/amazonaws/encryptionsdk/ParsedCiphertext.java @@ -13,61 +13,58 @@ package com.amazonaws.encryptionsdk; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.internal.Utils; import com.amazonaws.encryptionsdk.model.CiphertextHeaders; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; - import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; /** * Exposes header information of ciphertexts to make it easier to inspect the algorithm, keys, and * encryption context prior to decryption. * - * Please note that the class does not make defensive copies. + *

Please note that the class does not make defensive copies. */ public class ParsedCiphertext extends CiphertextHeaders { - private final byte[] ciphertext_; - private final int offset_; + private final byte[] ciphertext_; + private final int offset_; - /** - * Parses {@code ciphertext}. Please note that this does not make a defensive copy of - * {@code ciphertext} and that any changes made to the backing array will be reflected here as - * well. - * - * @param ciphertext The ciphertext to parse - * @param maxEncryptedDataKeys The maximum number of encrypted data keys to parse. - * Zero indicates no maximum. - */ - public ParsedCiphertext(final byte[] ciphertext, final int maxEncryptedDataKeys) { - ciphertext_ = Utils.assertNonNull(ciphertext, "ciphertext"); - offset_ = deserialize(ciphertext_, 0, maxEncryptedDataKeys); - if (!this.isComplete()) { - throw new BadCiphertextException("Incomplete ciphertext."); - } + /** + * Parses {@code ciphertext}. Please note that this does not make a defensive copy of + * {@code ciphertext} and that any changes made to the backing array will be reflected here as + * well. + * + * @param ciphertext The ciphertext to parse + * @param maxEncryptedDataKeys The maximum number of encrypted data keys to parse. Zero indicates + * no maximum. + */ + public ParsedCiphertext(final byte[] ciphertext, final int maxEncryptedDataKeys) { + ciphertext_ = Utils.assertNonNull(ciphertext, "ciphertext"); + offset_ = deserialize(ciphertext_, 0, maxEncryptedDataKeys); + if (!this.isComplete()) { + throw new BadCiphertextException("Incomplete ciphertext."); } + } - /** - * Parses {@code ciphertext} without enforcing a max EDK count. Please note that this does - * not make a defensive copy of {@code ciphertext} and that any changes made to the - * backing array will be reflected here as well. - */ - public ParsedCiphertext(final byte[] ciphertext) { - this(ciphertext, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - } + /** + * Parses {@code ciphertext} without enforcing a max EDK count. Please note that this does + * not make a defensive copy of {@code ciphertext} and that any changes made to the + * backing array will be reflected here as well. + */ + public ParsedCiphertext(final byte[] ciphertext) { + this(ciphertext, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + } - /** - * Returns the raw ciphertext backing this object. This is not a defensive copy and so - * must not be modified by callers. - */ - @SuppressFBWarnings("EI_EXPOSE_REP") - public byte[] getCiphertext() { - return ciphertext_; - } + /** + * Returns the raw ciphertext backing this object. This is not a defensive copy and so + * must not be modified by callers. + */ + @SuppressFBWarnings("EI_EXPOSE_REP") + public byte[] getCiphertext() { + return ciphertext_; + } - /** - * The offset at which the first non-header byte in {@code ciphertext} is located. - */ - public int getOffset() { - return offset_; - } + /** The offset at which the first non-header byte in {@code ciphertext} is located. */ + public int getOffset() { + return offset_; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/caching/CachingCryptoMaterialsManager.java b/src/main/java/com/amazonaws/encryptionsdk/caching/CachingCryptoMaterialsManager.java index 64a574282..2511939e6 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/caching/CachingCryptoMaterialsManager.java +++ b/src/main/java/com/amazonaws/encryptionsdk/caching/CachingCryptoMaterialsManager.java @@ -1,12 +1,5 @@ package com.amazonaws.encryptionsdk.caching; -import java.nio.charset.StandardCharsets; -import java.security.GeneralSecurityException; -import java.security.MessageDigest; -import java.util.ArrayList; -import java.util.UUID; -import java.util.concurrent.TimeUnit; - import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.CryptoMaterialsManager; import com.amazonaws.encryptionsdk.DefaultCryptoMaterialsManager; @@ -14,378 +7,408 @@ import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.internal.EncryptionContextSerializer; import com.amazonaws.encryptionsdk.internal.Utils; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.KeyBlob; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.security.MessageDigest; +import java.util.ArrayList; +import java.util.UUID; +import java.util.concurrent.TimeUnit; /** - * The CachingCryptoMaterialsManager wraps another {@link CryptoMaterialsManager}, and caches its results. This helps reduce - * the number of calls made to the underlying {@link CryptoMaterialsManager} and/or {@link MasterKeyProvider}, which may - * help reduce cost and/or improve performance. + * The CachingCryptoMaterialsManager wraps another {@link CryptoMaterialsManager}, and caches its + * results. This helps reduce the number of calls made to the underlying {@link + * CryptoMaterialsManager} and/or {@link MasterKeyProvider}, which may help reduce cost and/or + * improve performance. * - * The CachingCryptoMaterialsManager helps enforce a number of usage limits on encrypt. Specifically, it limits the number of - * individual messages encrypted with a particular data key, and the number of plaintext bytes encrypted with the same - * data key. It also allows you to configure a maximum time-to-live for cache entries. + *

The CachingCryptoMaterialsManager helps enforce a number of usage limits on encrypt. + * Specifically, it limits the number of individual messages encrypted with a particular data key, + * and the number of plaintext bytes encrypted with the same data key. It also allows you to + * configure a maximum time-to-live for cache entries. * - * Note that when performing streaming encryption operations, unless you set the stream size before writing any data - * using {@link com.amazonaws.encryptionsdk.CryptoOutputStream#setMaxInputLength(long)} or - * {@link com.amazonaws.encryptionsdk.CryptoInputStream#setMaxInputLength(long)}, the size of the message will not be - * known, and to avoid exceeding byte use limits, caching will not be performed. + *

Note that when performing streaming encryption operations, unless you set the stream size + * before writing any data using {@link + * com.amazonaws.encryptionsdk.CryptoOutputStream#setMaxInputLength(long)} or {@link + * com.amazonaws.encryptionsdk.CryptoInputStream#setMaxInputLength(long)}, the size of the message + * will not be known, and to avoid exceeding byte use limits, caching will not be performed. * - * By default, two different {@link CachingCryptoMaterialsManager}s will not share cached entries, even when using the same - * {@link CryptoMaterialsCache}. However, it's possible to make different {@link CachingCryptoMaterialsManager}s share the same - * cached entries by assigning a partition ID to them; all {@link CachingCryptoMaterialsManager}s with the same partition ID - * will share the same cached entries. + *

By default, two different {@link CachingCryptoMaterialsManager}s will not share cached + * entries, even when using the same {@link CryptoMaterialsCache}. However, it's possible to make + * different {@link CachingCryptoMaterialsManager}s share the same cached entries by assigning a + * partition ID to them; all {@link CachingCryptoMaterialsManager}s with the same partition ID will + * share the same cached entries. * - * Assigning partition IDs manually requires great care; if the backing {@link CryptoMaterialsManager}s are not - * equivalent, having entries cross over between them can result in problems such as encrypting messages to the wrong - * key, or accidentally bypassing access controls. For this reason we recommend not supplying a partition ID unless - * required for your use case. + *

Assigning partition IDs manually requires great care; if the backing {@link + * CryptoMaterialsManager}s are not equivalent, having entries cross over between them can result in + * problems such as encrypting messages to the wrong key, or accidentally bypassing access controls. + * For this reason we recommend not supplying a partition ID unless required for your use case. */ public class CachingCryptoMaterialsManager implements CryptoMaterialsManager { - private static final String CACHE_ID_HASH_ALGORITHM = "SHA-512"; - private static final long MAX_MESSAGE_USE_LIMIT = 1L << 32; - private static final long MAX_BYTE_USE_LIMIT = Long.MAX_VALUE; // 2^63 - 1 + private static final String CACHE_ID_HASH_ALGORITHM = "SHA-512"; + private static final long MAX_MESSAGE_USE_LIMIT = 1L << 32; + private static final long MAX_BYTE_USE_LIMIT = Long.MAX_VALUE; // 2^63 - 1 - private final CryptoMaterialsManager backingCMM; - private final CryptoMaterialsCache cache; + private final CryptoMaterialsManager backingCMM; + private final CryptoMaterialsCache cache; - private final byte[] partitionIdHash; - private final String partitionId; + private final byte[] partitionIdHash; + private final String partitionId; - private final long maxAgeMs; - private final long messageUseLimit; - private final long byteUseLimit; - - private final CryptoMaterialsCache.CacheHint hint = new CryptoMaterialsCache.CacheHint() { - @Override public long getMaxAgeMillis() { - return maxAgeMs; - } - }; - - public static class Builder { - private CryptoMaterialsManager backingCMM; - private CryptoMaterialsCache cache; - private String partitionId = null; - private long maxAge = 0; - private long messageUseLimit = MAX_MESSAGE_USE_LIMIT; - private long byteUseLimit = Long.MAX_VALUE; - - private Builder() {} - - /** - * Sets the {@link CryptoMaterialsManager} that should be queried when the {@link CachingCryptoMaterialsManager} - * (CCMM) incurs a cache miss. - * - * You can set either a MasterKeyProvider or a CryptoMaterialsManager to back the CCMM - the last value set will - * be used. - * - * @param backingCMM The CryptoMaterialsManager to invoke on cache misses - * @return this builder - */ - public Builder withBackingMaterialsManager(CryptoMaterialsManager backingCMM) { - this.backingCMM = backingCMM; - return this; - } - - /** - * Sets the {@link MasterKeyProvider} that should be queried when the {@link CachingCryptoMaterialsManager} - * (CCMM) incurs a cache miss. - * - * You can set either a MasterKeyProvider or a CryptoMaterialsManager to back the CCMM - the last value set will - * be used. - * - * This method is equivalent to calling {@link #withBackingMaterialsManager(CryptoMaterialsManager)} passing a - * {@link DefaultCryptoMaterialsManager} constructed using your {@link MasterKeyProvider}. - * - * @param mkp The MasterKeyProvider to invoke on cache misses - * @return this builder - */ - public Builder withMasterKeyProvider(MasterKeyProvider mkp) { - return withBackingMaterialsManager(new DefaultCryptoMaterialsManager(mkp)); - } - - /** - * Sets the cache to which this {@link CryptoMaterialsManager} will be bound. - * @param cache The cache to associate with the CMM - * @return this builder - */ - public Builder withCache(CryptoMaterialsCache cache) { - this.cache = cache; - return this; - } - - /** - * Sets the partition ID for this CMM. This is an optional operation. - * - * By default, two CMMs will never use each other's cache entries. This helps ensure that CMMs with different - * delegates won't incorrectly use each other's encrypt and decrypt results. However, in certain special - * circumstances it can be useful to share entries between different CMMs - for example, if the backing CMM is - * constructed based on some parameters that depend on the operation, you may wish for delegates constructed - * with the same parameters to share the same partition. - * - * To accomplish this, set the same partition ID and backing cache on both CMMs; entries cached from one of - * these CMMs can then be used by the other. This should only be done with careful consideration and - * verification that the CMM delegates are equivalent for your application's purposes. - * - * By default, the partition ID is set to a random UUID to avoid any collisions. - * - * @param partitionId The partition ID - * @return this builder - */ - public Builder withPartitionId(String partitionId) { - this.partitionId = partitionId; - return this; - } - - /** - * Sets the maximum lifetime for entries in the cache, for both encrypt and decrypt operations. When the - * specified amount of time passes after initial creation of the entry, the entry will be considered unusable, - * and the next operation will incur a cache miss. - * - * @param maxAge The amount of time entries are allowed to live. Must be positive. - * @param units The units maxAge is expressed in - * @return this builder - */ - public Builder withMaxAge(long maxAge, TimeUnit units) { - if (maxAge <= 0) { - throw new IllegalArgumentException("Max age must be positive"); - } - - this.maxAge = units.toMillis(maxAge); - return this; - } + private final long maxAgeMs; + private final long messageUseLimit; + private final long byteUseLimit; - /** - * Sets the maximum number of individual messages that can be encrypted under the same a cached data key. This - * does not affect decrypt operations. - * - * Specifying this limit is optional; by default, the limit is set to 2^32. This is also the maximum accepted - * value; if you specify a higher limit, an {@link IllegalArgumentException} will be thrown. - * - * @param messageUseLimit The maximum number of messages that can be encrypted by the same data key. Must be - * positive. - * @return this builder - */ - public Builder withMessageUseLimit(long messageUseLimit) { - if (messageUseLimit <= 0) { - throw new IllegalArgumentException("Message use limit must be positive"); - } - - if (messageUseLimit > MAX_MESSAGE_USE_LIMIT) { - throw new IllegalArgumentException("Message use limit exceeds limit of " + MAX_MESSAGE_USE_LIMIT); - } - - // We limit the number of messages encrypted under the same data key primarily to stay far away from any - // chance of message ID collisions (and therefore collisions of the key+IV used for the actual message - // encryption). - this.messageUseLimit = messageUseLimit; - return this; - } - - /** - * Sets the maximum number of plaintext bytes that can be encrypted under the same a cached data key. This does - * not affect decrypt operations. - * - * Specifying this limit is optional; by default, the limit is set to 2^63 - 1. - * - * While this limit can be set to zero, in this case keys can only be cached if they are used for zero-length - * messages. - * - * @param byteUseLimit The maximum number of bytes that can be encrypted by the same data key. Must be - * non-negative. - * - * @return this builder - */ - public Builder withByteUseLimit(long byteUseLimit) { - if (byteUseLimit < 0) { - throw new IllegalArgumentException("Byte use limit must be non-negative"); - } - - // Currently, since the byte use limit is Long.MAX_VALUE, this can't be reached, but is included for - // consistency. - - //noinspection ConstantConditions - if (byteUseLimit > MAX_BYTE_USE_LIMIT) { - throw new IllegalArgumentException("Byte use limit exceeds maximum of " + MAX_BYTE_USE_LIMIT); - } - - this.byteUseLimit = byteUseLimit; - return this; + private final CryptoMaterialsCache.CacheHint hint = + new CryptoMaterialsCache.CacheHint() { + @Override + public long getMaxAgeMillis() { + return maxAgeMs; } + }; + + public static class Builder { + private CryptoMaterialsManager backingCMM; + private CryptoMaterialsCache cache; + private String partitionId = null; + private long maxAge = 0; + private long messageUseLimit = MAX_MESSAGE_USE_LIMIT; + private long byteUseLimit = Long.MAX_VALUE; + + private Builder() {} + + /** + * Sets the {@link CryptoMaterialsManager} that should be queried when the {@link + * CachingCryptoMaterialsManager} (CCMM) incurs a cache miss. + * + *

You can set either a MasterKeyProvider or a CryptoMaterialsManager to back the CCMM - the + * last value set will be used. + * + * @param backingCMM The CryptoMaterialsManager to invoke on cache misses + * @return this builder + */ + public Builder withBackingMaterialsManager(CryptoMaterialsManager backingCMM) { + this.backingCMM = backingCMM; + return this; + } - public CachingCryptoMaterialsManager build() { - if (backingCMM == null) { - throw new IllegalArgumentException("Backing CMM must be set"); - } + /** + * Sets the {@link MasterKeyProvider} that should be queried when the {@link + * CachingCryptoMaterialsManager} (CCMM) incurs a cache miss. + * + *

You can set either a MasterKeyProvider or a CryptoMaterialsManager to back the CCMM - the + * last value set will be used. + * + *

This method is equivalent to calling {@link + * #withBackingMaterialsManager(CryptoMaterialsManager)} passing a {@link + * DefaultCryptoMaterialsManager} constructed using your {@link MasterKeyProvider}. + * + * @param mkp The MasterKeyProvider to invoke on cache misses + * @return this builder + */ + public Builder withMasterKeyProvider(MasterKeyProvider mkp) { + return withBackingMaterialsManager(new DefaultCryptoMaterialsManager(mkp)); + } - if (cache == null) { - throw new IllegalArgumentException("Cache must be set"); - } + /** + * Sets the cache to which this {@link CryptoMaterialsManager} will be bound. + * + * @param cache The cache to associate with the CMM + * @return this builder + */ + public Builder withCache(CryptoMaterialsCache cache) { + this.cache = cache; + return this; + } - if (maxAge <= 0) { - throw new IllegalArgumentException("Max age must be set"); - } + /** + * Sets the partition ID for this CMM. This is an optional operation. + * + *

By default, two CMMs will never use each other's cache entries. This helps ensure that + * CMMs with different delegates won't incorrectly use each other's encrypt and decrypt results. + * However, in certain special circumstances it can be useful to share entries between different + * CMMs - for example, if the backing CMM is constructed based on some parameters that depend on + * the operation, you may wish for delegates constructed with the same parameters to share the + * same partition. + * + *

To accomplish this, set the same partition ID and backing cache on both CMMs; entries + * cached from one of these CMMs can then be used by the other. This should only be done with + * careful consideration and verification that the CMM delegates are equivalent for your + * application's purposes. + * + *

By default, the partition ID is set to a random UUID to avoid any collisions. + * + * @param partitionId The partition ID + * @return this builder + */ + public Builder withPartitionId(String partitionId) { + this.partitionId = partitionId; + return this; + } - return new CachingCryptoMaterialsManager(this); - } + /** + * Sets the maximum lifetime for entries in the cache, for both encrypt and decrypt operations. + * When the specified amount of time passes after initial creation of the entry, the entry will + * be considered unusable, and the next operation will incur a cache miss. + * + * @param maxAge The amount of time entries are allowed to live. Must be positive. + * @param units The units maxAge is expressed in + * @return this builder + */ + public Builder withMaxAge(long maxAge, TimeUnit units) { + if (maxAge <= 0) { + throw new IllegalArgumentException("Max age must be positive"); + } + + this.maxAge = units.toMillis(maxAge); + return this; } - public static Builder newBuilder() { - return new Builder(); + /** + * Sets the maximum number of individual messages that can be encrypted under the same a cached + * data key. This does not affect decrypt operations. + * + *

Specifying this limit is optional; by default, the limit is set to 2^32. This is also the + * maximum accepted value; if you specify a higher limit, an {@link IllegalArgumentException} + * will be thrown. + * + * @param messageUseLimit The maximum number of messages that can be encrypted by the same data + * key. Must be positive. + * @return this builder + */ + public Builder withMessageUseLimit(long messageUseLimit) { + if (messageUseLimit <= 0) { + throw new IllegalArgumentException("Message use limit must be positive"); + } + + if (messageUseLimit > MAX_MESSAGE_USE_LIMIT) { + throw new IllegalArgumentException( + "Message use limit exceeds limit of " + MAX_MESSAGE_USE_LIMIT); + } + + // We limit the number of messages encrypted under the same data key primarily to stay far + // away from any + // chance of message ID collisions (and therefore collisions of the key+IV used for the actual + // message + // encryption). + this.messageUseLimit = messageUseLimit; + return this; } - private CachingCryptoMaterialsManager(Builder builder) { - this.backingCMM = builder.backingCMM; - this.cache = builder.cache; - this.partitionId = builder.partitionId != null ? builder.partitionId : UUID.randomUUID().toString(); - this.maxAgeMs = builder.maxAge; - this.messageUseLimit = builder.messageUseLimit; - this.byteUseLimit = builder.byteUseLimit; - - try { - this.partitionIdHash = MessageDigest.getInstance(CACHE_ID_HASH_ALGORITHM).digest( - partitionId.getBytes(StandardCharsets.UTF_8) - ); - } catch (GeneralSecurityException e) { - throw new AwsCryptoException(e); - } + /** + * Sets the maximum number of plaintext bytes that can be encrypted under the same a cached data + * key. This does not affect decrypt operations. + * + *

Specifying this limit is optional; by default, the limit is set to 2^63 - 1. + * + *

While this limit can be set to zero, in this case keys can only be cached if they are used + * for zero-length messages. + * + * @param byteUseLimit The maximum number of bytes that can be encrypted by the same data key. + * Must be non-negative. + * @return this builder + */ + public Builder withByteUseLimit(long byteUseLimit) { + if (byteUseLimit < 0) { + throw new IllegalArgumentException("Byte use limit must be non-negative"); + } + + // Currently, since the byte use limit is Long.MAX_VALUE, this can't be reached, but is + // included for + // consistency. + + //noinspection ConstantConditions + if (byteUseLimit > MAX_BYTE_USE_LIMIT) { + throw new IllegalArgumentException( + "Byte use limit exceeds maximum of " + MAX_BYTE_USE_LIMIT); + } + + this.byteUseLimit = byteUseLimit; + return this; } - @Override public EncryptionMaterials getMaterialsForEncrypt( - EncryptionMaterialsRequest request - ) { - // We cannot correctly enforce size limits if the request has no known plaintext size, so bypass the cache in - // this case. - if (request.getPlaintextSize() == -1) { - return backingCMM.getMaterialsForEncrypt(request); - } + public CachingCryptoMaterialsManager build() { + if (backingCMM == null) { + throw new IllegalArgumentException("Backing CMM must be set"); + } - // Strip off information on the plaintext length & contents - we do this because we will be (potentially) - // reusing the result from the backing CMM across multiple requests, and as such it would be misleading to pass on - // the first such request's information to the backing CMM. + if (cache == null) { + throw new IllegalArgumentException("Cache must be set"); + } - EncryptionMaterialsRequest upstreamRequest = request.toBuilder() - .setPlaintext(null) - .setPlaintextSize(-1) - .build(); + if (maxAge <= 0) { + throw new IllegalArgumentException("Max age must be set"); + } - byte[] cacheId = getCacheIdentifier(upstreamRequest); + return new CachingCryptoMaterialsManager(this); + } + } + + public static Builder newBuilder() { + return new Builder(); + } + + private CachingCryptoMaterialsManager(Builder builder) { + this.backingCMM = builder.backingCMM; + this.cache = builder.cache; + this.partitionId = + builder.partitionId != null ? builder.partitionId : UUID.randomUUID().toString(); + this.maxAgeMs = builder.maxAge; + this.messageUseLimit = builder.messageUseLimit; + this.byteUseLimit = builder.byteUseLimit; + + try { + this.partitionIdHash = + MessageDigest.getInstance(CACHE_ID_HASH_ALGORITHM) + .digest(partitionId.getBytes(StandardCharsets.UTF_8)); + } catch (GeneralSecurityException e) { + throw new AwsCryptoException(e); + } + } + + @Override + public EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request) { + // We cannot correctly enforce size limits if the request has no known plaintext size, so bypass + // the cache in + // this case. + if (request.getPlaintextSize() == -1) { + return backingCMM.getMaterialsForEncrypt(request); + } - CryptoMaterialsCache.UsageStats increment = initialIncrementForRequest(request); + // Strip off information on the plaintext length & contents - we do this because we will be + // (potentially) + // reusing the result from the backing CMM across multiple requests, and as such it would be + // misleading to pass on + // the first such request's information to the backing CMM. - // If our plaintext size is such that even a brand new entry would reach or exceed cache limits, there's no - // point in accessing the cache - in fact, doing so would poison a cache entry that could potentially be still - // used for a smaller request. So we'll bypass the cache and just call the backing CMM directly in this case. - if (increment.getBytesEncrypted() >= byteUseLimit) { - return backingCMM.getMaterialsForEncrypt(request); - } + EncryptionMaterialsRequest upstreamRequest = + request.toBuilder().setPlaintext(null).setPlaintextSize(-1).build(); - CryptoMaterialsCache.EncryptCacheEntry entry = cache.getEntryForEncrypt(cacheId, increment); - if (entry != null - && !isEntryExpired(entry.getEntryCreationTime()) - && !hasExceededLimits(entry.getUsageStats())) { - return entry.getResult(); - } else if (entry != null) { - // entry has potentially expired, so hint to the cache that it should be removed, in case the cache stores - // multiple entries or something - entry.invalidate(); - } - - // Cache miss. - EncryptionMaterials result = backingCMM.getMaterialsForEncrypt(request); + byte[] cacheId = getCacheIdentifier(upstreamRequest); - if (result.getAlgorithm().isSafeToCache()) { - cache.putEntryForEncrypt(cacheId, result, hint, initialIncrementForRequest(request)); - } + CryptoMaterialsCache.UsageStats increment = initialIncrementForRequest(request); - return result; + // If our plaintext size is such that even a brand new entry would reach or exceed cache limits, + // there's no + // point in accessing the cache - in fact, doing so would poison a cache entry that could + // potentially be still + // used for a smaller request. So we'll bypass the cache and just call the backing CMM directly + // in this case. + if (increment.getBytesEncrypted() >= byteUseLimit) { + return backingCMM.getMaterialsForEncrypt(request); } - private boolean hasExceededLimits(final CryptoMaterialsCache.UsageStats stats) { - return stats.getBytesEncrypted() > byteUseLimit - || stats.getMessagesEncrypted() > messageUseLimit; + CryptoMaterialsCache.EncryptCacheEntry entry = cache.getEntryForEncrypt(cacheId, increment); + if (entry != null + && !isEntryExpired(entry.getEntryCreationTime()) + && !hasExceededLimits(entry.getUsageStats())) { + return entry.getResult(); + } else if (entry != null) { + // entry has potentially expired, so hint to the cache that it should be removed, in case the + // cache stores + // multiple entries or something + entry.invalidate(); } - private boolean isEntryExpired(final long entryCreationTime) { - return System.currentTimeMillis() - entryCreationTime > maxAgeMs; - } + // Cache miss. + EncryptionMaterials result = backingCMM.getMaterialsForEncrypt(request); - private CryptoMaterialsCache.UsageStats initialIncrementForRequest(EncryptionMaterialsRequest request) { - return new CryptoMaterialsCache.UsageStats(request.getPlaintextSize(), 1); + if (result.getAlgorithm().isSafeToCache()) { + cache.putEntryForEncrypt(cacheId, result, hint, initialIncrementForRequest(request)); } - @Override public DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request) { - byte[] cacheId = getCacheIdentifier(request); + return result; + } - CryptoMaterialsCache.DecryptCacheEntry entry = cache.getEntryForDecrypt(cacheId); + private boolean hasExceededLimits(final CryptoMaterialsCache.UsageStats stats) { + return stats.getBytesEncrypted() > byteUseLimit + || stats.getMessagesEncrypted() > messageUseLimit; + } - if (entry != null && !isEntryExpired(entry.getEntryCreationTime())) { - return entry.getResult(); - } + private boolean isEntryExpired(final long entryCreationTime) { + return System.currentTimeMillis() - entryCreationTime > maxAgeMs; + } + + private CryptoMaterialsCache.UsageStats initialIncrementForRequest( + EncryptionMaterialsRequest request) { + return new CryptoMaterialsCache.UsageStats(request.getPlaintextSize(), 1); + } - DecryptionMaterials result = backingCMM.decryptMaterials(request); - cache.putEntryForDecrypt(cacheId, result, hint); + @Override + public DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request) { + byte[] cacheId = getCacheIdentifier(request); - return result; + CryptoMaterialsCache.DecryptCacheEntry entry = cache.getEntryForDecrypt(cacheId); + + if (entry != null && !isEntryExpired(entry.getEntryCreationTime())) { + return entry.getResult(); } - private byte[] getCacheIdentifier(EncryptionMaterialsRequest req) { - try { - MessageDigest digest = MessageDigest.getInstance(CACHE_ID_HASH_ALGORITHM); + DecryptionMaterials result = backingCMM.decryptMaterials(request); + cache.putEntryForDecrypt(cacheId, result, hint); - digest.update(partitionIdHash); + return result; + } - CryptoAlgorithm algorithm = req.getRequestedAlgorithm(); - digest.update((byte) (algorithm != null ? 1 : 0)); - if (algorithm != null) { - updateDigestWithAlgorithm(digest, algorithm); - } + private byte[] getCacheIdentifier(EncryptionMaterialsRequest req) { + try { + MessageDigest digest = MessageDigest.getInstance(CACHE_ID_HASH_ALGORITHM); - digest.update(MessageDigest.getInstance(CACHE_ID_HASH_ALGORITHM).digest( - EncryptionContextSerializer.serialize(req.getContext()) - )); + digest.update(partitionIdHash); - return digest.digest(); - } catch (GeneralSecurityException e) { - throw new AwsCryptoException(e); - } + CryptoAlgorithm algorithm = req.getRequestedAlgorithm(); + digest.update((byte) (algorithm != null ? 1 : 0)); + if (algorithm != null) { + updateDigestWithAlgorithm(digest, algorithm); + } + + digest.update( + MessageDigest.getInstance(CACHE_ID_HASH_ALGORITHM) + .digest(EncryptionContextSerializer.serialize(req.getContext()))); + + return digest.digest(); + } catch (GeneralSecurityException e) { + throw new AwsCryptoException(e); } + } - private byte[] getCacheIdentifier(DecryptionMaterialsRequest req) { - try { - MessageDigest digest = MessageDigest.getInstance(CACHE_ID_HASH_ALGORITHM); - byte[] hashOfContext = digest.digest(EncryptionContextSerializer.serialize(req.getEncryptionContext())); + private byte[] getCacheIdentifier(DecryptionMaterialsRequest req) { + try { + MessageDigest digest = MessageDigest.getInstance(CACHE_ID_HASH_ALGORITHM); + byte[] hashOfContext = + digest.digest(EncryptionContextSerializer.serialize(req.getEncryptionContext())); - ArrayList keyBlobHashes = new ArrayList<>(req.getEncryptedDataKeys().size()); + ArrayList keyBlobHashes = new ArrayList<>(req.getEncryptedDataKeys().size()); - for (KeyBlob blob : req.getEncryptedDataKeys()) { - keyBlobHashes.add(digest.digest(blob.toByteArray())); - } - keyBlobHashes.sort(new Utils.ComparingByteArrays()); + for (KeyBlob blob : req.getEncryptedDataKeys()) { + keyBlobHashes.add(digest.digest(blob.toByteArray())); + } + keyBlobHashes.sort(new Utils.ComparingByteArrays()); - // Now starting the digest of the actual cache identifier - digest.update(partitionIdHash); - updateDigestWithAlgorithm(digest, req.getAlgorithm()); + // Now starting the digest of the actual cache identifier + digest.update(partitionIdHash); + updateDigestWithAlgorithm(digest, req.getAlgorithm()); - keyBlobHashes.forEach(digest::update); + keyBlobHashes.forEach(digest::update); - // This all-zero sentinel field indicates the end of the key blob hashes. - digest.update(new byte[digest.getDigestLength()]); - digest.update(hashOfContext); + // This all-zero sentinel field indicates the end of the key blob hashes. + digest.update(new byte[digest.getDigestLength()]); + digest.update(hashOfContext); - return digest.digest(); - } catch (GeneralSecurityException e) { - throw new AwsCryptoException(e); - } + return digest.digest(); + } catch (GeneralSecurityException e) { + throw new AwsCryptoException(e); } + } - // Common helper to add the algorithm identifier (in proper big endian order) for both encrypt and decrypt paths. - private void updateDigestWithAlgorithm(MessageDigest digest, CryptoAlgorithm algorithm) { - short algId = algorithm.getValue(); + // Common helper to add the algorithm identifier (in proper big endian order) for both encrypt and + // decrypt paths. + private void updateDigestWithAlgorithm(MessageDigest digest, CryptoAlgorithm algorithm) { + short algId = algorithm.getValue(); - digest.update(new byte[] { (byte)(algId >> 8), (byte)(algId) }); - } + digest.update(new byte[] {(byte) (algId >> 8), (byte) (algId)}); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/caching/CryptoMaterialsCache.java b/src/main/java/com/amazonaws/encryptionsdk/caching/CryptoMaterialsCache.java index d28c521bb..578711ae2 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/caching/CryptoMaterialsCache.java +++ b/src/main/java/com/amazonaws/encryptionsdk/caching/CryptoMaterialsCache.java @@ -1,223 +1,220 @@ package com.amazonaws.encryptionsdk.caching; -import java.util.Objects; - import com.amazonaws.encryptionsdk.internal.Utils; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; +import java.util.Objects; /** - * Represents a generic cache for cryptographic materials. MaterialsCaches store mappings from abstract bytestring - * identifiers to MaterialsResults and DecryptResults. + * Represents a generic cache for cryptographic materials. MaterialsCaches store mappings from + * abstract bytestring identifiers to MaterialsResults and DecryptResults. * - * In general, the materials cache is concerned about the proper storage of these materials, and managing size limits - * on the cache. While it stores statistics about cache usage limits, the enforcement of these limits is left to the - * caller (typically, a {@link CachingCryptoMaterialsManager}). + *

In general, the materials cache is concerned about the proper storage of these materials, and + * managing size limits on the cache. While it stores statistics about cache usage limits, the + * enforcement of these limits is left to the caller (typically, a {@link + * CachingCryptoMaterialsManager}). * - * For encrypt, a cache implementation may store multiple cache entries for the same identifier. This allows for usage - * limits to be enforced even when doing multiple streaming requests in parallel. However, the cache is permitted to - * set a limit on the number of such entries (even as low as only allowing one entry per identifier), and if it does so - * should evict the excess entries. + *

For encrypt, a cache implementation may store multiple cache entries for the same identifier. + * This allows for usage limits to be enforced even when doing multiple streaming requests in + * parallel. However, the cache is permitted to set a limit on the number of such entries (even as + * low as only allowing one entry per identifier), and if it does so should evict the excess + * entries. * - * Being a cache, a CryptoMaterialsCache is permitted to evict entries at any time. However, a caller can explicitly hint - * the cache to invalidate an entry in the encrypt-side cache. This is generally done when usage limits are exceeded. - * The cache is not required to respect this invalidation hint. + *

Being a cache, a CryptoMaterialsCache is permitted to evict entries at any time. However, a + * caller can explicitly hint the cache to invalidate an entry in the encrypt-side cache. This is + * generally done when usage limits are exceeded. The cache is not required to respect this + * invalidation hint. * - * Likewise, the CacheHint passed to the put calls on caches will indicate the maximum lifetime of entries; the cache - * is allowed - but not required - to evict entries automatically upon expiration of this lifetime. + *

Likewise, the CacheHint passed to the put calls on caches will indicate the maximum lifetime + * of entries; the cache is allowed - but not required - to evict entries automatically upon + * expiration of this lifetime. */ public interface CryptoMaterialsCache { + /** + * Searches for an entry in the encrypt cache matching a particular cache identifier, and returns + * one if found. + * + * @param cacheId The identifier for the item in the cache + * @param usageIncrement The amount of usage to atomically add to the returned entry. This usage + * increment must be reflected in the getUsageStats() method on the returned cache entry. + * @return The entry, or null if not found or an error occurred + */ + EncryptCacheEntry getEntryForEncrypt(byte[] cacheId, final UsageStats usageIncrement); + + /** + * Attempts to add a new entry to the encrypt cache to be returned on subsequent {@link + * #getEntryForEncrypt(byte[], UsageStats)} calls. + * + *

In the event that an error occurs when adding the entry to the cache, this function shall + * still return a EncryptCacheEntry instance, which shall act as if the cache entry was + * immediately evicted and/or invalidated. + * + * @param cacheId The identifier for the item in the cache + * @param encryptionMaterials The {@link EncryptionMaterials} to associate with this new entry + * @param initialUsage The initial usage stats for the cache entry + * @return A new, locked EncryptCacheEntry instance containing the given encryptionMaterials + */ + EncryptCacheEntry putEntryForEncrypt( + byte[] cacheId, + EncryptionMaterials encryptionMaterials, + CacheHint hint, + UsageStats initialUsage); + + /** + * Searches for an entry in the encrypt cache matching a particular cache identifier, and returns + * one if found. + * + *

In the event of an error accessing the cache, this function will return null. + * + * @param cacheId The identifier for the item in the cache + * @return The cached decryption result, or null if none was found or an error occurred. + */ + DecryptCacheEntry getEntryForDecrypt(byte[] cacheId); + + /** + * Adds a new entry to the decrypt cache. In the event of an error, this function will return + * silently without propagating the exception. + * + *

If an entry already exists for this cache ID, the cache may either overwrite the entry in + * its entirety, or update the creation timestamp for the existing entry, at its option. + * + * @param cacheId The identifier for the item in the cache + * @param decryptionMaterials The {@link DecryptionMaterials} to associate with the new entry. + */ + void putEntryForDecrypt(byte[] cacheId, DecryptionMaterials decryptionMaterials, CacheHint hint); + + /** + * Contains some additional information associated with a cache entry. The cache receiving this + * hint may take some actions based on the hint, or it may ignore it entirely. + */ + interface CacheHint { /** - * Searches for an entry in the encrypt cache matching a particular cache identifier, and returns one if found. - * - * @param cacheId The identifier for the item in the cache - * @param usageIncrement The amount of usage to atomically add to the returned entry. This usage increment must be - * reflected in the getUsageStats() method on the returned cache entry. - * @return The entry, or null if not found or an error occurred - */ - EncryptCacheEntry getEntryForEncrypt( - byte[] cacheId, final UsageStats usageIncrement - ); - - /** - * Attempts to add a new entry to the encrypt cache to be returned on subsequent - * {@link #getEntryForEncrypt(byte[], UsageStats)} calls. + * Returns the lifetime of the cache entry. This hint suggests to the cache that the entry will + * not be useful after the provided number of milliseconds passes, and suggests that the cache + * should discard the entry when this interval elapses even if it is not explicitly invalidated. * - * In the event that an error occurs when adding the entry to the cache, this function shall still return a - * EncryptCacheEntry instance, which shall act as if the cache entry was immediately evicted and/or invalidated. + *

Note that this time counts from the creation of the entry, not from last use. * - * @param cacheId The identifier for the item in the cache - * @param encryptionMaterials The {@link EncryptionMaterials} to associate with this new entry - * @param initialUsage The initial usage stats for the cache entry - * @return A new, locked EncryptCacheEntry instance containing the given encryptionMaterials + * @return The lifetime of this entry, in milliseconds. If the lifetime is unknown or + * irrelevant, this will return {@link Long#MAX_VALUE}. */ - EncryptCacheEntry putEntryForEncrypt( - byte[] cacheId, - EncryptionMaterials encryptionMaterials, - CacheHint hint, - UsageStats initialUsage - ); - + long getMaxAgeMillis(); + } + + /** + * Represents an entry in the encrypt cache, and provides methods for manipulating the entry. + * + *

Note that the EncryptCacheEntry Java object remains valid even after the cache entry is + * invalidated or evicted; getResult will still return a valid result, for example. + */ + interface EncryptCacheEntry { /** - * Searches for an entry in the encrypt cache matching a particular cache identifier, and returns one if found. - * - * In the event of an error accessing the cache, this function will return null. - * - * @param cacheId The identifier for the item in the cache - * @return The cached decryption result, or null if none was found or an error occurred. + * @return The current usage statistics for this entry. The caller should be aware that these + * statistics may be stale by the time they are returned. */ - DecryptCacheEntry getEntryForDecrypt(byte[] cacheId); + UsageStats getUsageStats(); - /** - * Adds a new entry to the decrypt cache. In the event of an error, this function will return silently without - * propagating the exception. - * - * If an entry already exists for this cache ID, the cache may either overwrite the entry in its entirety, or update - * the creation timestamp for the existing entry, at its option. - * - * @param cacheId The identifier for the item in the cache - * @param decryptionMaterials The {@link DecryptionMaterials} to associate with the new entry. - */ - void putEntryForDecrypt(byte[] cacheId, DecryptionMaterials decryptionMaterials, CacheHint hint); + /** @return The unix timestamp at which this entry was added to the cache, in milliseconds */ + long getEntryCreationTime(); /** - * Contains some additional information associated with a cache entry. The cache receiving this hint may take some - * actions based on the hint, or it may ignore it entirely. + * @return The EncryptionMaterials associated with this cache entry. The encrypt completion + * callback should be a no-op. */ - interface CacheHint { - /** - * Returns the lifetime of the cache entry. This hint suggests to the cache that the entry will not be useful - * after the provided number of milliseconds passes, and suggests that the cache should discard the entry when - * this interval elapses even if it is not explicitly invalidated. - * - * Note that this time counts from the creation of the entry, not from last use. - * - * @return The lifetime of this entry, in milliseconds. If the lifetime is unknown or irrelevant, this will - * return {@link Long#MAX_VALUE}. - */ - long getMaxAgeMillis(); + EncryptionMaterials getResult(); + + /** Suggests to the cache that this entry should be removed from the cache. */ + default void invalidate() {} + } + + final class UsageStats { + public static final UsageStats ZERO = new UsageStats(0, 0); + + private final long bytesEncrypted; + private final long messagesEncrypted; + + public UsageStats(long bytesEncrypted, long messagesEncrypted) { + if (bytesEncrypted < 0) { + throw new IllegalArgumentException("Negative bytes encrypted is not permitted"); + } + + if (messagesEncrypted < 0) { + throw new IllegalArgumentException("Negative messages encrypted is not permitted"); + } + + this.bytesEncrypted = bytesEncrypted; + this.messagesEncrypted = messagesEncrypted; } - /** - * Represents an entry in the encrypt cache, and provides methods for manipulating the entry. - * - * Note that the EncryptCacheEntry Java object remains valid even after the cache entry is invalidated or evicted; - * getResult will still return a valid result, for example. - */ - interface EncryptCacheEntry { - /** - * @return The current usage statistics for this entry. The caller should be aware that these statistics may be - * stale by the time they are returned. - */ - UsageStats getUsageStats(); - - /** - * @return The unix timestamp at which this entry was added to the cache, in milliseconds - */ - long getEntryCreationTime(); - - /** - * @return The EncryptionMaterials associated with this cache entry. The encrypt completion callback should be a - * no-op. - */ - EncryptionMaterials getResult(); - - /** - * Suggests to the cache that this entry should be removed from the cache. - */ - default void invalidate() {} + public long getBytesEncrypted() { + return bytesEncrypted; } - final class UsageStats { - public static final UsageStats ZERO = new UsageStats(0, 0); - - private final long bytesEncrypted; - private final long messagesEncrypted; - - public UsageStats(long bytesEncrypted, long messagesEncrypted) { - if (bytesEncrypted < 0) { - throw new IllegalArgumentException("Negative bytes encrypted is not permitted"); - } - - if (messagesEncrypted < 0) { - throw new IllegalArgumentException("Negative messages encrypted is not permitted"); - } - - this.bytesEncrypted = bytesEncrypted; - this.messagesEncrypted = messagesEncrypted; - } - - public long getBytesEncrypted() { - return bytesEncrypted; - } - - public long getMessagesEncrypted() { - return messagesEncrypted; - } - - /** - * Performs a pairwise add of two UsageStats objects. In the event of overflow, counters saturate at - * {@link Long#MAX_VALUE} - * - * @param other - * @return - */ - public UsageStats add(UsageStats other) { - return new UsageStats( - saturatingAdd(getBytesEncrypted(), other.getBytesEncrypted()), - saturatingAdd(getMessagesEncrypted(), other.getMessagesEncrypted()) - ); - } - - static long saturatingAdd(long a, long b) { - if (a < 0 || b < 0) { - throw new IllegalArgumentException("Negative usage values are not permitted"); - } - - return Utils.saturatingAdd(a, b); - } - - @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - UsageStats that = (UsageStats) o; - return getBytesEncrypted() == that.getBytesEncrypted() && - getMessagesEncrypted() == that.getMessagesEncrypted(); - } - - @Override public int hashCode() { - return Objects.hash(getBytesEncrypted(), getMessagesEncrypted()); - } - - @Override public String toString() { - return "UsageStats{" + - "bytesEncrypted=" + bytesEncrypted + - ", messagesEncrypted=" + messagesEncrypted + - '}'; - } + public long getMessagesEncrypted() { + return messagesEncrypted; } /** - * Represents an entry in the decrypt cache, and provides methods for manipulating the entry. + * Performs a pairwise add of two UsageStats objects. In the event of overflow, counters + * saturate at {@link Long#MAX_VALUE} * - * Note that the DecryptCacheEntry JAva object remains valid even after the cache entry is invalidated or evicted; - * getResult will still return a valid result, for example. + * @param other + * @return */ - interface DecryptCacheEntry { - /** - * Returns the DecryptionMaterials associated with this entry. - */ - DecryptionMaterials getResult(); - - /** - * Advises the cache that this entry should be removed from the cache. - */ - void invalidate(); - - /** - * @return The unix timestamp at which this entry was added to the cache, in milliseconds - */ - long getEntryCreationTime(); + public UsageStats add(UsageStats other) { + return new UsageStats( + saturatingAdd(getBytesEncrypted(), other.getBytesEncrypted()), + saturatingAdd(getMessagesEncrypted(), other.getMessagesEncrypted())); + } + + static long saturatingAdd(long a, long b) { + if (a < 0 || b < 0) { + throw new IllegalArgumentException("Negative usage values are not permitted"); + } + + return Utils.saturatingAdd(a, b); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UsageStats that = (UsageStats) o; + return getBytesEncrypted() == that.getBytesEncrypted() + && getMessagesEncrypted() == that.getMessagesEncrypted(); + } + + @Override + public int hashCode() { + return Objects.hash(getBytesEncrypted(), getMessagesEncrypted()); + } + + @Override + public String toString() { + return "UsageStats{" + + "bytesEncrypted=" + + bytesEncrypted + + ", messagesEncrypted=" + + messagesEncrypted + + '}'; } + } + + /** + * Represents an entry in the decrypt cache, and provides methods for manipulating the entry. + * + *

Note that the DecryptCacheEntry JAva object remains valid even after the cache entry is + * invalidated or evicted; getResult will still return a valid result, for example. + */ + interface DecryptCacheEntry { + /** Returns the DecryptionMaterials associated with this entry. */ + DecryptionMaterials getResult(); + + /** Advises the cache that this entry should be removed from the cache. */ + void invalidate(); + + /** @return The unix timestamp at which this entry was added to the cache, in milliseconds */ + long getEntryCreationTime(); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCache.java b/src/main/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCache.java index 8f7085b23..6cb95d16a 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCache.java +++ b/src/main/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCache.java @@ -1,22 +1,19 @@ package com.amazonaws.encryptionsdk.caching; -import static java.util.Collections.max; - -import javax.annotation.concurrent.GuardedBy; -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.TreeSet; - import com.amazonaws.encryptionsdk.internal.Utils; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.TreeSet; +import javax.annotation.concurrent.GuardedBy; /** * A simple implementation of the {@link CryptoMaterialsCache} using a basic LRU cache. * - * Example usage: - * {@code + *

Example usage: * + *

{@code
  * LocalCryptoMaterialsCache cache = new LocalCryptoMaterialsCache(500);
  *
  * CachingCryptoMaterialsManager materialsManager = CachingCryptoMaterialsManager.builder()
@@ -26,284 +23,284 @@
  *      .build();
  *
  * byte[] data = new AwsCrypto().encryptData(materialsManager, plaintext).getResult();
- * }
+ * }
*/ public class LocalCryptoMaterialsCache implements CryptoMaterialsCache { - // The maximum number of entries to implicitly prune per access due to TTL expiration. We limit this to avoid - // latency spikes when a large number of entries have expired since the last cache usage. - private static final int MAX_TTL_PRUNE = 10; + // The maximum number of entries to implicitly prune per access due to TTL expiration. We limit + // this to avoid + // latency spikes when a large number of entries have expired since the last cache usage. + private static final int MAX_TTL_PRUNE = 10; - // Mockable time source, to allow us to test TTL pruning. - // package access for tests - // note: we're not using the java 8 time APIs in order to improve android compatibility - MsClock clock = MsClock.WALLCLOCK; + // Mockable time source, to allow us to test TTL pruning. + // package access for tests + // note: we're not using the java 8 time APIs in order to improve android compatibility + MsClock clock = MsClock.WALLCLOCK; - // The magic numbers here are the normal defaults for LinkedHashMap; we have to specify them explicitly if we are to - // specify accessOrder=true, which enables LRU behavior - private final LinkedHashMap cacheMap = new LinkedHashMap<>( - /* capacity */ 16, /* loadFactor */ 0.75f, /* accessOrder */ true - ); + // The magic numbers here are the normal defaults for LinkedHashMap; we have to specify them + // explicitly if we are to + // specify accessOrder=true, which enables LRU behavior + private final LinkedHashMap cacheMap = + new LinkedHashMap<>(/* capacity */ 16, /* loadFactor */ 0.75f, /* accessOrder */ true); - // This is a treeset sorted by TTL to allow us to quickly find expired entries - private final TreeSet expirationQueue = new TreeSet<>(LocalCryptoMaterialsCache::compareEntries); + // This is a treeset sorted by TTL to allow us to quickly find expired entries + private final TreeSet expirationQueue = + new TreeSet<>(LocalCryptoMaterialsCache::compareEntries); - private final int capacity; + private final int capacity; - public LocalCryptoMaterialsCache(int capacity) { - this.capacity = capacity; - } + public LocalCryptoMaterialsCache(int capacity) { + this.capacity = capacity; + } - private static int compareEntries(BaseEntry a, BaseEntry b) { - int result; + private static int compareEntries(BaseEntry a, BaseEntry b) { + int result; - if (a == b) { - return 0; - } - - result = Long.compare(a.expirationTimestamp_, b.expirationTimestamp_); - if (result != 0) { - return result; - } - - return Utils.compareObjectIdentity(a, b); + if (a == b) { + return 0; } - /** - * A common base for both encrypt and decrypt entries - */ - private class BaseEntry { - final CacheIdentifier identifier_; - final long expirationTimestamp_; - final long creationTime = clock.timestamp(); - - private BaseEntry(CacheIdentifier identifier, long expiration) { - this.identifier_ = identifier; - this.expirationTimestamp_ = expiration; - } + result = Long.compare(a.expirationTimestamp_, b.expirationTimestamp_); + if (result != 0) { + return result; } - /** - * This wrapper just gives us a usable hashcode over our cache identifiers. - */ - private static final class CacheIdentifier { - private final byte[] identifier; - private final int hashCode; - - private CacheIdentifier(byte[] passed_id) { - this.identifier = passed_id.clone(); - this.hashCode = Arrays.hashCode(passed_id); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - return Arrays.equals(identifier, ((CacheIdentifier)o).identifier); - } - - @Override - public int hashCode() { - return hashCode; - } - } + return Utils.compareObjectIdentity(a, b); + } - // Note: We take locks on both cache entries as well as the overall cache. - // The lock order is overall cache -> cache entry; this means that the entry cannot call back into the parent cache - // while holding its own lock. - private final class EncryptCacheEntryInternal extends BaseEntry { - private final EncryptionMaterials result; + /** A common base for both encrypt and decrypt entries */ + private class BaseEntry { + final CacheIdentifier identifier_; + final long expirationTimestamp_; + final long creationTime = clock.timestamp(); - @GuardedBy("this") - private UsageStats usageStats = UsageStats.ZERO; + private BaseEntry(CacheIdentifier identifier, long expiration) { + this.identifier_ = identifier; + this.expirationTimestamp_ = expiration; + } + } - private EncryptCacheEntryInternal( - CacheIdentifier identifier, - long expiration, - EncryptionMaterials result - ) { - super(identifier, expiration); + /** This wrapper just gives us a usable hashcode over our cache identifiers. */ + private static final class CacheIdentifier { + private final byte[] identifier; + private final int hashCode; - this.result = result; - } + private CacheIdentifier(byte[] passed_id) { + this.identifier = passed_id.clone(); + this.hashCode = Arrays.hashCode(passed_id); + } - synchronized UsageStats addAndGetUsageStats(UsageStats delta) { - this.usageStats = this.usageStats.add(delta); + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; - return this.usageStats; - } + return Arrays.equals(identifier, ((CacheIdentifier) o).identifier); } - // When returning cache entries, we create a new object to represent the snapshot of usage stats at time of get. - // This helps avoid races where two gets together push an entry over usage limits, and then both miss when they - // see the entry over the limit. - // - // Not static as invalidate calls back into the cache. - private final class EncryptCacheEntryExposed implements EncryptCacheEntry { - private final UsageStats usageStats_; - private final EncryptCacheEntryInternal internal_; - - private EncryptCacheEntryExposed( - final UsageStats usageStats, - final EncryptCacheEntryInternal internal - ) { - usageStats_ = usageStats; - internal_ = internal; - } - - @Override public UsageStats getUsageStats() { - return usageStats_; - } - - @Override public long getEntryCreationTime() { - return internal_.creationTime; - } - - @Override public EncryptionMaterials getResult() { - return internal_.result; - } - - @Override public void invalidate() { - removeEntry(internal_); - } + @Override + public int hashCode() { + return hashCode; } + } - private final class DecryptCacheEntryInternal extends BaseEntry implements DecryptCacheEntry { - final DecryptionMaterials result; + // Note: We take locks on both cache entries as well as the overall cache. + // The lock order is overall cache -> cache entry; this means that the entry cannot call back into + // the parent cache + // while holding its own lock. + private final class EncryptCacheEntryInternal extends BaseEntry { + private final EncryptionMaterials result; - private DecryptCacheEntryInternal( - CacheIdentifier identifier, - long expiration, - DecryptionMaterials result - ) { - super(identifier, expiration); + @GuardedBy("this") + private UsageStats usageStats = UsageStats.ZERO; - this.result = result; - } + private EncryptCacheEntryInternal( + CacheIdentifier identifier, long expiration, EncryptionMaterials result) { + super(identifier, expiration); - @Override public DecryptionMaterials getResult() { - return result; - } + this.result = result; + } - @Override public void invalidate() { - removeEntry(this); - } + synchronized UsageStats addAndGetUsageStats(UsageStats delta) { + this.usageStats = this.usageStats.add(delta); - @Override public long getEntryCreationTime() { - return creationTime; - } + return this.usageStats; + } + } + + // When returning cache entries, we create a new object to represent the snapshot of usage stats + // at time of get. + // This helps avoid races where two gets together push an entry over usage limits, and then both + // miss when they + // see the entry over the limit. + // + // Not static as invalidate calls back into the cache. + private final class EncryptCacheEntryExposed implements EncryptCacheEntry { + private final UsageStats usageStats_; + private final EncryptCacheEntryInternal internal_; + + private EncryptCacheEntryExposed( + final UsageStats usageStats, final EncryptCacheEntryInternal internal) { + usageStats_ = usageStats; + internal_ = internal; } - /** - * Removes an entry from the cache. - * @param e the entry to remove - */ - private synchronized void removeEntry(BaseEntry e) { - expirationQueue.remove(e); - // This does not update the LRU if the value does not match - cacheMap.remove(e.identifier_, e); + @Override + public UsageStats getUsageStats() { + return usageStats_; } - /** - * Prunes all TTL-expired entries, plus LRU entries until we are under capacity limits. - */ - private synchronized void prune() { - // Purge maxage-expired entries first, to avoid pruning entries by LRU unnecessarily when we're about to free - // up space anyway. - ttlPrune(); - - while (cacheMap.size() > capacity) { - removeEntry(cacheMap.values().iterator().next()); - } + @Override + public long getEntryCreationTime() { + return internal_.creationTime; } - /** - * Prunes all TTL-expired entries. Does not check capacity. - */ - private void ttlPrune() { - int pruneCount = 0; - long now = clock.timestamp(); - - while (!expirationQueue.isEmpty() && expirationQueue.first().expirationTimestamp_ < now && pruneCount < MAX_TTL_PRUNE) { - removeEntry(expirationQueue.first()); - pruneCount++; - } + @Override + public EncryptionMaterials getResult() { + return internal_.result; } - private synchronized T getEntry(Class klass, byte[] identifier) { - // Perform cache maintenance first - ttlPrune(); + @Override + public void invalidate() { + removeEntry(internal_); + } + } - BaseEntry e = cacheMap.get(new CacheIdentifier(identifier)); + private final class DecryptCacheEntryInternal extends BaseEntry implements DecryptCacheEntry { + final DecryptionMaterials result; - if (e == null) { - return null; - } else { - if (e.expirationTimestamp_ < clock.timestamp()) { - removeEntry(e); - return null; - } + private DecryptCacheEntryInternal( + CacheIdentifier identifier, long expiration, DecryptionMaterials result) { + super(identifier, expiration); - return klass.cast(e); - } + this.result = result; } - private synchronized void putEntry(final BaseEntry entry) { - BaseEntry oldEntry = cacheMap.put(entry.identifier_, entry); + @Override + public DecryptionMaterials getResult() { + return result; + } - if (oldEntry != null) { - expirationQueue.remove(oldEntry); - } - expirationQueue.add(entry); + @Override + public void invalidate() { + removeEntry(this); + } - prune(); + @Override + public long getEntryCreationTime() { + return creationTime; } + } + + /** + * Removes an entry from the cache. + * + * @param e the entry to remove + */ + private synchronized void removeEntry(BaseEntry e) { + expirationQueue.remove(e); + // This does not update the LRU if the value does not match + cacheMap.remove(e.identifier_, e); + } + + /** Prunes all TTL-expired entries, plus LRU entries until we are under capacity limits. */ + private synchronized void prune() { + // Purge maxage-expired entries first, to avoid pruning entries by LRU unnecessarily when we're + // about to free + // up space anyway. + ttlPrune(); + + while (cacheMap.size() > capacity) { + removeEntry(cacheMap.values().iterator().next()); + } + } + + /** Prunes all TTL-expired entries. Does not check capacity. */ + private void ttlPrune() { + int pruneCount = 0; + long now = clock.timestamp(); + + while (!expirationQueue.isEmpty() + && expirationQueue.first().expirationTimestamp_ < now + && pruneCount < MAX_TTL_PRUNE) { + removeEntry(expirationQueue.first()); + pruneCount++; + } + } - @Override public EncryptCacheEntry getEntryForEncrypt( - byte[] cacheId, final UsageStats usageIncrement - ) { - EncryptCacheEntryInternal entry = getEntry(EncryptCacheEntryInternal.class, cacheId); + private synchronized T getEntry(Class klass, byte[] identifier) { + // Perform cache maintenance first + ttlPrune(); - if (entry != null) { - UsageStats stats = entry.addAndGetUsageStats(usageIncrement); - return new EncryptCacheEntryExposed(stats, entry); - } + BaseEntry e = cacheMap.get(new CacheIdentifier(identifier)); + if (e == null) { + return null; + } else { + if (e.expirationTimestamp_ < clock.timestamp()) { + removeEntry(e); return null; - } - - @Override public EncryptCacheEntry putEntryForEncrypt( - byte[] cacheId, - EncryptionMaterials encryptionMaterials, - CacheHint hint, - UsageStats initialUsage - ) { - EncryptCacheEntryInternal entry = new EncryptCacheEntryInternal( - new CacheIdentifier(cacheId), - Utils.saturatingAdd(clock.timestamp(), hint.getMaxAgeMillis()), - encryptionMaterials - ); + } - entry.addAndGetUsageStats(initialUsage); + return klass.cast(e); + } + } - putEntry(entry); + private synchronized void putEntry(final BaseEntry entry) { + BaseEntry oldEntry = cacheMap.put(entry.identifier_, entry); - return new EncryptCacheEntryExposed(initialUsage, entry); + if (oldEntry != null) { + expirationQueue.remove(oldEntry); } + expirationQueue.add(entry); - @Override public DecryptCacheEntry getEntryForDecrypt(byte[] cacheId) { - return getEntry(DecryptCacheEntryInternal.class, cacheId); - } + prune(); + } - @Override public void putEntryForDecrypt( - byte[] cacheId, DecryptionMaterials decryptionMaterials, CacheHint hint - ) { - DecryptCacheEntryInternal entry = new DecryptCacheEntryInternal( - new CacheIdentifier(cacheId), - Utils.saturatingAdd(clock.timestamp(), hint.getMaxAgeMillis()), - decryptionMaterials - ); + @Override + public EncryptCacheEntry getEntryForEncrypt(byte[] cacheId, final UsageStats usageIncrement) { + EncryptCacheEntryInternal entry = getEntry(EncryptCacheEntryInternal.class, cacheId); - putEntry(entry); + if (entry != null) { + UsageStats stats = entry.addAndGetUsageStats(usageIncrement); + return new EncryptCacheEntryExposed(stats, entry); } + + return null; + } + + @Override + public EncryptCacheEntry putEntryForEncrypt( + byte[] cacheId, + EncryptionMaterials encryptionMaterials, + CacheHint hint, + UsageStats initialUsage) { + EncryptCacheEntryInternal entry = + new EncryptCacheEntryInternal( + new CacheIdentifier(cacheId), + Utils.saturatingAdd(clock.timestamp(), hint.getMaxAgeMillis()), + encryptionMaterials); + + entry.addAndGetUsageStats(initialUsage); + + putEntry(entry); + + return new EncryptCacheEntryExposed(initialUsage, entry); + } + + @Override + public DecryptCacheEntry getEntryForDecrypt(byte[] cacheId) { + return getEntry(DecryptCacheEntryInternal.class, cacheId); + } + + @Override + public void putEntryForDecrypt( + byte[] cacheId, DecryptionMaterials decryptionMaterials, CacheHint hint) { + DecryptCacheEntryInternal entry = + new DecryptCacheEntryInternal( + new CacheIdentifier(cacheId), + Utils.saturatingAdd(clock.timestamp(), hint.getMaxAgeMillis()), + decryptionMaterials); + + putEntry(entry); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/caching/MsClock.java b/src/main/java/com/amazonaws/encryptionsdk/caching/MsClock.java index 883dbc89c..61e3faca3 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/caching/MsClock.java +++ b/src/main/java/com/amazonaws/encryptionsdk/caching/MsClock.java @@ -1,7 +1,7 @@ package com.amazonaws.encryptionsdk.caching; interface MsClock { - MsClock WALLCLOCK = System::currentTimeMillis; + MsClock WALLCLOCK = System::currentTimeMillis; - public long timestamp(); + public long timestamp(); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/caching/NullCryptoMaterialsCache.java b/src/main/java/com/amazonaws/encryptionsdk/caching/NullCryptoMaterialsCache.java index b8f1653f9..3f4806699 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/caching/NullCryptoMaterialsCache.java +++ b/src/main/java/com/amazonaws/encryptionsdk/caching/NullCryptoMaterialsCache.java @@ -3,49 +3,47 @@ import com.amazonaws.encryptionsdk.model.DecryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; -/** - * A {@link CryptoMaterialsCache} that doesn't actually cache anything. - */ +/** A {@link CryptoMaterialsCache} that doesn't actually cache anything. */ public class NullCryptoMaterialsCache implements CryptoMaterialsCache { - @Override public EncryptCacheEntry getEntryForEncrypt( - byte[] cacheId, - final UsageStats usageIncrement - ) { - return null; - } - - @Override public EncryptCacheEntry putEntryForEncrypt( - byte[] cacheId, - EncryptionMaterials encryptionMaterials, - CacheHint hint, - UsageStats initialUsage - ) { - return new EncryptCacheEntry() { - private final long creationTime = System.currentTimeMillis(); - - @Override public synchronized UsageStats getUsageStats() { - return initialUsage; - } - - @Override public long getEntryCreationTime() { - return creationTime; - } - - @Override public EncryptionMaterials getResult() { - return encryptionMaterials; - } - }; - } - - @Override public DecryptCacheEntry getEntryForDecrypt( - byte[] cacheId - ) { - return null; - } - - @Override public void putEntryForDecrypt( - byte[] cacheId, DecryptionMaterials decryptionMaterials, CacheHint hint - ) { - // no-op - } + @Override + public EncryptCacheEntry getEntryForEncrypt(byte[] cacheId, final UsageStats usageIncrement) { + return null; + } + + @Override + public EncryptCacheEntry putEntryForEncrypt( + byte[] cacheId, + EncryptionMaterials encryptionMaterials, + CacheHint hint, + UsageStats initialUsage) { + return new EncryptCacheEntry() { + private final long creationTime = System.currentTimeMillis(); + + @Override + public synchronized UsageStats getUsageStats() { + return initialUsage; + } + + @Override + public long getEntryCreationTime() { + return creationTime; + } + + @Override + public EncryptionMaterials getResult() { + return encryptionMaterials; + } + }; + } + + @Override + public DecryptCacheEntry getEntryForDecrypt(byte[] cacheId) { + return null; + } + + @Override + public void putEntryForDecrypt( + byte[] cacheId, DecryptionMaterials decryptionMaterials, CacheHint hint) { + // no-op + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/AwsCryptoException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/AwsCryptoException.java index 9abe55cb6..d1f3a418f 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/exception/AwsCryptoException.java +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/AwsCryptoException.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,46 +13,47 @@ package com.amazonaws.encryptionsdk.exception; -/** - * This is the parent class of the runtime exceptions thrown by the AWS Encryption SDK. - */ -//@ non_null_by_default +/** This is the parent class of the runtime exceptions thrown by the AWS Encryption SDK. */ +// @ non_null_by_default public class AwsCryptoException extends RuntimeException { - private static final long serialVersionUID = -1L; + private static final long serialVersionUID = -1L; - //@ public normal_behavior - //@ ensures standardThrowable(); - //@ pure - public AwsCryptoException() { - super(); - } + // @ public normal_behavior + // @ ensures standardThrowable(); + // @ pure + public AwsCryptoException() { + super(); + } - //@ public normal_behavior - //@ ensures standardThrowable(message); - //@ pure - public AwsCryptoException(final String message) { - super(message); - } + // @ public normal_behavior + // @ ensures standardThrowable(message); + // @ pure + public AwsCryptoException(final String message) { + super(message); + } - //@ public normal_behavior - //@ ensures standardThrowable(cause); - //@ pure - public AwsCryptoException(final Throwable cause) { - super(cause); - } + // @ public normal_behavior + // @ ensures standardThrowable(cause); + // @ pure + public AwsCryptoException(final Throwable cause) { + super(cause); + } - //@ public normal_behavior - //@ ensures standardThrowable(message,cause); - //@ pure - public AwsCryptoException(final String message, final Throwable cause) { - super(message, cause); - } + // @ public normal_behavior + // @ ensures standardThrowable(message,cause); + // @ pure + public AwsCryptoException(final String message, final Throwable cause) { + super(message, cause); + } - //@ public normal_behavior - //@ ensures standardThrowable(message,cause); - //@ pure // TODO - public AwsCryptoException(final String message, final Throwable cause, final boolean enableSuppression, - final boolean writableStackTrace) { - super(message, cause, enableSuppression, writableStackTrace); - } + // @ public normal_behavior + // @ ensures standardThrowable(message,cause); + // @ pure // TODO + public AwsCryptoException( + final String message, + final Throwable cause, + final boolean enableSuppression, + final boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/BadCiphertextException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/BadCiphertextException.java index e7c1be95f..76efc6157 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/exception/BadCiphertextException.java +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/BadCiphertextException.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -14,42 +14,41 @@ package com.amazonaws.encryptionsdk.exception; /** - * This exception is thrown when the values found in a ciphertext message are - * invalid or corrupt. + * This exception is thrown when the values found in a ciphertext message are invalid or corrupt. */ -//@ non_null_by_default +// @ non_null_by_default public class BadCiphertextException extends AwsCryptoException { - private static final long serialVersionUID = -1L; + private static final long serialVersionUID = -1L; - /*@ public normal_behavior - @ ensures standardThrowable(); - @*/ - //@ pure - public BadCiphertextException() { - super(); - } + /*@ public normal_behavior + @ ensures standardThrowable(); + @*/ + // @ pure + public BadCiphertextException() { + super(); + } - /*@ public normal_behavior - @ ensures standardThrowable(message); - @*/ - //@ pure - public BadCiphertextException(final String message) { - super(message); - } + /*@ public normal_behavior + @ ensures standardThrowable(message); + @*/ + // @ pure + public BadCiphertextException(final String message) { + super(message); + } - /*@ public normal_behavior - @ ensures standardThrowable(cause); - @*/ - //@ pure - public BadCiphertextException(final Throwable cause) { - super(cause); - } + /*@ public normal_behavior + @ ensures standardThrowable(cause); + @*/ + // @ pure + public BadCiphertextException(final Throwable cause) { + super(cause); + } - /*@ public normal_behavior - @ ensures standardThrowable(message,cause); - @*/ - //@ pure - public BadCiphertextException(final String message, final Throwable cause) { - super(message, cause); - } + /*@ public normal_behavior + @ ensures standardThrowable(message,cause); + @*/ + // @ pure + public BadCiphertextException(final String message, final Throwable cause) { + super(message, cause); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/CannotUnwrapDataKeyException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/CannotUnwrapDataKeyException.java index c40f1b08e..30f6ddd15 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/exception/CannotUnwrapDataKeyException.java +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/CannotUnwrapDataKeyException.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -15,42 +15,40 @@ import com.amazonaws.encryptionsdk.DataKey; -/** - * This exception is thrown when there are no {@link DataKey}s which can be decrypted. - */ -//@ non_null_by_default +/** This exception is thrown when there are no {@link DataKey}s which can be decrypted. */ +// @ non_null_by_default public class CannotUnwrapDataKeyException extends AwsCryptoException { - private static final long serialVersionUID = -1L; + private static final long serialVersionUID = -1L; - /*@ public normal_behavior - @ ensures standardThrowable(); - @*/ - //@ pure - public CannotUnwrapDataKeyException() { - super(); - } + /*@ public normal_behavior + @ ensures standardThrowable(); + @*/ + // @ pure + public CannotUnwrapDataKeyException() { + super(); + } - /*@ public normal_behavior - @ ensures standardThrowable(message); - @*/ - //@ pure - public CannotUnwrapDataKeyException(final String message) { - super(message); - } + /*@ public normal_behavior + @ ensures standardThrowable(message); + @*/ + // @ pure + public CannotUnwrapDataKeyException(final String message) { + super(message); + } - /*@ public normal_behavior - @ ensures standardThrowable(cause); - @*/ - //@ pure - public CannotUnwrapDataKeyException(final Throwable cause) { - super(cause); - } + /*@ public normal_behavior + @ ensures standardThrowable(cause); + @*/ + // @ pure + public CannotUnwrapDataKeyException(final Throwable cause) { + super(cause); + } - /*@ public normal_behavior - @ ensures standardThrowable(message,cause); - @*/ - //@ pure - public CannotUnwrapDataKeyException(final String message, final Throwable cause) { - super(message, cause); - } + /*@ public normal_behavior + @ ensures standardThrowable(message,cause); + @*/ + // @ pure + public CannotUnwrapDataKeyException(final String message, final Throwable cause) { + super(message, cause); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/NoSuchMasterKeyException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/NoSuchMasterKeyException.java index a9f68fb85..5ed9883a4 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/exception/NoSuchMasterKeyException.java +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/NoSuchMasterKeyException.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -19,48 +19,49 @@ * This exception is thrown when the SDK attempts to use a {@link MasterKey} which either doesn't * exist or to which it doesn't have access. */ -//@ non_null_by_default +// @ non_null_by_default public class NoSuchMasterKeyException extends AwsCryptoException { - private static final long serialVersionUID = -1L; + private static final long serialVersionUID = -1L; - /*@ public normal_behavior - @ ensures standardThrowable(); - @*/ - //@ pure - public NoSuchMasterKeyException() { - } + /*@ public normal_behavior + @ ensures standardThrowable(); + @*/ + // @ pure + public NoSuchMasterKeyException() {} - /*@ public normal_behavior - @ ensures standardThrowable(message); - @*/ - //@ pure - public NoSuchMasterKeyException(final String message) { - super(message); - } + /*@ public normal_behavior + @ ensures standardThrowable(message); + @*/ + // @ pure + public NoSuchMasterKeyException(final String message) { + super(message); + } - /*@ public normal_behavior - @ ensures standardThrowable(cause); - @*/ - //@ pure - public NoSuchMasterKeyException(final Throwable cause) { - super(cause); - } + /*@ public normal_behavior + @ ensures standardThrowable(cause); + @*/ + // @ pure + public NoSuchMasterKeyException(final Throwable cause) { + super(cause); + } - /*@ public normal_behavior - @ ensures standardThrowable(message,cause); - @*/ - //@ pure - public NoSuchMasterKeyException(final String message, final Throwable cause) { - super(message, cause); - } - - /*@ public normal_behavior - @ ensures standardThrowable(message,cause); - @*/ - //@ pure // TODO - public NoSuchMasterKeyException(final String message, final Throwable cause, - final boolean enableSuppression, final boolean writableStackTrace) { - super(message, cause, enableSuppression, writableStackTrace); - } + /*@ public normal_behavior + @ ensures standardThrowable(message,cause); + @*/ + // @ pure + public NoSuchMasterKeyException(final String message, final Throwable cause) { + super(message, cause); + } + /*@ public normal_behavior + @ ensures standardThrowable(message,cause); + @*/ + // @ pure // TODO + public NoSuchMasterKeyException( + final String message, + final Throwable cause, + final boolean enableSuppression, + final boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/ParseException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/ParseException.java index 6d163c838..e03721aa3 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/exception/ParseException.java +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/ParseException.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -17,74 +17,66 @@ * This exception is thrown when there are not enough bytes to parse a primitive, a specified number * of bytes, or the bytes does not properly represent the desired object. */ -//@ non_null_by_default +// @ non_null_by_default public class ParseException extends AwsCryptoException { - private static final long serialVersionUID = -1L; + private static final long serialVersionUID = -1L; - /** - * Constructs a new exception with no detail message. - */ - /*@ public normal_behavior - @ ensures standardThrowable(); - @*/ - //@ pure - public ParseException() { - super(); - } + /** Constructs a new exception with no detail message. */ + /*@ public normal_behavior + @ ensures standardThrowable(); + @*/ + // @ pure + public ParseException() { + super(); + } - /** - * Constructs a new exception with the specified detail message. - * - * @param message - * the detail message. - */ - /*@ public normal_behavior - @ ensures standardThrowable(message); - @*/ - //@ pure - public ParseException(final String message) { - super(message); - } + /** + * Constructs a new exception with the specified detail message. + * + * @param message the detail message. + */ + /*@ public normal_behavior + @ ensures standardThrowable(message); + @*/ + // @ pure + public ParseException(final String message) { + super(message); + } - /** - * Constructs a new exception with the specified cause and a detail message of - * (cause==null ? null : cause.toString()) (which typically contains the class and - * detail message of cause). - * - * @param cause - * the cause (which is saved for later retrieval by the {@link Throwable#getCause()} - * method). (A null value is permitted, and indicates that the cause is - * nonexistent or unknown.) - */ - /*@ public normal_behavior - @ ensures standardThrowable(cause); - @*/ - //@ pure - public ParseException(final Throwable cause) { - super(cause); - } + /** + * Constructs a new exception with the specified cause and a detail message of (cause==null ? + * null : cause.toString()) (which typically contains the class and detail message of + * cause). + * + * @param cause the cause (which is saved for later retrieval by the {@link Throwable#getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent + * or unknown.) + */ + /*@ public normal_behavior + @ ensures standardThrowable(cause); + @*/ + // @ pure + public ParseException(final Throwable cause) { + super(cause); + } - /** - * Constructs a new exception with the specified detail message and cause. - * - *

- * Note that the detail message associated with cause is not automatically incorporated in this - * exception's detail message. - * - * @param message - * the detail message (which is saved for later retrieval by the - * {@link Throwable#getMessage()} method). - * - * @param cause - * the cause (which is saved for later retrieval by the {@link Throwable#getCause()} - * method). (A null value is permitted, and indicates that the cause is - * nonexistent or unknown.) - */ - /*@ public normal_behavior - @ ensures standardThrowable(message,cause); - @*/ - //@ pure - public ParseException(final String message, final Throwable cause) { - super(message, cause); - } + /** + * Constructs a new exception with the specified detail message and cause. + * + *

Note that the detail message associated with cause is not automatically incorporated in this + * exception's detail message. + * + * @param message the detail message (which is saved for later retrieval by the {@link + * Throwable#getMessage()} method). + * @param cause the cause (which is saved for later retrieval by the {@link Throwable#getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent + * or unknown.) + */ + /*@ public normal_behavior + @ ensures standardThrowable(message,cause); + @*/ + // @ pure + public ParseException(final String message, final Throwable cause) { + super(message, cause); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedProviderException.java b/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedProviderException.java index d6fe91621..8e81170b0 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedProviderException.java +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/UnsupportedProviderException.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -19,48 +19,49 @@ * This exception is thrown when there are no {@link MasterKeyProvider}s which which support the * requested {@code provider} value. */ -//@ non_null_by_default +// @ non_null_by_default public class UnsupportedProviderException extends AwsCryptoException { - private static final long serialVersionUID = -1L; + private static final long serialVersionUID = -1L; - /*@ public normal_behavior - @ ensures standardThrowable(); - @*/ - //@ pure - public UnsupportedProviderException() { - } + /*@ public normal_behavior + @ ensures standardThrowable(); + @*/ + // @ pure + public UnsupportedProviderException() {} - /*@ public normal_behavior - @ ensures standardThrowable(message); - @*/ - //@ pure - public UnsupportedProviderException(final String message) { - super(message); - } + /*@ public normal_behavior + @ ensures standardThrowable(message); + @*/ + // @ pure + public UnsupportedProviderException(final String message) { + super(message); + } - /*@ public normal_behavior - @ ensures standardThrowable(cause); - @*/ - //@ pure - public UnsupportedProviderException(final Throwable cause) { - super(cause); - } + /*@ public normal_behavior + @ ensures standardThrowable(cause); + @*/ + // @ pure + public UnsupportedProviderException(final Throwable cause) { + super(cause); + } - /*@ public normal_behavior - @ ensures standardThrowable(message,cause); - @*/ - //@ pure - public UnsupportedProviderException(final String message, final Throwable cause) { - super(message, cause); - } - - /*@ public normal_behavior - @ ensures standardThrowable(message,cause); - @*/ - //@ pure // TODO - public UnsupportedProviderException(final String message, final Throwable cause, - final boolean enableSuppression, final boolean writableStackTrace) { - super(message, cause, enableSuppression, writableStackTrace); - } + /*@ public normal_behavior + @ ensures standardThrowable(message,cause); + @*/ + // @ pure + public UnsupportedProviderException(final String message, final Throwable cause) { + super(message, cause); + } + /*@ public normal_behavior + @ ensures standardThrowable(message,cause); + @*/ + // @ pure // TODO + public UnsupportedProviderException( + final String message, + final Throwable cause, + final boolean enableSuppression, + final boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/exception/package-info.java b/src/main/java/com/amazonaws/encryptionsdk/exception/package-info.java index 613e2205e..935a00835 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/exception/package-info.java +++ b/src/main/java/com/amazonaws/encryptionsdk/exception/package-info.java @@ -1,17 +1,15 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ -/** - * Contains the various exceptions which may be thrown by the AWS Encryption SDK. - */ +/** Contains the various exceptions which may be thrown by the AWS Encryption SDK. */ package com.amazonaws.encryptionsdk.exception; diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/AesGcmJceKeyCipher.java b/src/main/java/com/amazonaws/encryptionsdk/internal/AesGcmJceKeyCipher.java index 7a4511f01..bf7843f35 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/AesGcmJceKeyCipher.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/AesGcmJceKeyCipher.java @@ -13,82 +13,87 @@ package com.amazonaws.encryptionsdk.internal; -import javax.crypto.Cipher; -import javax.crypto.SecretKey; -import javax.crypto.spec.GCMParameterSpec; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.security.InvalidKeyException; import java.security.Key; import java.util.Map; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; -/** - * A JceKeyCipher based on the Advanced Encryption Standard in Galois/Counter Mode. - */ +/** A JceKeyCipher based on the Advanced Encryption Standard in Galois/Counter Mode. */ class AesGcmJceKeyCipher extends JceKeyCipher { - private static final int NONCE_LENGTH = 12; - private static final int TAG_LENGTH = 128; - private static final String TRANSFORMATION = "AES/GCM/NoPadding"; - private static final int SPEC_LENGTH = Integer.BYTES + Integer.BYTES + NONCE_LENGTH; + private static final int NONCE_LENGTH = 12; + private static final int TAG_LENGTH = 128; + private static final String TRANSFORMATION = "AES/GCM/NoPadding"; + private static final int SPEC_LENGTH = Integer.BYTES + Integer.BYTES + NONCE_LENGTH; - AesGcmJceKeyCipher(SecretKey key) { - super(key, key); - } + AesGcmJceKeyCipher(SecretKey key) { + super(key, key); + } - private static byte[] specToBytes(final GCMParameterSpec spec) { - final byte[] nonce = spec.getIV(); - final byte[] result = new byte[SPEC_LENGTH]; - final ByteBuffer buffer = ByteBuffer.wrap(result); - buffer.putInt(spec.getTLen()); - buffer.putInt(nonce.length); - buffer.put(nonce); - return result; + private static byte[] specToBytes(final GCMParameterSpec spec) { + final byte[] nonce = spec.getIV(); + final byte[] result = new byte[SPEC_LENGTH]; + final ByteBuffer buffer = ByteBuffer.wrap(result); + buffer.putInt(spec.getTLen()); + buffer.putInt(nonce.length); + buffer.put(nonce); + return result; + } + + private static GCMParameterSpec bytesToSpec(final byte[] data, final int offset) + throws InvalidKeyException { + if (data.length - offset != SPEC_LENGTH) { + throw new InvalidKeyException("Algorithm specification was an invalid data size"); } - private static GCMParameterSpec bytesToSpec(final byte[] data, final int offset) throws InvalidKeyException { - if (data.length - offset != SPEC_LENGTH) { - throw new InvalidKeyException("Algorithm specification was an invalid data size"); - } + final ByteBuffer buffer = ByteBuffer.wrap(data, offset, SPEC_LENGTH); + final int tagLen = buffer.getInt(); + final int nonceLen = buffer.getInt(); - final ByteBuffer buffer = ByteBuffer.wrap(data, offset, SPEC_LENGTH); - final int tagLen = buffer.getInt(); - final int nonceLen = buffer.getInt(); + if (tagLen != TAG_LENGTH) { + throw new InvalidKeyException( + String.format("Authentication tag length must be %s", TAG_LENGTH)); + } - if (tagLen != TAG_LENGTH) { - throw new InvalidKeyException(String.format("Authentication tag length must be %s", TAG_LENGTH)); - } + if (nonceLen != NONCE_LENGTH) { + throw new InvalidKeyException( + String.format("Initialization vector (IV) length must be %s", NONCE_LENGTH)); + } - if (nonceLen != NONCE_LENGTH) { - throw new InvalidKeyException(String.format("Initialization vector (IV) length must be %s", NONCE_LENGTH)); - } + final byte[] nonce = new byte[nonceLen]; + buffer.get(nonce); - final byte[] nonce = new byte[nonceLen]; - buffer.get(nonce); + return new GCMParameterSpec(tagLen, nonce); + } - return new GCMParameterSpec(tagLen, nonce); - } + @Override + WrappingData buildWrappingCipher(final Key key, final Map encryptionContext) + throws GeneralSecurityException { + final byte[] nonce = new byte[NONCE_LENGTH]; + Utils.getSecureRandom().nextBytes(nonce); + final GCMParameterSpec spec = new GCMParameterSpec(TAG_LENGTH, nonce); + final Cipher cipher = Cipher.getInstance(TRANSFORMATION); + cipher.init(Cipher.ENCRYPT_MODE, key, spec); + final byte[] aad = EncryptionContextSerializer.serialize(encryptionContext); + cipher.updateAAD(aad); + return new WrappingData(cipher, specToBytes(spec)); + } - @Override - WrappingData buildWrappingCipher(final Key key, final Map encryptionContext) - throws GeneralSecurityException { - final byte[] nonce = new byte[NONCE_LENGTH]; - Utils.getSecureRandom().nextBytes(nonce); - final GCMParameterSpec spec = new GCMParameterSpec(TAG_LENGTH, nonce); - final Cipher cipher = Cipher.getInstance(TRANSFORMATION); - cipher.init(Cipher.ENCRYPT_MODE, key, spec); - final byte[] aad = EncryptionContextSerializer.serialize(encryptionContext); - cipher.updateAAD(aad); - return new WrappingData(cipher, specToBytes(spec)); - } - - @Override - Cipher buildUnwrappingCipher(final Key key, final byte[] extraInfo, final int offset, - final Map encryptionContext) throws GeneralSecurityException { - final GCMParameterSpec spec = bytesToSpec(extraInfo, offset); - final Cipher cipher = Cipher.getInstance(TRANSFORMATION); - cipher.init(Cipher.DECRYPT_MODE, key, spec); - final byte[] aad = EncryptionContextSerializer.serialize(encryptionContext); - cipher.updateAAD(aad); - return cipher; - } + @Override + Cipher buildUnwrappingCipher( + final Key key, + final byte[] extraInfo, + final int offset, + final Map encryptionContext) + throws GeneralSecurityException { + final GCMParameterSpec spec = bytesToSpec(extraInfo, offset); + final Cipher cipher = Cipher.getInstance(TRANSFORMATION); + cipher.init(Cipher.DECRYPT_MODE, key, spec); + final byte[] aad = EncryptionContextSerializer.serialize(encryptionContext); + cipher.updateAAD(aad); + return cipher; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/AwsKmsCmkArnInfo.java b/src/main/java/com/amazonaws/encryptionsdk/internal/AwsKmsCmkArnInfo.java index c1dddde1c..5c306622d 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/AwsKmsCmkArnInfo.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/AwsKmsCmkArnInfo.java @@ -2,345 +2,314 @@ import java.util.Arrays; - /** - * A class to parse and handle AWS KMS identifiers. - * Mostly AWS KMS ARNs but raw resources - * are also used in the AWS Encryption SDK. + * A class to parse and handle AWS KMS identifiers. Mostly AWS KMS ARNs but raw resources are also + * used in the AWS Encryption SDK. */ public final class AwsKmsCmkArnInfo { - final private static String arnLiteral = "arn"; - final private static String kmsServiceName = "kms"; - - /** - * Takes an AWS KMS identifier that may or may not be an ARN - * and attempts to parse the identifier as an ARN. - * If the identifier is not an ARN, it returns - * null. This is an expected condition, not an error. - * - * @param keyArn The string to parse - */ - public static AwsKmsCmkArnInfo parseInfoFromKeyArn(final String keyArn) { - /* Precondition: keyArn must be a string. */ - if (keyArn == null || keyArn.isEmpty()) return null; - - final String[] parts = AwsKmsArnParts.splitArn(keyArn); - - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //# MUST start with string "arn" - if (!arnLiteral.equals(parts[AwsKmsArnParts.ArnLiteral.index()])) { - return null; - } - - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //# The service MUST be the string "kms" - if (!kmsServiceName.equals(parts[AwsKmsArnParts.Service.index()])) { - return null; - } - - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //# The partition MUST be a non-empty - // - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //# The region MUST be a non-empty string - // - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //# The account MUST be a non-empty string - // - final boolean emptyParts = Arrays.stream(parts).anyMatch(String::isEmpty); - if (emptyParts || AwsKmsArnParts.values().length != parts.length) return null; - - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //# The resource section MUST be non-empty and MUST be split by a - //# single "/" any additional "/" are included in the resource id - String[] resourceParts = AwsKmsArnParts - .Resource - .splitResourceParts(parts[AwsKmsArnParts.ResourceParts.index()]); - - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //# The resource id MUST be a non-empty string - if (Arrays.stream(resourceParts).anyMatch(String::isEmpty) - || AwsKmsArnParts.Resource.values().length > resourceParts.length - ) { - return null; - } - - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //# The resource type MUST be either "alias" or "key" - if (!("key".equals(resourceParts[AwsKmsArnParts.Resource.ResourceType.index()]) - || "alias".equals(resourceParts[AwsKmsArnParts.Resource.ResourceType.index()]))) { - return null; - } - - return new AwsKmsCmkArnInfo( - parts[AwsKmsArnParts.Partition.index()], - parts[AwsKmsArnParts.Region.index()], - parts[AwsKmsArnParts.Account.index()], - resourceParts[AwsKmsArnParts.Resource.ResourceType.index()], - resourceParts[AwsKmsArnParts.Resource.Resource.index()] - ); + private static final String arnLiteral = "arn"; + private static final String kmsServiceName = "kms"; + + /** + * Takes an AWS KMS identifier that may or may not be an ARN and attempts to parse the identifier + * as an ARN. If the identifier is not an ARN, it returns null. This is an expected condition, not + * an error. + * + * @param keyArn The string to parse + */ + public static AwsKmsCmkArnInfo parseInfoFromKeyArn(final String keyArn) { + /* Precondition: keyArn must be a string. */ + if (keyArn == null || keyArn.isEmpty()) return null; + + final String[] parts = AwsKmsArnParts.splitArn(keyArn); + + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // # MUST start with string "arn" + if (!arnLiteral.equals(parts[AwsKmsArnParts.ArnLiteral.index()])) { + return null; } - /** Takes a string an will throw if this identifier is invalid - * Raw resources like a key ID or alias - * `mrk-edb7fe6942894d32ac46dbb1c922d574`, `alias/my-alias` - * or ARNs like - * arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574 - * arn:aws:kms:us-west-2:111122223333:alias/my-alias - * - * @param identifier an identifier that is an ARN or raw resource - */ - public static void validAwsKmsIdentifier(final String identifier) { - /* Exceptional Postcondition: Null or empty string is not a valid identifier. */ - if (identifier == null || identifier.isEmpty()) { - throw new IllegalArgumentException("Null or empty string is not a valid Aws KMS identifier."); - } - - /* Exceptional Postcondition: Things that start with `arn:` MUST be ARNs. */ - if (identifier.startsWith("arn:") && parseInfoFromKeyArn(identifier) == null) { - throw new IllegalArgumentException("Invalid ARN used as an identifier."); - }; - /* Postcondition: Raw alias starts with `alias/`. */ - if (identifier.startsWith("alias/")) return; - - /* Postcondition: There are no requirements on key ids. - * Even thought they look like UUID, this is not required. - * Take multi region keys: mrk-edb7fe6942894d32ac46dbb1c922d574 - */ - return; + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // # The service MUST be the string "kms" + if (!kmsServiceName.equals(parts[AwsKmsArnParts.Service.index()])) { + return null; } - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //# This function MUST take a single AWS KMS identifier - /** - * Identifies Multi Region AWS KMS keys. - * This can misidentify an alias that starts with "mrk-". - * - */ - public static boolean isMRK(final String resource) { - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //# If the input starts with "arn:", this MUST return the output of - //# identifying an an AWS KMS multi-Region ARN (aws-kms-key- - //# arn.md#identifying-an-an-aws-kms-multi-region-arn) called with this - //# input. - if (resource.startsWith("arn:")) return isMRK(parseInfoFromKeyArn(resource)); - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //# If the input starts with "alias/", this an AWS KMS alias and - //# not a multi-Region key id and MUST return false. - if (resource.startsWith("alias/")) return false; - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //# If the input starts - //# with "mrk-", this is a multi-Region key id and MUST return true. - // - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //# If - //# the input does not start with any of the above, this is not a multi- - //# Region key id and MUST return false. - return resource.startsWith("mrk-"); + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // # The partition MUST be a non-empty + // + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // # The region MUST be a non-empty string + // + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // # The account MUST be a non-empty string + // + final boolean emptyParts = Arrays.stream(parts).anyMatch(String::isEmpty); + if (emptyParts || AwsKmsArnParts.values().length != parts.length) return null; + + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // # The resource section MUST be non-empty and MUST be split by a + // # single "/" any additional "/" are included in the resource id + String[] resourceParts = + AwsKmsArnParts.Resource.splitResourceParts(parts[AwsKmsArnParts.ResourceParts.index()]); + + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // # The resource id MUST be a non-empty string + if (Arrays.stream(resourceParts).anyMatch(String::isEmpty) + || AwsKmsArnParts.Resource.values().length > resourceParts.length) { + return null; } - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //# This function MUST take a single AWS KMS ARN - /** - * Identifies Multi Region AWS KMS keys. - * The resource type check is to protect against the edge case where an alias starts with - * `mrk-` * e.g. arn:aws:kms:us-west-2:111122223333:alias/mrk-someOtherName - * - */ - public static boolean isMRK(final AwsKmsCmkArnInfo arn) { - - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //# If the input is an invalid AWS KMS ARN this function MUST error. - if (arn == null) throw new Error("Invalid Arn"); - - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //# If resource type is "alias", this is an AWS KMS alias ARN and MUST - //# return false. - // - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //# If resource type is "key" and resource ID starts with - //# "mrk-", this is a AWS KMS multi-Region key ARN and MUST return true. - // - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //# If resource type is "key" and resource ID does not start with "mrk-", - //# this is a (single-region) AWS KMS key ARN and MUST return false. - return isMRK(arn.getResource()) && arn.getResourceType().equals("key"); + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // # The resource type MUST be either "alias" or "key" + if (!("key".equals(resourceParts[AwsKmsArnParts.Resource.ResourceType.index()]) + || "alias".equals(resourceParts[AwsKmsArnParts.Resource.ResourceType.index()]))) { + return null; } - //= compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 - //# The caller MUST provide: - /** - * Tell if two different AWS KMS ARNs match. - * For identical keys this is trivial, - * but multi-Region keys can match across regions. - * - */ - public static boolean awsKmsArnMatchForDecrypt( - final String configuredKeyIdentifier, - final String providerInfoKeyIdentifier - ) { - //= compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 - //# If both identifiers are identical, this function MUST return "true". - if (configuredKeyIdentifier.equals(providerInfoKeyIdentifier)) return true; - - final AwsKmsCmkArnInfo configuredArnInfo = parseInfoFromKeyArn(configuredKeyIdentifier); - final AwsKmsCmkArnInfo providerInfoKeyArnInfo = parseInfoFromKeyArn(providerInfoKeyIdentifier); - - /* Check for early return (Postcondition): Both identifiers are not ARNs and not equal, therefore they can not match. */ - if (providerInfoKeyArnInfo == null || configuredArnInfo == null) return false; - - //= compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 - //# Otherwise if either input is not identified as a multi-Region key - //# (aws-kms-key-arn.md#identifying-an-aws-kms-multi-region-key), then - //# this function MUST return "false". - if (!isMRK(configuredArnInfo) || !isMRK(providerInfoKeyArnInfo)) return false; - - //= compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 - //# Otherwise if both inputs are - //# identified as a multi-Region keys (aws-kms-key-arn.md#identifying-an- - //# aws-kms-multi-region-key), this function MUST return the result of - //# comparing the "partition", "service", "accountId", "resourceType", - //# and "resource" parts of both ARN inputs. - //Service is not matched because AwsKmsCmkArnInfo only allows a service of `kms`. - return configuredArnInfo.getPartition().equals(providerInfoKeyArnInfo.getPartition()) && - configuredArnInfo.getAccountId().equals(providerInfoKeyArnInfo.getAccountId()) && - configuredArnInfo.getResourceType().equals(providerInfoKeyArnInfo.getResourceType()) && - configuredArnInfo.getResource().equals(providerInfoKeyArnInfo.getResource()); + return new AwsKmsCmkArnInfo( + parts[AwsKmsArnParts.Partition.index()], + parts[AwsKmsArnParts.Region.index()], + parts[AwsKmsArnParts.Account.index()], + resourceParts[AwsKmsArnParts.Resource.ResourceType.index()], + resourceParts[AwsKmsArnParts.Resource.Resource.index()]); + } + + /** + * Takes a string an will throw if this identifier is invalid Raw resources like a key ID or alias + * `mrk-edb7fe6942894d32ac46dbb1c922d574`, `alias/my-alias` or ARNs like + * arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574 + * arn:aws:kms:us-west-2:111122223333:alias/my-alias + * + * @param identifier an identifier that is an ARN or raw resource + */ + public static void validAwsKmsIdentifier(final String identifier) { + /* Exceptional Postcondition: Null or empty string is not a valid identifier. */ + if (identifier == null || identifier.isEmpty()) { + throw new IllegalArgumentException("Null or empty string is not a valid Aws KMS identifier."); } - private final String partition_; - private final String accountId_; - private final String region_; - private final String resource_; - private final String resourceType_; - - /** - * Data structure to hold the parts of an AWS KMS ARN - * - */ - AwsKmsCmkArnInfo( - String partition, - String region, - String accountId, - String resourceType, - String resource - ) { - partition_ = partition; - region_ = region; - accountId_ = accountId; - resourceType_ = resourceType; - resource_ = resource; + /* Exceptional Postcondition: Things that start with `arn:` MUST be ARNs. */ + if (identifier.startsWith("arn:") && parseInfoFromKeyArn(identifier) == null) { + throw new IllegalArgumentException("Invalid ARN used as an identifier."); } + ; + /* Postcondition: Raw alias starts with `alias/`. */ + if (identifier.startsWith("alias/")) return; - public String getPartition() { - return partition_; + /* Postcondition: There are no requirements on key ids. + * Even thought they look like UUID, this is not required. + * Take multi region keys: mrk-edb7fe6942894d32ac46dbb1c922d574 + */ + return; + } + + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // # This function MUST take a single AWS KMS identifier + /** + * Identifies Multi Region AWS KMS keys. This can misidentify an alias that starts with "mrk-". + */ + public static boolean isMRK(final String resource) { + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // # If the input starts with "arn:", this MUST return the output of + // # identifying an an AWS KMS multi-Region ARN (aws-kms-key- + // # arn.md#identifying-an-an-aws-kms-multi-region-arn) called with this + // # input. + if (resource.startsWith("arn:")) return isMRK(parseInfoFromKeyArn(resource)); + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // # If the input starts with "alias/", this an AWS KMS alias and + // # not a multi-Region key id and MUST return false. + if (resource.startsWith("alias/")) return false; + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // # If the input starts + // # with "mrk-", this is a multi-Region key id and MUST return true. + // + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // # If + // # the input does not start with any of the above, this is not a multi- + // # Region key id and MUST return false. + return resource.startsWith("mrk-"); + } + + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // # This function MUST take a single AWS KMS ARN + /** + * Identifies Multi Region AWS KMS keys. The resource type check is to protect against the edge + * case where an alias starts with `mrk-` * e.g. + * arn:aws:kms:us-west-2:111122223333:alias/mrk-someOtherName + */ + public static boolean isMRK(final AwsKmsCmkArnInfo arn) { + + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // # If the input is an invalid AWS KMS ARN this function MUST error. + if (arn == null) throw new Error("Invalid Arn"); + + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // # If resource type is "alias", this is an AWS KMS alias ARN and MUST + // # return false. + // + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // # If resource type is "key" and resource ID starts with + // # "mrk-", this is a AWS KMS multi-Region key ARN and MUST return true. + // + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // # If resource type is "key" and resource ID does not start with "mrk-", + // # this is a (single-region) AWS KMS key ARN and MUST return false. + return isMRK(arn.getResource()) && arn.getResourceType().equals("key"); + } + + // = compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 + // # The caller MUST provide: + /** + * Tell if two different AWS KMS ARNs match. For identical keys this is trivial, but multi-Region + * keys can match across regions. + */ + public static boolean awsKmsArnMatchForDecrypt( + final String configuredKeyIdentifier, final String providerInfoKeyIdentifier) { + // = compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 + // # If both identifiers are identical, this function MUST return "true". + if (configuredKeyIdentifier.equals(providerInfoKeyIdentifier)) return true; + + final AwsKmsCmkArnInfo configuredArnInfo = parseInfoFromKeyArn(configuredKeyIdentifier); + final AwsKmsCmkArnInfo providerInfoKeyArnInfo = parseInfoFromKeyArn(providerInfoKeyIdentifier); + + /* Check for early return (Postcondition): Both identifiers are not ARNs and not equal, therefore they can not match. */ + if (providerInfoKeyArnInfo == null || configuredArnInfo == null) return false; + + // = compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 + // # Otherwise if either input is not identified as a multi-Region key + // # (aws-kms-key-arn.md#identifying-an-aws-kms-multi-region-key), then + // # this function MUST return "false". + if (!isMRK(configuredArnInfo) || !isMRK(providerInfoKeyArnInfo)) return false; + + // = compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 + // # Otherwise if both inputs are + // # identified as a multi-Region keys (aws-kms-key-arn.md#identifying-an- + // # aws-kms-multi-region-key), this function MUST return the result of + // # comparing the "partition", "service", "accountId", "resourceType", + // # and "resource" parts of both ARN inputs. + // Service is not matched because AwsKmsCmkArnInfo only allows a service of `kms`. + return configuredArnInfo.getPartition().equals(providerInfoKeyArnInfo.getPartition()) + && configuredArnInfo.getAccountId().equals(providerInfoKeyArnInfo.getAccountId()) + && configuredArnInfo.getResourceType().equals(providerInfoKeyArnInfo.getResourceType()) + && configuredArnInfo.getResource().equals(providerInfoKeyArnInfo.getResource()); + } + + private final String partition_; + private final String accountId_; + private final String region_; + private final String resource_; + private final String resourceType_; + + /** Data structure to hold the parts of an AWS KMS ARN */ + AwsKmsCmkArnInfo( + String partition, String region, String accountId, String resourceType, String resource) { + partition_ = partition; + region_ = region; + accountId_ = accountId; + resourceType_ = resourceType; + resource_ = resource; + } + + public String getPartition() { + return partition_; + } + + public String getAccountId() { + return accountId_; + } + + public String getRegion() { + return region_; + } + + public String getResourceType() { + return resourceType_; + } + + public String getResource() { + return resource_; + } + + /** Returns the well-formed ARN this object describes. */ + @Override + public String toString() { + return toString(region_); + } + + /** + * AWS KMS multi-Region keys can have replicas in other region. A compatible ARN in a different + * Region may be required. + * + * @param mrkRegion The region to use instead of the region in the ARN + */ + public String toString(String mrkRegion) { + return String.join( + AwsKmsArnParts.Delimiter, + arnLiteral, + partition_, + kmsServiceName, + mrkRegion, + accountId_, + String.join(AwsKmsArnParts.Resource.ResourceDelimiter, resourceType_, resource_)); + } + + /** + * Structure information about an ARN. This structure is only expecting to process AWS KMS ARNs + * see https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more + * details. + */ + enum AwsKmsArnParts { + ArnLiteral(0), + Partition(1), + Service(2), + Region(3), + Account(4), + ResourceParts(5); + + int index_; + + AwsKmsArnParts(int i) { + index_ = i; } - public String getAccountId() { - return accountId_; + int index() { + return index_; } - public String getRegion() { - return region_; + public static String[] splitArn(String arn) { + return arn.split(AwsKmsArnParts.Delimiter, AwsKmsArnParts.values().length); } - public String getResourceType() { return resourceType_; } - - public String getResource() { return resource_; } - + static String Delimiter = ":"; /** - * Returns the well-formed ARN this object describes. + * Structure information about the resource part of an ARN This structure is only expecting to + * process AWS KMS ARNs see + * https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more details. * + *

Of note, is that the ARN specification lets the `/` also be a `:` however AWS KMS does not + * support this. AWS KMS _only_ uses `/` to delimit the resource type and resource. */ - @Override - public String toString() { - return toString(region_); - } + enum Resource { + ResourceType(0), + Resource(1); - /** - * AWS KMS multi-Region keys can have replicas in other region. - * A compatible ARN in a different Region may be required. - * - * @param mrkRegion The region to use instead of the region in the ARN - */ - public String toString(String mrkRegion) { - return String.join( - AwsKmsArnParts.Delimiter, - arnLiteral, - partition_, - kmsServiceName, - mrkRegion, - accountId_, - String.join( - AwsKmsArnParts.Resource.ResourceDelimiter, - resourceType_, - resource_)); - } + static String ResourceDelimiter = "/"; - /** - * Structure information about an ARN. - * This structure is only expecting - * to process AWS KMS ARNs - * see https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html - * for more details. - * - */ - enum AwsKmsArnParts { - ArnLiteral(0), - Partition(1), - Service(2), - Region(3), - Account(4), - ResourceParts(5); - - int index_; - AwsKmsArnParts(int i) { - index_ = i; - } - int index() { - return index_; - } - - public static String[] splitArn(String arn) { - return arn.split( - AwsKmsArnParts.Delimiter, - AwsKmsArnParts.values().length); - } - - static String Delimiter = ":"; - - /** - * Structure information about the resource part of an ARN - * This structure is only expecting - * to process AWS KMS ARNs - * see https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html - * for more details. - * - * Of note, is that the ARN specification lets the `/` also be a `:` - * however AWS KMS does not support this. - * AWS KMS _only_ uses `/` to delimit the resource type and resource. - * - */ - enum Resource { - ResourceType(0), - Resource(1); - - static String ResourceDelimiter = "/"; - - int index_; - Resource(int i) { - index_ = i; - } - int index() { - return index_; - } - - public static String[] splitResourceParts(String resource) { - return resource.split( - ResourceDelimiter, - 2); - } - } + int index_; + + Resource(int i) { + index_ = i; + } + + int index() { + return index_; + } + + public static String[] splitResourceParts(String resource) { + return resource.split(ResourceDelimiter, 2); + } } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/BlockDecryptionHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/BlockDecryptionHandler.java index bfb42333d..e5b5957c4 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/BlockDecryptionHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/BlockDecryptionHandler.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -17,233 +17,206 @@ import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.model.CipherBlockHeaders; - +import java.util.Arrays; import javax.crypto.Cipher; import javax.crypto.SecretKey; -import java.util.Arrays; /** - * The block decryption handler is an implementation of CryptoHandler that - * provides methods to decrypt content encrypted and stored in a single block. - * - *

- * In this SDK, this class decrypts content that is encrypted by - * {@link BlockEncryptionHandler}. + * The block decryption handler is an implementation of CryptoHandler that provides methods to + * decrypt content encrypted and stored in a single block. + * + *

In this SDK, this class decrypts content that is encrypted by {@link BlockEncryptionHandler}. */ class BlockDecryptionHandler implements CryptoHandler { - private final SecretKey decryptionKey_; - private final short nonceLen_; - private final CryptoAlgorithm cryptoAlgo_; - private final byte[] messageId_; - private final CipherBlockHeaders blockHeaders_; - - private final byte[] bytesToDecrypt_ = new byte[0]; - private byte[] unparsedBytes_ = new byte[0]; - private boolean complete_ = false; - - /** - * Construct a decryption handler for decrypting bytes stored in a single - * block. - * - * @param decryptionKey - * the key to use for decrypting the ciphertext - * @param nonceLen - * the length to use when parsing the nonce in the block headers. - * @param cryptoAlgo - * the crypto algorithm to use for decrypting the ciphertext - * @param messageId - * the byte array containing the message identifier that is used - * in binding the encrypted content to the headers in the - * ciphertext. - */ - public BlockDecryptionHandler(final SecretKey decryptionKey, final short nonceLen, - final CryptoAlgorithm cryptoAlgo, final byte[] messageId) { - decryptionKey_ = decryptionKey; - nonceLen_ = nonceLen; - cryptoAlgo_ = cryptoAlgo; - messageId_ = messageId; - blockHeaders_ = new CipherBlockHeaders(); + private final SecretKey decryptionKey_; + private final short nonceLen_; + private final CryptoAlgorithm cryptoAlgo_; + private final byte[] messageId_; + private final CipherBlockHeaders blockHeaders_; + + private final byte[] bytesToDecrypt_ = new byte[0]; + private byte[] unparsedBytes_ = new byte[0]; + private boolean complete_ = false; + + /** + * Construct a decryption handler for decrypting bytes stored in a single block. + * + * @param decryptionKey the key to use for decrypting the ciphertext + * @param nonceLen the length to use when parsing the nonce in the block headers. + * @param cryptoAlgo the crypto algorithm to use for decrypting the ciphertext + * @param messageId the byte array containing the message identifier that is used in binding the + * encrypted content to the headers in the ciphertext. + */ + public BlockDecryptionHandler( + final SecretKey decryptionKey, + final short nonceLen, + final CryptoAlgorithm cryptoAlgo, + final byte[] messageId) { + decryptionKey_ = decryptionKey; + nonceLen_ = nonceLen; + cryptoAlgo_ = cryptoAlgo; + messageId_ = messageId; + blockHeaders_ = new CipherBlockHeaders(); + } + + /** + * Decrypt the ciphertext bytes provided in {@code in} containing the encrypted bytes of the + * plaintext stored in a single block. The decrypted bytes are copied into {@code out} starting at + * {@code outOff}. + * + *

This method performs two operations: parses the headers of the single block structure in the + * ciphertext and processes the encrypted content following the headers and decrypts it. + * + * @param in the input byte array. + * @param off the offset into the in array where the data to be decrypted starts. + * @param len the number of bytes to be decrypted. + * @param out the output buffer the decrypted plaintext bytes go into. + * @param outOff the offset into the output byte array the decrypted data starts at. + * @return the number of bytes written to out. + * @throws AwsCryptoException if the content type found in the headers is not of single-block + * type. + */ + @Override + public synchronized ProcessingSummary processBytes( + final byte[] in, final int off, final int len, final byte[] out, final int outOff) + throws AwsCryptoException { + + if (complete_) { + throw new AwsCryptoException("Ciphertext has already been processed."); } - /** - * Decrypt the ciphertext bytes provided in {@code in} containing the - * encrypted bytes of the plaintext stored in a single block. The decrypted - * bytes are copied into {@code out} starting at {@code outOff}. - * - * This method performs two operations: parses the headers of the single - * block structure in the ciphertext and processes the encrypted content - * following the headers and decrypts it. - * - * @param in - * the input byte array. - * @param off - * the offset into the in array where the data to be decrypted - * starts. - * @param len - * the number of bytes to be decrypted. - * @param out - * the output buffer the decrypted plaintext bytes go into. - * @param outOff - * the offset into the output byte array the decrypted data - * starts at. - * @return - * the number of bytes written to out. - * @throws AwsCryptoException - * if the content type found in the headers is not of - * single-block type. - */ - @Override - synchronized public ProcessingSummary processBytes(final byte[] in, final int off, final int len, - final byte[] out, - final int outOff) throws AwsCryptoException { - - if (complete_) { - throw new AwsCryptoException("Ciphertext has already been processed."); + final byte[] bytesToParse = new byte[unparsedBytes_.length + len]; + // If there were previously unparsed bytes, add them as the first + // set of bytes to be parsed in this call. + System.arraycopy(unparsedBytes_, 0, bytesToParse, 0, unparsedBytes_.length); + System.arraycopy(in, off, bytesToParse, unparsedBytes_.length, len); + + long parsedBytes = 0; + + // Parse available bytes. Stop parsing when there aren't enough + // bytes to complete parsing of the : + // - the blockcipher headers + // - encrypted content + while (!complete_ && parsedBytes < bytesToParse.length) { + blockHeaders_.setNonceLength(nonceLen_); + + parsedBytes += blockHeaders_.deserialize(bytesToParse, (int) parsedBytes); + if (parsedBytes > Integer.MAX_VALUE) { + throw new AwsCryptoException( + "Integer overflow of the total bytes to parse and decrypt occured."); + } + + // if we have all header fields, process the encrypted content. + if (blockHeaders_.isComplete() == true) { + if (blockHeaders_.getContentLength() > Integer.MAX_VALUE) { + throw new AwsCryptoException("Content length exceeds the maximum allowed value."); } + int protectedContentLen = (int) blockHeaders_.getContentLength(); - final byte[] bytesToParse = new byte[unparsedBytes_.length + len]; - // If there were previously unparsed bytes, add them as the first - // set of bytes to be parsed in this call. - System.arraycopy(unparsedBytes_, 0, bytesToParse, 0, unparsedBytes_.length); - System.arraycopy(in, off, bytesToParse, unparsedBytes_.length, len); - - long parsedBytes = 0; - - // Parse available bytes. Stop parsing when there aren't enough - // bytes to complete parsing of the : - // - the blockcipher headers - // - encrypted content - while (!complete_ && parsedBytes < bytesToParse.length) { - blockHeaders_.setNonceLength(nonceLen_); - - parsedBytes += blockHeaders_.deserialize(bytesToParse, (int) parsedBytes); - if (parsedBytes > Integer.MAX_VALUE) { - throw new AwsCryptoException( - "Integer overflow of the total bytes to parse and decrypt occured."); - } - - // if we have all header fields, process the encrypted content. - if (blockHeaders_.isComplete() == true) { - if (blockHeaders_.getContentLength() > Integer.MAX_VALUE) { - throw new AwsCryptoException("Content length exceeds the maximum allowed value."); - } - int protectedContentLen = (int) blockHeaders_.getContentLength(); - - // include the tag which is added by the underlying cipher. - protectedContentLen += cryptoAlgo_.getTagLen(); - - if ((bytesToParse.length - parsedBytes) < protectedContentLen) { - // if we don't have all of the encrypted bytes, break - // until they become available. - break; - } - byte[] plaintext = decryptContent(bytesToParse, (int) parsedBytes, protectedContentLen); - System.arraycopy(plaintext, 0, out, outOff, plaintext.length); - - complete_ = true; - return new ProcessingSummary(plaintext.length, (int) (parsedBytes + protectedContentLen) - - unparsedBytes_.length); - } else { - // if there aren't enough bytes to parse the block headers, - // we can't continue parsing. - break; - } - } - - // buffer remaining bytes for parsing in the next round. - unparsedBytes_ = Arrays.copyOfRange(bytesToParse, (int) parsedBytes, bytesToParse.length); - - return new ProcessingSummary(0, len); - } + // include the tag which is added by the underlying cipher. + protectedContentLen += cryptoAlgo_.getTagLen(); - /** - * Finish processing of the bytes by decrypting the ciphertext. - * - * @param out - * space for any resulting output data. - * @param outOff - * offset into {@code out} to start copying the data at. - * @return - * number of bytes written into {@code out}. - * @throws BadCiphertextException - * if the bytes do not decrypt correctly. - */ - @Override - synchronized public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException { - if (!complete_) { - throw new BadCiphertextException("Unable to process entire ciphertext."); + if ((bytesToParse.length - parsedBytes) < protectedContentLen) { + // if we don't have all of the encrypted bytes, break + // until they become available. + break; } - return 0; + byte[] plaintext = decryptContent(bytesToParse, (int) parsedBytes, protectedContentLen); + System.arraycopy(plaintext, 0, out, outOff, plaintext.length); + + complete_ = true; + return new ProcessingSummary( + plaintext.length, (int) (parsedBytes + protectedContentLen) - unparsedBytes_.length); + } else { + // if there aren't enough bytes to parse the block headers, + // we can't continue parsing. + break; + } } - /** - * Return the size of the output buffer required for a processBytes plus a - * doFinal with an input of inLen bytes. - * - * @param inLen - * the length of the input. - * @return - * the space required to accommodate a call to processBytes and - * doFinal with len bytes of input. - */ - @Override - synchronized public int estimateOutputSize(final int inLen) { - // include any buffered bytes - int outSize = bytesToDecrypt_.length + unparsedBytes_.length; - - if (inLen > 0) { - outSize += inLen; - } - - return outSize; + // buffer remaining bytes for parsing in the next round. + unparsedBytes_ = Arrays.copyOfRange(bytesToParse, (int) parsedBytes, bytesToParse.length); + + return new ProcessingSummary(0, len); + } + + /** + * Finish processing of the bytes by decrypting the ciphertext. + * + * @param out space for any resulting output data. + * @param outOff offset into {@code out} to start copying the data at. + * @return number of bytes written into {@code out}. + * @throws BadCiphertextException if the bytes do not decrypt correctly. + */ + @Override + public synchronized int doFinal(final byte[] out, final int outOff) + throws BadCiphertextException { + if (!complete_) { + throw new BadCiphertextException("Unable to process entire ciphertext."); } - - @Override - public int estimatePartialOutputSize(int inLen) { - return estimateOutputSize(inLen); + return 0; + } + + /** + * Return the size of the output buffer required for a processBytes plus a doFinal with an input + * of inLen bytes. + * + * @param inLen the length of the input. + * @return the space required to accommodate a call to processBytes and doFinal with len bytes of + * input. + */ + @Override + public synchronized int estimateOutputSize(final int inLen) { + // include any buffered bytes + int outSize = bytesToDecrypt_.length + unparsedBytes_.length; + + if (inLen > 0) { + outSize += inLen; } - @Override - public int estimateFinalOutputSize() { - return estimateOutputSize(0); + return outSize; + } + + @Override + public int estimatePartialOutputSize(int inLen) { + return estimateOutputSize(inLen); + } + + @Override + public int estimateFinalOutputSize() { + return estimateOutputSize(0); + } + + /** + * Returns the plaintext bytes of the encrypted content. + * + * @param input the input bytes containing the content + * @param off the offset into the input array where the data to be decrypted starts. + * @param len the number of bytes to be decrypted. + * @return the plaintext bytes of the encrypted content. + * @throws BadCiphertextException if the MAC tag verification fails or an invalid header value is + * found. + */ + private byte[] decryptContent(final byte[] input, final int off, final int len) + throws BadCiphertextException { + if (blockHeaders_.isComplete() == false) { + return new byte[0]; } - /** - * Returns the plaintext bytes of the encrypted content. - * - * @param input - * the input bytes containing the content - * @param off - * the offset into the input array where the data to be decrypted - * starts. - * @param len - * the number of bytes to be decrypted. - * @return - * the plaintext bytes of the encrypted content. - * @throws BadCiphertextException - * if the MAC tag verification fails or an invalid header value - * is found. - */ - private byte[] decryptContent(final byte[] input, final int off, final int len) throws BadCiphertextException { - if (blockHeaders_.isComplete() == false) { - return new byte[0]; - } + final byte[] nonce = blockHeaders_.getNonce(); + final int seqNum = 1; // always 1 for single block case. - final byte[] nonce = blockHeaders_.getNonce(); - final int seqNum = 1; // always 1 for single block case. + final byte[] contentAad = + Utils.generateContentAad( + messageId_, Constants.SINGLE_BLOCK_STRING_ID, seqNum, blockHeaders_.getContentLength()); - final byte[] contentAad = Utils.generateContentAad( - messageId_, - Constants.SINGLE_BLOCK_STRING_ID, - seqNum, - blockHeaders_.getContentLength()); + final CipherHandler cipherHandler = + new CipherHandler(decryptionKey_, Cipher.DECRYPT_MODE, cryptoAlgo_); + return cipherHandler.cipherData(nonce, contentAad, input, off, len); + } - final CipherHandler cipherHandler = new CipherHandler(decryptionKey_, Cipher.DECRYPT_MODE, cryptoAlgo_); - return cipherHandler.cipherData(nonce, contentAad, input, off, len); - } - - @Override - public boolean isComplete() { - return complete_; - } + @Override + public boolean isComplete() { + return complete_; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/BlockEncryptionHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/BlockEncryptionHandler.java index d50b983ca..8ba6d36bf 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/BlockEncryptionHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/BlockEncryptionHandler.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,202 +13,181 @@ package com.amazonaws.encryptionsdk.internal; -import java.io.ByteArrayOutputStream; -import javax.crypto.Cipher; -import javax.crypto.SecretKey; - import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.model.CipherBlockHeaders; +import java.io.ByteArrayOutputStream; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; /** - * The block encryption handler is an implementation of {@link MessageCryptoHandler} - * that provides methods to encrypt content and store it in a single - * block. - * - *

- * In this SDK, content encrypted by this class is decrypted by the - * {@link BlockDecryptionHandler}. + * The block encryption handler is an implementation of {@link MessageCryptoHandler} that provides + * methods to encrypt content and store it in a single block. + * + *

In this SDK, content encrypted by this class is decrypted by the {@link + * BlockDecryptionHandler}. */ class BlockEncryptionHandler implements CryptoHandler { - private final SecretKey encryptionKey_; - private final CryptoAlgorithm cryptoAlgo_; - private final int nonceLen_; - private final byte[] messageId_; - private final int tagLenBytes_; - - private final ByteArrayOutputStream bytesToEncryptStream_ = new ByteArrayOutputStream(1024); - - private boolean complete_ = false; - - /** - * Construct an encryption handler for encrypting bytes and storing them in - * a single block. - * - * @param encryptionKey - * the key to use for encrypting the plaintext - * @param nonceLen - * the length of the nonce to use when encrypting the plaintext - * @param cryptoAlgo - * the crypto algorithm to use for encrypting the plaintext - * @param messageId - * the byte array containing the message identifier that is used - * in binding the encrypted content to the headers in the - * ciphertext. - */ - public BlockEncryptionHandler(final SecretKey encryptionKey, final int nonceLen, final CryptoAlgorithm cryptoAlgo, - final byte[] messageId) { - encryptionKey_ = encryptionKey; - cryptoAlgo_ = cryptoAlgo; - nonceLen_ = nonceLen; - messageId_ = messageId.clone(); - tagLenBytes_ = cryptoAlgo_.getTagLen(); + private final SecretKey encryptionKey_; + private final CryptoAlgorithm cryptoAlgo_; + private final int nonceLen_; + private final byte[] messageId_; + private final int tagLenBytes_; + + private final ByteArrayOutputStream bytesToEncryptStream_ = new ByteArrayOutputStream(1024); + + private boolean complete_ = false; + + /** + * Construct an encryption handler for encrypting bytes and storing them in a single block. + * + * @param encryptionKey the key to use for encrypting the plaintext + * @param nonceLen the length of the nonce to use when encrypting the plaintext + * @param cryptoAlgo the crypto algorithm to use for encrypting the plaintext + * @param messageId the byte array containing the message identifier that is used in binding the + * encrypted content to the headers in the ciphertext. + */ + public BlockEncryptionHandler( + final SecretKey encryptionKey, + final int nonceLen, + final CryptoAlgorithm cryptoAlgo, + final byte[] messageId) { + encryptionKey_ = encryptionKey; + cryptoAlgo_ = cryptoAlgo; + nonceLen_ = nonceLen; + messageId_ = messageId.clone(); + tagLenBytes_ = cryptoAlgo_.getTagLen(); + } + + /** + * Encrypt the block of bytes provide in {@code in} and copy the resulting ciphertext bytes into + * {@code out}. + * + * @param in the input byte array containing plaintext bytes. + * @param off the offset into {@code in} where the data to be encrypted starts. + * @param len the number of bytes to be encrypted. + * @param out the output buffer the encrypted bytes are copied into. + * @param outOff the offset into the output byte array the encrypted data starts at. + * @return the number of bytes written to {@code out} and the number of bytes processed + */ + @Override + public ProcessingSummary processBytes( + final byte[] in, final int off, final int len, final byte[] out, final int outOff) { + bytesToEncryptStream_.write(in, off, len); + return new ProcessingSummary(0, len); + } + + /** + * Finish encryption of the plaintext bytes. + * + * @param out space for any resulting output data. + * @param outOff offset into {@code out} to start copying the data at. + * @return number of bytes written into {@code out}. + * @throws BadCiphertextException thrown by the underlying cipher handler. + */ + @Override + public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException { + complete_ = true; + return writeEncryptedBlock( + bytesToEncryptStream_.toByteArray(), 0, bytesToEncryptStream_.size(), out, outOff); + } + + /** + * Return the size of the output buffer required for a processBytes plus a doFinal with an input + * size of {@code inLen} bytes. + * + * @param inLen the length of the input. + * @return the space required to accommodate a call to processBytes and doFinal with {@code inLen} + * bytes of input. + */ + @Override + public int estimateOutputSize(final int inLen) { + int outSize = 0; + + outSize += nonceLen_ + tagLenBytes_; + // include long for storing size of content + outSize += Long.SIZE / Byte.SIZE; + + // include any buffered bytes + outSize += bytesToEncryptStream_.size(); + + if (inLen > 0) { + outSize += inLen; } - /** - * Encrypt the block of bytes provide in {@code in} and copy the resulting ciphertext bytes into - * {@code out}. - * - * @param in - * the input byte array containing plaintext bytes. - * @param off - * the offset into {@code in} where the data to be encrypted starts. - * @param len - * the number of bytes to be encrypted. - * @param out - * the output buffer the encrypted bytes are copied into. - * @param outOff - * the offset into the output byte array the encrypted data starts at. - * @return the number of bytes written to {@code out} and the number of bytes processed - */ - @Override - public ProcessingSummary processBytes(final byte[] in, final int off, final int len, final byte[] out, - final int outOff) { - bytesToEncryptStream_.write(in, off, len); - return new ProcessingSummary(0, len); + return outSize; + } + + @Override + public int estimatePartialOutputSize(int inLen) { + return 0; + } + + @Override + public int estimateFinalOutputSize() { + return estimateOutputSize(0); + } + + /** + * This method encrypts the provided bytes, creates the headers for the block, and assembles the + * block containing the headers and the encrypted bytes. + * + * @param in the input byte array. + * @param inOff the offset into {@code in} array where the data to be encrypted starts. + * @param inLen the number of bytes to be encrypted. + * @param out the output buffer the encrypted bytes is copied into. + * @param outOff the offset into the output byte array the encrypted data starts at. + * @return the number of bytes written to {@code out}. + * @throws BadCiphertextException thrown by the underlying cipher handler. + */ + private int writeEncryptedBlock( + final byte[] input, final int off, final int len, final byte[] out, final int outOff) + throws BadCiphertextException { + if (out.length == 0) { + return 0; } - /** - * Finish encryption of the plaintext bytes. - * - * @param out - * space for any resulting output data. - * @param outOff - * offset into {@code out} to start copying the data at. - * @return - * number of bytes written into {@code out}. - * @throws BadCiphertextException - * thrown by the underlying cipher handler. - */ - @Override - public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException { - complete_ = true; - return writeEncryptedBlock(bytesToEncryptStream_.toByteArray(), 0, bytesToEncryptStream_.size(), out, outOff); - } + int outLen = 0; + final int seqNum = 1; // always 1 for single block case - /** - * Return the size of the output buffer required for a processBytes plus a - * doFinal with an input size of {@code inLen} bytes. - * - * @param inLen - * the length of the input. - * @return - * the space required to accommodate a call to processBytes and - * doFinal with {@code inLen} bytes of input. - */ - @Override - public int estimateOutputSize(final int inLen) { - int outSize = 0; - - outSize += nonceLen_ + tagLenBytes_; - // include long for storing size of content - outSize += Long.SIZE / Byte.SIZE; - - // include any buffered bytes - outSize += bytesToEncryptStream_.size(); - - if (inLen > 0) { - outSize += inLen; - } - - return outSize; - } + final byte[] contentAad = + Utils.generateContentAad(messageId_, Constants.SINGLE_BLOCK_STRING_ID, seqNum, len); - @Override - public int estimatePartialOutputSize(int inLen) { - return 0; - } + final byte[] nonce = getNonce(); - @Override - public int estimateFinalOutputSize() { - return estimateOutputSize(0); - } + final byte[] encryptedBytes = + new CipherHandler(encryptionKey_, Cipher.ENCRYPT_MODE, cryptoAlgo_) + .cipherData(nonce, contentAad, input, off, len); - /** - * This method encrypts the provided bytes, creates the headers for the - * block, and assembles the block containing the headers and the encrypted - * bytes. - * - * @param in - * the input byte array. - * @param inOff - * the offset into {@code in} array where the data to be - * encrypted starts. - * @param inLen - * the number of bytes to be encrypted. - * @param out - * the output buffer the encrypted bytes is copied into. - * @param outOff - * the offset into the output byte array the encrypted data - * starts at. - * @return - * the number of bytes written to {@code out}. - * @throws BadCiphertextException - * thrown by the underlying cipher handler. - */ - private int writeEncryptedBlock(final byte[] input, final int off, final int len, final byte[] out, final int outOff) - throws BadCiphertextException { - if (out.length == 0) { - return 0; - } - - int outLen = 0; - final int seqNum = 1; // always 1 for single block case - - final byte[] contentAad = Utils - .generateContentAad(messageId_, Constants.SINGLE_BLOCK_STRING_ID, seqNum, len); - - final byte[] nonce = getNonce(); - - final byte[] encryptedBytes = new CipherHandler(encryptionKey_, Cipher.ENCRYPT_MODE, cryptoAlgo_) - .cipherData(nonce, contentAad, input, off, len); - - // create the cipherblock headers now for the encrypted data - final int encryptedContentLen = encryptedBytes.length - tagLenBytes_; - final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce, encryptedContentLen); - final byte[] cipherBlockHeaderBytes = cipherBlockHeaders.toByteArray(); - - // assemble the headers and the encrypted bytes into a single block - System.arraycopy(cipherBlockHeaderBytes, 0, out, outOff + outLen, cipherBlockHeaderBytes.length); - outLen += cipherBlockHeaderBytes.length; - System.arraycopy(encryptedBytes, 0, out, outOff + outLen, encryptedBytes.length); - outLen += encryptedBytes.length; - - return outLen; - } + // create the cipherblock headers now for the encrypted data + final int encryptedContentLen = encryptedBytes.length - tagLenBytes_; + final CipherBlockHeaders cipherBlockHeaders = + new CipherBlockHeaders(nonce, encryptedContentLen); + final byte[] cipherBlockHeaderBytes = cipherBlockHeaders.toByteArray(); - private byte[] getNonce() { - final byte[] nonce = new byte[nonceLen_]; + // assemble the headers and the encrypted bytes into a single block + System.arraycopy( + cipherBlockHeaderBytes, 0, out, outOff + outLen, cipherBlockHeaderBytes.length); + outLen += cipherBlockHeaderBytes.length; + System.arraycopy(encryptedBytes, 0, out, outOff + outLen, encryptedBytes.length); + outLen += encryptedBytes.length; - // The IV for the non-framed encryption case is generated as if we were encrypting a message with a single - // frame. - nonce[nonce.length - 1] = 1; + return outLen; + } - return nonce; - } + private byte[] getNonce() { + final byte[] nonce = new byte[nonceLen_]; - @Override - public boolean isComplete() { - return complete_; - } + // The IV for the non-framed encryption case is generated as if we were encrypting a message + // with a single + // frame. + nonce[nonce.length - 1] = 1; + + return nonce; + } + + @Override + public boolean isComplete() { + return complete_; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/CipherHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/CipherHandler.java index 900805bf9..220e9ba2e 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/CipherHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/CipherHandler.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,102 +13,92 @@ package com.amazonaws.encryptionsdk.internal; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import java.security.GeneralSecurityException; import java.security.spec.AlgorithmParameterSpec; - import javax.annotation.concurrent.NotThreadSafe; import javax.crypto.Cipher; import javax.crypto.SecretKey; import javax.crypto.spec.GCMParameterSpec; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; - /** * This class provides a cryptographic cipher handler powered by an underlying block cipher. The * block cipher performs authenticated encryption of the provided bytes using Additional * Authenticated Data (AAD). * - *

- * This class implements a method called cipherData() that encrypts or decrypts a byte array by + *

This class implements a method called cipherData() that encrypts or decrypts a byte array by * calling methods on the underlying block cipher. */ @NotThreadSafe class CipherHandler { - private final int cipherMode_; - private final SecretKey key_; - private final CryptoAlgorithm cryptoAlgorithm_; - private final Cipher cipher_; + private final int cipherMode_; + private final SecretKey key_; + private final CryptoAlgorithm cryptoAlgorithm_; + private final Cipher cipher_; - /** - * Process data through the cipher. - * - *

- * This method calls the update and doFinal methods on the underlying - * cipher to complete processing of the data. - * - * @param nonce - * the nonce to be used by the underlying cipher - * @param contentAad - * the optional additional authentication data to be used by the underlying cipher - * @param content - * the content to be processed by the underlying cipher - * @param off - * the offset into content array to be processed - * @param len - * the number of bytes to process - * @return the bytes processed by the underlying cipher - * @throws AwsCryptoException if cipher initialization fails - * @throws BadCiphertextException - * if processing the data through the cipher fails - */ - public byte[] cipherData(byte[] nonce, byte[] contentAad, final byte[] content, final int off, final int len) { - if (nonce.length != cryptoAlgorithm_.getNonceLen()) { - throw new IllegalArgumentException("Invalid nonce length: " + nonce.length); - } - final AlgorithmParameterSpec spec = new GCMParameterSpec(cryptoAlgorithm_.getTagLen() * 8, nonce, 0, nonce.length); - - try { - cipher_.init(cipherMode_, key_, spec); - if (contentAad != null) { - cipher_.updateAAD(contentAad); - } - } catch (final GeneralSecurityException gsx) { - throw new AwsCryptoException(gsx); - } - try { - return cipher_.doFinal(content, off, len); - } catch (final GeneralSecurityException gsx) { - throw new BadCiphertextException(gsx); - } + /** + * Process data through the cipher. + * + *

This method calls the update and doFinal methods on the underlying + * cipher to complete processing of the data. + * + * @param nonce the nonce to be used by the underlying cipher + * @param contentAad the optional additional authentication data to be used by the underlying + * cipher + * @param content the content to be processed by the underlying cipher + * @param off the offset into content array to be processed + * @param len the number of bytes to process + * @return the bytes processed by the underlying cipher + * @throws AwsCryptoException if cipher initialization fails + * @throws BadCiphertextException if processing the data through the cipher fails + */ + public byte[] cipherData( + byte[] nonce, byte[] contentAad, final byte[] content, final int off, final int len) { + if (nonce.length != cryptoAlgorithm_.getNonceLen()) { + throw new IllegalArgumentException("Invalid nonce length: " + nonce.length); } + final AlgorithmParameterSpec spec = + new GCMParameterSpec(cryptoAlgorithm_.getTagLen() * 8, nonce, 0, nonce.length); - /** - * Create a cipher handler for processing bytes using an underlying block cipher. - * - * @param key - * the key to use in encrypting or decrypting bytes - * @param cipherMode - * the mode for processing the bytes as defined in - * {@link Cipher#init(int, java.security.Key)} - * @param cryptoAlgorithm - * the cryptography algorithm to be used by the underlying block cipher. - * @throws GeneralSecurityException - */ - CipherHandler(final SecretKey key, final int cipherMode, final CryptoAlgorithm cryptoAlgorithm) { - this.cipherMode_ = cipherMode; - this.key_ = key; - this.cryptoAlgorithm_ = cryptoAlgorithm; - this.cipher_ = buildCipherObject(cryptoAlgorithm); + try { + cipher_.init(cipherMode_, key_, spec); + if (contentAad != null) { + cipher_.updateAAD(contentAad); + } + } catch (final GeneralSecurityException gsx) { + throw new AwsCryptoException(gsx); + } + try { + return cipher_.doFinal(content, off, len); + } catch (final GeneralSecurityException gsx) { + throw new BadCiphertextException(gsx); } + } + + /** + * Create a cipher handler for processing bytes using an underlying block cipher. + * + * @param key the key to use in encrypting or decrypting bytes + * @param cipherMode the mode for processing the bytes as defined in {@link Cipher#init(int, + * java.security.Key)} + * @param cryptoAlgorithm the cryptography algorithm to be used by the underlying block cipher. + * @throws GeneralSecurityException + */ + CipherHandler(final SecretKey key, final int cipherMode, final CryptoAlgorithm cryptoAlgorithm) { + this.cipherMode_ = cipherMode; + this.key_ = key; + this.cryptoAlgorithm_ = cryptoAlgorithm; + this.cipher_ = buildCipherObject(cryptoAlgorithm); + } - private static Cipher buildCipherObject(final CryptoAlgorithm alg) { - try { - // Right now, just GCM is supported - return Cipher.getInstance("AES/GCM/NoPadding"); - } catch (final GeneralSecurityException ex) { - throw new IllegalStateException("Java does not support the requested algorithm", ex); - } + private static Cipher buildCipherObject(final CryptoAlgorithm alg) { + try { + // Right now, just GCM is supported + return Cipher.getInstance("AES/GCM/NoPadding"); + } catch (final GeneralSecurityException ex) { + throw new IllegalStateException("Java does not support the requested algorithm", ex); } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/CommittedKey.java b/src/main/java/com/amazonaws/encryptionsdk/internal/CommittedKey.java index e9e232354..fa3cd84d9 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/CommittedKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/CommittedKey.java @@ -4,104 +4,110 @@ package com.amazonaws.encryptionsdk.internal; import com.amazonaws.encryptionsdk.CryptoAlgorithm; - +import java.nio.charset.StandardCharsets; +import java.security.NoSuchAlgorithmException; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; import javax.management.openmbean.InvalidKeyException; -import java.nio.charset.StandardCharsets; -import java.security.NoSuchAlgorithmException; public final class CommittedKey { - private final SecretKey key_; - private final byte[] commitment_; - - CommittedKey(SecretKey key, byte[] commitment) { - key_ = key; - commitment_ = commitment; + private final SecretKey key_; + private final byte[] commitment_; + + CommittedKey(SecretKey key, byte[] commitment) { + key_ = key; + commitment_ = commitment; + } + + public SecretKey getKey() { + return key_; + } + + public byte[] getCommitment() { + return commitment_.clone(); + } + + /** + * The template for creating the label/info for deriving the encryption key from the data key. + * + *

Note that this value must be cloned and modified prior to use. Cloned to prevent + * modification of the template and threading issues. Modified to insert the algorithm id into the + * first two bytes. + */ + private static byte[] DERIVE_KEY_LABEL_TEMPLATE = "__DERIVEKEY".getBytes(StandardCharsets.UTF_8); + + /** + * Full label/info for deriving the key commitment value from the data key. + * + *

Unlike {@link #DERIVE_KEY_LABEL_TEMPLATE} this value does not need to be cloned or modified + * prior to use. + */ + private static byte[] COMMITKEY_LABEL = "COMMITKEY".getBytes(StandardCharsets.UTF_8); + + private static final String RAW_DATA_FORMAT = "RAW"; + private static final String HKDF_SHA_512 = "HkdfSHA512"; + private static final String HMAC_SHA_512 = "HmacSHA512"; + + /** Generates an encryption key along with associated commitment value. */ + public static CommittedKey generate(CryptoAlgorithm alg, SecretKey dataKey, byte[] nonce) + throws InvalidKeyException { + if (!alg.isCommitting()) { + throw new IllegalArgumentException("Algorithm does not support key commitment."); } - - public SecretKey getKey() { - return key_; + if (nonce.length != alg.getCommitmentNonceLength()) { + throw new IllegalArgumentException("Invalid nonce size"); + } + if (dataKey.getFormat() == null || !dataKey.getFormat().equalsIgnoreCase(RAW_DATA_FORMAT)) { + throw new IllegalArgumentException( + "Currently only RAW format keys are supported for HKDF algorithms. Actual format was " + + dataKey.getFormat()); + } + if (dataKey.getAlgorithm() == null + || !dataKey.getAlgorithm().equalsIgnoreCase(alg.getDataKeyAlgo())) { + throw new IllegalArgumentException( + "DataKey of incorrect algorithm. Expected " + + alg.getDataKeyAlgo() + + " but was " + + dataKey.getAlgorithm()); + } + final byte[] rawDataKey = dataKey.getEncoded(); + if (rawDataKey.length != alg.getDataKeyLength()) { + throw new IllegalArgumentException( + "DataKey of incorrect length. Expected " + + alg.getDataKeyLength() + + " but was " + + rawDataKey.length); } - public byte[] getCommitment() { - return commitment_.clone(); + final String macAlgorithm; + switch (alg.getKeyCommitmentAlgo_()) { + case HKDF_SHA_512: + macAlgorithm = HMAC_SHA_512; + break; + default: + throw new UnsupportedOperationException( + "Support for commitment with " + alg.getKeyCommitmentAlgo_() + " not yet built."); } - /** - * The template for creating the label/info for deriving the encryption key from the data key. - * - * Note that this value must be cloned and modified prior to use. - * Cloned to prevent modification of the template and threading issues. - * Modified to insert the algorithm id into the first two bytes. - */ - private static byte[] DERIVE_KEY_LABEL_TEMPLATE = "__DERIVEKEY".getBytes(StandardCharsets.UTF_8); - - /** - * Full label/info for deriving the key commitment value from the data key. - * - * Unlike {@link #DERIVE_KEY_LABEL_TEMPLATE} this value does not need to be cloned or modified - * prior to use. - */ - private static byte[] COMMITKEY_LABEL = "COMMITKEY".getBytes(StandardCharsets.UTF_8); - - private static final String RAW_DATA_FORMAT = "RAW"; - private static final String HKDF_SHA_512 = "HkdfSHA512"; - private static final String HMAC_SHA_512 = "HmacSHA512"; - - /** - * Generates an encryption key along with associated commitment value. - */ - public static CommittedKey generate(CryptoAlgorithm alg, SecretKey dataKey, byte[] nonce) - throws InvalidKeyException { - if (!alg.isCommitting()) { - throw new IllegalArgumentException("Algorithm does not support key commitment."); - } - if (nonce.length != alg.getCommitmentNonceLength()) { - throw new IllegalArgumentException("Invalid nonce size"); - } - if (dataKey.getFormat() == null || !dataKey.getFormat().equalsIgnoreCase(RAW_DATA_FORMAT)) { - throw new IllegalArgumentException( - "Currently only RAW format keys are supported for HKDF algorithms. Actual format was " - + dataKey.getFormat()); - } - if (dataKey.getAlgorithm() == null || !dataKey.getAlgorithm().equalsIgnoreCase(alg.getDataKeyAlgo())) { - throw new IllegalArgumentException("DataKey of incorrect algorithm. Expected " + alg.getDataKeyAlgo() + " but was " - + dataKey.getAlgorithm()); - } - final byte[] rawDataKey = dataKey.getEncoded(); - if (rawDataKey.length != alg.getDataKeyLength()) { - throw new IllegalArgumentException("DataKey of incorrect length. Expected " + alg.getDataKeyLength() + " but was " - + rawDataKey.length); - } - - final String macAlgorithm; - switch (alg.getKeyCommitmentAlgo_()) { - case HKDF_SHA_512: - macAlgorithm = HMAC_SHA_512; - break; - default: - throw new UnsupportedOperationException("Support for commitment with " + alg.getKeyCommitmentAlgo_() + " not yet built."); - } - - HmacKeyDerivationFunction kdf = null; - try { - kdf = HmacKeyDerivationFunction.getInstance(macAlgorithm); - } catch (NoSuchAlgorithmException e) { - throw new IllegalStateException(e); - } - kdf.init(rawDataKey, nonce); - - final byte[] commitment = kdf.deriveKey(COMMITKEY_LABEL, alg.getCommitmentLength()); - - // Clone to prevent modification of the master copy - final byte[] deriveKeyLabel = DERIVE_KEY_LABEL_TEMPLATE.clone(); - final short algId = alg.getValue(); - deriveKeyLabel[0] = (byte) ((algId >> 8) & 0xFF); - deriveKeyLabel[1] = (byte) (algId & 0xFF); - SecretKey ek = new SecretKeySpec(kdf.deriveKey(deriveKeyLabel, alg.getKeyLength()), alg.getKeyAlgo()); - - return new CommittedKey(ek, commitment); + HmacKeyDerivationFunction kdf = null; + try { + kdf = HmacKeyDerivationFunction.getInstance(macAlgorithm); + } catch (NoSuchAlgorithmException e) { + throw new IllegalStateException(e); } + kdf.init(rawDataKey, nonce); + + final byte[] commitment = kdf.deriveKey(COMMITKEY_LABEL, alg.getCommitmentLength()); + + // Clone to prevent modification of the master copy + final byte[] deriveKeyLabel = DERIVE_KEY_LABEL_TEMPLATE.clone(); + final short algId = alg.getValue(); + deriveKeyLabel[0] = (byte) ((algId >> 8) & 0xFF); + deriveKeyLabel[1] = (byte) (algId & 0xFF); + SecretKey ek = + new SecretKeySpec(kdf.deriveKey(deriveKeyLabel, alg.getKeyLength()), alg.getKeyAlgo()); + + return new CommittedKey(ek, commitment); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/Constants.java b/src/main/java/com/amazonaws/encryptionsdk/internal/Constants.java index 3fbf23810..08c4fedd0 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/Constants.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/Constants.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -16,54 +16,48 @@ import com.amazonaws.encryptionsdk.CryptoAlgorithm; public final class Constants { - /** - * Default length of the message identifier used to uniquely identify every - * ciphertext created by this library. - * @deprecated This value may change based on {@link CryptoAlgorithm#getMessageIdLength()} - */ - @Deprecated - public static final int MESSAGE_ID_LEN = 16; - - private Constants() { - // Prevent instantiation - } - - /** - * Marker for identifying the final frame. - */ - public static final int ENDFRAME_SEQUENCE_NUMBER = ~0; // is 0xFFFFFFFF - - /** - * The identifier for non-final frames in the framing content type. This value is used as part - * of the additional authenticated data (AAD) when encryption of content in a frame. - */ - public static final String FRAME_STRING_ID = "AWSKMSEncryptionClient Frame"; - - /** - * The identifier for the final frame in the framing content type. This value is used as part of - * the additional authenticated data (AAD) when encryption of content in a frame. - */ - public static final String FINAL_FRAME_STRING_ID = "AWSKMSEncryptionClient Final Frame"; - - /** - * The identifier for the single block content type. This value is used as part of the - * additional authenticated data (AAD) when encryption of content in a single block. - */ - public static final String SINGLE_BLOCK_STRING_ID = "AWSKMSEncryptionClient Single Block"; - - /** - * Maximum length of the content that can be encrypted in GCM mode. - */ - public static final long GCM_MAX_CONTENT_LEN = (1L << 36) - 32; - - public static final int MAX_NONCE_LENGTH = (1 << 8) - 1; - - /** - * Maximum value of an unsigned short. - */ - public static final int UNSIGNED_SHORT_MAX_VAL = (1 << 16) - 1; - - public static final long MAX_FRAME_NUMBER = (1L << 32) - 1; - - public static final String EC_PUBLIC_KEY_FIELD = "aws-crypto-public-key"; + /** + * Default length of the message identifier used to uniquely identify every ciphertext created by + * this library. + * + * @deprecated This value may change based on {@link CryptoAlgorithm#getMessageIdLength()} + */ + @Deprecated public static final int MESSAGE_ID_LEN = 16; + + private Constants() { + // Prevent instantiation + } + + /** Marker for identifying the final frame. */ + public static final int ENDFRAME_SEQUENCE_NUMBER = ~0; // is 0xFFFFFFFF + + /** + * The identifier for non-final frames in the framing content type. This value is used as part of + * the additional authenticated data (AAD) when encryption of content in a frame. + */ + public static final String FRAME_STRING_ID = "AWSKMSEncryptionClient Frame"; + + /** + * The identifier for the final frame in the framing content type. This value is used as part of + * the additional authenticated data (AAD) when encryption of content in a frame. + */ + public static final String FINAL_FRAME_STRING_ID = "AWSKMSEncryptionClient Final Frame"; + + /** + * The identifier for the single block content type. This value is used as part of the additional + * authenticated data (AAD) when encryption of content in a single block. + */ + public static final String SINGLE_BLOCK_STRING_ID = "AWSKMSEncryptionClient Single Block"; + + /** Maximum length of the content that can be encrypted in GCM mode. */ + public static final long GCM_MAX_CONTENT_LEN = (1L << 36) - 32; + + public static final int MAX_NONCE_LENGTH = (1 << 8) - 1; + + /** Maximum value of an unsigned short. */ + public static final int UNSIGNED_SHORT_MAX_VAL = (1 << 16) - 1; + + public static final long MAX_FRAME_NUMBER = (1L << 32) - 1; + + public static final String EC_PUBLIC_KEY_FIELD = "aws-crypto-public-key"; } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/CryptoHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/CryptoHandler.java index ddfb25583..f0dffee5c 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/CryptoHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/CryptoHandler.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -16,91 +16,77 @@ /** * This interface defines the contract for the implementation of encryption and decryption handlers * in this library. - * - *

- * The implementations of this interface provided in this package currently process bytes in a + * + *

The implementations of this interface provided in this package currently process bytes in a * single block mode (where all input data is processed in entirety, or in a framing mode (where * data is processed incrementally in chunks). */ public interface CryptoHandler { - /** - * Process a block of bytes from {@code in} putting the result into {@code out}. - * - * @param in - * the input byte array. - * @param inOff - * the offset into the {@code in} array where the data to be processed starts. - * @param inLen - * the number of bytes to be processed. - * @param out - * the output buffer the processed bytes go into. - * @param outOff - * the offset into the output byte array the processed data starts at. - * @return the number of bytes written to {@code out} and the number of bytes parsed. - */ - ProcessingSummary processBytes(final byte[] in, final int inOff, final int inLen, byte[] out, final int outOff); + /** + * Process a block of bytes from {@code in} putting the result into {@code out}. + * + * @param in the input byte array. + * @param inOff the offset into the {@code in} array where the data to be processed starts. + * @param inLen the number of bytes to be processed. + * @param out the output buffer the processed bytes go into. + * @param outOff the offset into the output byte array the processed data starts at. + * @return the number of bytes written to {@code out} and the number of bytes parsed. + */ + ProcessingSummary processBytes( + final byte[] in, final int inOff, final int inLen, byte[] out, final int outOff); - /** - * Finish processing of the bytes. - * - * @param out - * the output buffer for copying any remaining output data. - * @param outOff - * offset into {@code out} to start copying the output data. - * @return number of bytes written into {@code out}. - */ - int doFinal(final byte[] out, final int outOff); + /** + * Finish processing of the bytes. + * + * @param out the output buffer for copying any remaining output data. + * @param outOff offset into {@code out} to start copying the output data. + * @return number of bytes written into {@code out}. + */ + int doFinal(final byte[] out, final int outOff); - /** - * Return the size of the output buffer required for a - * {@link #processBytes(byte[], int, int, byte[], int)} plus a {@link #doFinal(byte[], int)} - * call with an input of {@code inLen} bytes. - * - *

- * Note this method is allowed to return an estimation of the output size that is greater - * than the actual size of the output. Returning an estimate that is lesser than the actual size - * of the output will result in underflow exceptions. - * - * @param inLen - * the length of the input. - * @return the space required to accommodate a call to processBytes and - * {@link #doFinal(byte[], int)} with an input of size {@code inLen} bytes. - */ - int estimateOutputSize(final int inLen); + /** + * Return the size of the output buffer required for a {@link #processBytes(byte[], int, int, + * byte[], int)} plus a {@link #doFinal(byte[], int)} call with an input of {@code inLen} bytes. + * + *

Note this method is allowed to return an estimation of the output size that is + * greater than the actual size of the output. Returning an estimate that is lesser than + * the actual size of the output will result in underflow exceptions. + * + * @param inLen the length of the input. + * @return the space required to accommodate a call to processBytes and {@link #doFinal(byte[], + * int)} with an input of size {@code inLen} bytes. + */ + int estimateOutputSize(final int inLen); - /** - * Return the size of the output buffer required for a call to - * {@link #processBytes(byte[], int, int, byte[], int)}. - * - *

- * Note this method is allowed to return an estimation of the output size that is greater - * than the actual size of the output. Returning an estimate that is lesser than the actual size - * of the output will result in underflow exceptions. - * - * @param inLen - * the length of the input. - * @return the space required to accommodate a call to - * {@link #processBytes(byte[], int, int, byte[], int)} with an input of size - * {@code inLen} bytes. - */ - int estimatePartialOutputSize(final int inLen); + /** + * Return the size of the output buffer required for a call to {@link #processBytes(byte[], int, + * int, byte[], int)}. + * + *

Note this method is allowed to return an estimation of the output size that is + * greater than the actual size of the output. Returning an estimate that is lesser than + * the actual size of the output will result in underflow exceptions. + * + * @param inLen the length of the input. + * @return the space required to accommodate a call to {@link #processBytes(byte[], int, int, + * byte[], int)} with an input of size {@code inLen} bytes. + */ + int estimatePartialOutputSize(final int inLen); - /** - * Return the size of the output buffer required for a call to {@link #doFinal(byte[], int)}. - * - *

- * Note this method is allowed to return an estimation of the output size that is greater - * than the actual size of the output. Returning an estimate that is lesser than the actual size - * of the output will result in underflow exceptions. - * - * @return the space required to accomodate a call to {@link #doFinal(byte[], int)} - */ - int estimateFinalOutputSize(); + /** + * Return the size of the output buffer required for a call to {@link #doFinal(byte[], int)}. + * + *

Note this method is allowed to return an estimation of the output size that is + * greater than the actual size of the output. Returning an estimate that is lesser than + * the actual size of the output will result in underflow exceptions. + * + * @return the space required to accomodate a call to {@link #doFinal(byte[], int)} + */ + int estimateFinalOutputSize(); - /** - * For decrypt and parsing flows returns {@code true} when this has handled as many bytes as it - * can. This usually means that it has reached the end of an object, file, or other delimited - * stream. - */ - boolean isComplete(); + /** + * For decrypt and parsing flows returns {@code true} when this has handled as many bytes as it + * can. This usually means that it has reached the end of an object, file, or other delimited + * stream. + */ + boolean isComplete(); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/DecryptionHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/DecryptionHandler.java index 8a27329e8..05a31be98 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/DecryptionHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/DecryptionHandler.java @@ -3,6 +3,15 @@ package com.amazonaws.encryptionsdk.internal; +import com.amazonaws.encryptionsdk.*; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; +import com.amazonaws.encryptionsdk.model.CiphertextFooters; +import com.amazonaws.encryptionsdk.model.CiphertextHeaders; +import com.amazonaws.encryptionsdk.model.CiphertextType; +import com.amazonaws.encryptionsdk.model.ContentType; +import com.amazonaws.encryptionsdk.model.DecryptionMaterials; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import java.security.GeneralSecurityException; import java.security.InvalidKeyException; import java.security.PublicKey; @@ -12,697 +21,692 @@ import java.util.Collections; import java.util.List; import java.util.Map; - import javax.crypto.Cipher; import javax.crypto.SecretKey; -import com.amazonaws.encryptionsdk.*; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; -import com.amazonaws.encryptionsdk.model.CiphertextFooters; -import com.amazonaws.encryptionsdk.model.CiphertextHeaders; -import com.amazonaws.encryptionsdk.model.CiphertextType; -import com.amazonaws.encryptionsdk.model.ContentType; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; -import com.amazonaws.encryptionsdk.model.DecryptionMaterials; - /** - * This class implements the CryptoHandler interface by providing methods for - * the decryption of ciphertext produced by the methods in - * {@link EncryptionHandler}. - * - *

- * This class reads and parses the values in the ciphertext headers and - * delegates the decryption of the ciphertext to the - * {@link BlockDecryptionHandler} or {@link FrameDecryptionHandler} based on the - * content type parsed in the ciphertext headers. + * This class implements the CryptoHandler interface by providing methods for the decryption of + * ciphertext produced by the methods in {@link EncryptionHandler}. + * + *

This class reads and parses the values in the ciphertext headers and delegates the decryption + * of the ciphertext to the {@link BlockDecryptionHandler} or {@link FrameDecryptionHandler} based + * on the content type parsed in the ciphertext headers. */ public class DecryptionHandler> implements MessageCryptoHandler { - private final CryptoMaterialsManager materialsManager_; - private final CommitmentPolicy commitmentPolicy_; - /** - * The maximum number of encrypted data keys to parse, if positive. - * If zero, do not limit EDKs. - */ - private final int maxEncryptedDataKeys_; - private final SignaturePolicy signaturePolicy_; - - private final CiphertextHeaders ciphertextHeaders_; - private final CiphertextFooters ciphertextFooters_; - private boolean ciphertextHeadersParsed_; - - private CryptoHandler contentCryptoHandler_; - - private DataKey dataKey_; - private SecretKey decryptionKey_; - private CryptoAlgorithm cryptoAlgo_; - private Signature trailingSig_; - - private Map encryptionContext_ = null; - - private byte[] unparsedBytes_ = new byte[0]; - private boolean complete_ = false; - - private long ciphertextSizeBound_ = -1; - private long ciphertextBytesSupplied_ = 0; - - // These ctors are private to ensure type safety - we must ensure construction using a CMM results in a - // DecryptionHandler, not a DecryptionHandler, since the CryptoMaterialsManager is not itself - // genericized. - private DecryptionHandler(final CryptoMaterialsManager materialsManager, final CommitmentPolicy commitmentPolicy, - final SignaturePolicy signaturePolicy, final int maxEncryptedDataKeys) { - Utils.assertNonNull(materialsManager, "materialsManager"); - Utils.assertNonNull(commitmentPolicy, "commitmentPolicy"); - Utils.assertNonNull(signaturePolicy, "signaturePolicy"); - - this.materialsManager_ = materialsManager; - this.commitmentPolicy_ = commitmentPolicy; - this.maxEncryptedDataKeys_ = maxEncryptedDataKeys; - this.signaturePolicy_ = signaturePolicy; - ciphertextHeaders_ = new CiphertextHeaders(); - ciphertextFooters_ = new CiphertextFooters(); + private final CryptoMaterialsManager materialsManager_; + private final CommitmentPolicy commitmentPolicy_; + /** + * The maximum number of encrypted data keys to parse, if positive. If zero, do not limit EDKs. + */ + private final int maxEncryptedDataKeys_; + + private final SignaturePolicy signaturePolicy_; + + private final CiphertextHeaders ciphertextHeaders_; + private final CiphertextFooters ciphertextFooters_; + private boolean ciphertextHeadersParsed_; + + private CryptoHandler contentCryptoHandler_; + + private DataKey dataKey_; + private SecretKey decryptionKey_; + private CryptoAlgorithm cryptoAlgo_; + private Signature trailingSig_; + + private Map encryptionContext_ = null; + + private byte[] unparsedBytes_ = new byte[0]; + private boolean complete_ = false; + + private long ciphertextSizeBound_ = -1; + private long ciphertextBytesSupplied_ = 0; + + // These ctors are private to ensure type safety - we must ensure construction using a CMM results + // in a + // DecryptionHandler, not a DecryptionHandler, since the + // CryptoMaterialsManager is not itself + // genericized. + private DecryptionHandler( + final CryptoMaterialsManager materialsManager, + final CommitmentPolicy commitmentPolicy, + final SignaturePolicy signaturePolicy, + final int maxEncryptedDataKeys) { + Utils.assertNonNull(materialsManager, "materialsManager"); + Utils.assertNonNull(commitmentPolicy, "commitmentPolicy"); + Utils.assertNonNull(signaturePolicy, "signaturePolicy"); + + this.materialsManager_ = materialsManager; + this.commitmentPolicy_ = commitmentPolicy; + this.maxEncryptedDataKeys_ = maxEncryptedDataKeys; + this.signaturePolicy_ = signaturePolicy; + ciphertextHeaders_ = new CiphertextHeaders(); + ciphertextFooters_ = new CiphertextFooters(); + } + + private DecryptionHandler( + final CryptoMaterialsManager materialsManager, + final CiphertextHeaders headers, + final CommitmentPolicy commitmentPolicy, + final SignaturePolicy signaturePolicy, + final int maxEncryptedDataKeys) + throws AwsCryptoException { + Utils.assertNonNull(materialsManager, "materialsManager"); + Utils.assertNonNull(commitmentPolicy, "commitmentPolicy"); + Utils.assertNonNull(signaturePolicy, "signaturePolicy"); + + materialsManager_ = materialsManager; + ciphertextHeaders_ = headers; + commitmentPolicy_ = commitmentPolicy; + signaturePolicy_ = signaturePolicy; + maxEncryptedDataKeys_ = maxEncryptedDataKeys; + ciphertextFooters_ = new CiphertextFooters(); + if (headers instanceof ParsedCiphertext) { + ciphertextBytesSupplied_ = ((ParsedCiphertext) headers).getOffset(); + } else { + // This is a little more expensive, hence the public create(...) methods + // that take a CiphertextHeaders instead of a ParsedCiphertext are + // deprecated. + ciphertextBytesSupplied_ = headers.toByteArray().length; } - - private DecryptionHandler(final CryptoMaterialsManager materialsManager, final CiphertextHeaders headers, - final CommitmentPolicy commitmentPolicy, final SignaturePolicy signaturePolicy, - final int maxEncryptedDataKeys) - throws AwsCryptoException - { - Utils.assertNonNull(materialsManager, "materialsManager"); - Utils.assertNonNull(commitmentPolicy, "commitmentPolicy"); - Utils.assertNonNull(signaturePolicy, "signaturePolicy"); - - materialsManager_ = materialsManager; - ciphertextHeaders_ = headers; - commitmentPolicy_ = commitmentPolicy; - signaturePolicy_ = signaturePolicy; - maxEncryptedDataKeys_ = maxEncryptedDataKeys; - ciphertextFooters_ = new CiphertextFooters(); - if (headers instanceof ParsedCiphertext) { - ciphertextBytesSupplied_ = ((ParsedCiphertext)headers).getOffset(); - } else { - // This is a little more expensive, hence the public create(...) methods - // that take a CiphertextHeaders instead of a ParsedCiphertext are - // deprecated. - ciphertextBytesSupplied_ = headers.toByteArray().length; - } - readHeaderFields(headers); - updateTrailingSignature(headers); + readHeaderFields(headers); + updateTrailingSignature(headers); + } + + /** + * Create a decryption handler using the provided master key. + * + *

Note the methods in the provided master key are used in decrypting the encrypted data key + * parsed from the ciphertext headers. + * + * @param customerMasterKeyProvider the master key provider to use in picking a master key from + * the key blobs encoded in the provided ciphertext. + * @param commitmentPolicy The commitment policy to enforce during decryption + * @param signaturePolicy The signature policy to enforce during decryption + * @param maxEncryptedDataKeys The maximum number of encrypted data keys to unwrap during + * decryption; zero indicates no maximum + * @throws AwsCryptoException if the master key is null. + */ + @SuppressWarnings("unchecked") + public static > DecryptionHandler create( + final MasterKeyProvider customerMasterKeyProvider, + final CommitmentPolicy commitmentPolicy, + final SignaturePolicy signaturePolicy, + final int maxEncryptedDataKeys) + throws AwsCryptoException { + Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider"); + + return (DecryptionHandler) + create( + new DefaultCryptoMaterialsManager(customerMasterKeyProvider), + commitmentPolicy, + signaturePolicy, + maxEncryptedDataKeys); + } + + /** + * Create a decryption handler using the provided master key and already parsed {@code headers}. + * + *

Note the methods in the provided master key are used in decrypting the encrypted data key + * parsed from the ciphertext headers. + * + * @param customerMasterKeyProvider the master key provider to use in picking a master key from + * the key blobs encoded in the provided ciphertext. + * @param headers already parsed headers which will not be passed into {@link + * #processBytes(byte[], int, int, byte[], int)} + * @param commitmentPolicy The commitment policy to enforce during decryption + * @param signaturePolicy The signature policy to enforce during decryption + * @param maxEncryptedDataKeys The maximum number of encrypted data keys to unwrap during + * decryption; zero indicates no maximum + * @throws AwsCryptoException if the master key is null. + * @deprecated This version may have to recalculate the number of bytes already parsed, which adds + * a performance penalty. Use {@link #create(CryptoMaterialsManager, ParsedCiphertext, + * CommitmentPolicy, SignaturePolicy, int)} instead, which makes the parsed byte count + * directly available instead. + */ + @SuppressWarnings("unchecked") + @Deprecated + public static > DecryptionHandler create( + final MasterKeyProvider customerMasterKeyProvider, + final CiphertextHeaders headers, + final CommitmentPolicy commitmentPolicy, + final SignaturePolicy signaturePolicy, + final int maxEncryptedDataKeys) + throws AwsCryptoException { + Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider"); + + return (DecryptionHandler) + create( + new DefaultCryptoMaterialsManager(customerMasterKeyProvider), + headers, + commitmentPolicy, + signaturePolicy, + maxEncryptedDataKeys); + } + + /** + * Create a decryption handler using the provided master key and already parsed {@code headers}. + * + *

Note the methods in the provided master key are used in decrypting the encrypted data key + * parsed from the ciphertext headers. + * + * @param customerMasterKeyProvider the master key provider to use in picking a master key from + * the key blobs encoded in the provided ciphertext. + * @param headers already parsed headers which will not be passed into {@link + * #processBytes(byte[], int, int, byte[], int)} + * @param commitmentPolicy The commitment policy to enforce during decryption + * @param signaturePolicy The signature policy to enforce during decryption + * @param maxEncryptedDataKeys The maximum number of encrypted data keys to unwrap during + * decryption; zero indicates no maximum + * @throws AwsCryptoException if the master key is null. + */ + @SuppressWarnings("unchecked") + public static > DecryptionHandler create( + final MasterKeyProvider customerMasterKeyProvider, + final ParsedCiphertext headers, + final CommitmentPolicy commitmentPolicy, + final SignaturePolicy signaturePolicy, + final int maxEncryptedDataKeys) + throws AwsCryptoException { + Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider"); + + return (DecryptionHandler) + create( + new DefaultCryptoMaterialsManager(customerMasterKeyProvider), + headers, + commitmentPolicy, + signaturePolicy, + maxEncryptedDataKeys); + } + + /** + * Create a decryption handler using the provided materials manager. + * + *

Note the methods in the provided materials manager are used in decrypting the encrypted data + * key parsed from the ciphertext headers. + * + * @param materialsManager the materials manager to use in decrypting the data key from the key + * blobs encoded in the provided ciphertext. + * @param commitmentPolicy The commitment policy to enforce during decryption + * @param signaturePolicy The signature policy to enforce during decryption + * @param maxEncryptedDataKeys The maximum number of encrypted data keys to unwrap during + * decryption; zero indicates no maximum + * @throws AwsCryptoException if the master key is null. + */ + public static DecryptionHandler create( + final CryptoMaterialsManager materialsManager, + final CommitmentPolicy commitmentPolicy, + final SignaturePolicy signaturePolicy, + final int maxEncryptedDataKeys) + throws AwsCryptoException { + return new DecryptionHandler( + materialsManager, commitmentPolicy, signaturePolicy, maxEncryptedDataKeys); + } + + /** + * Create a decryption handler using the provided materials manager and already parsed {@code + * headers}. + * + *

Note the methods in the provided materials manager are used in decrypting the encrypted data + * key parsed from the ciphertext headers. + * + * @param materialsManager the materials manager to use in decrypting the data key from the key + * blobs encoded in the provided ciphertext. + * @param headers already parsed headers which will not be passed into {@link + * #processBytes(byte[], int, int, byte[], int)} + * @param commitmentPolicy The commitment policy to enforce during decryption + * @param signaturePolicy The signature policy to enforce during decryption + * @param maxEncryptedDataKeys The maximum number of encrypted data keys to unwrap during + * decryption; zero indicates no maximum + * @throws AwsCryptoException if the master key is null. + * @deprecated This version may have to recalculate the number of bytes already parsed, which adds + * a performance penalty. Use {@link #create(CryptoMaterialsManager, ParsedCiphertext, + * CommitmentPolicy, SignaturePolicy, int)} instead, which makes the parsed byte count + * directly available instead. + */ + @Deprecated + public static DecryptionHandler create( + final CryptoMaterialsManager materialsManager, + final CiphertextHeaders headers, + final CommitmentPolicy commitmentPolicy, + final SignaturePolicy signaturePolicy, + final int maxEncryptedDataKeys) + throws AwsCryptoException { + return new DecryptionHandler( + materialsManager, headers, commitmentPolicy, signaturePolicy, maxEncryptedDataKeys); + } + + /** + * Create a decryption handler using the provided materials manager and already parsed {@code + * headers}. + * + *

Note the methods in the provided materials manager are used in decrypting the encrypted data + * key parsed from the ciphertext headers. + * + * @param materialsManager the materials manager to use in decrypting the data key from the key + * blobs encoded in the provided ciphertext. + * @param headers already parsed headers which will not be passed into {@link + * #processBytes(byte[], int, int, byte[], int)} + * @param commitmentPolicy The commitment policy to enforce during decryption + * @param signaturePolicy The signature policy to enforce during decryption + * @param maxEncryptedDataKeys The maximum number of encrypted data keys to unwrap during + * decryption; zero indicates no maximum + * @throws AwsCryptoException if the master key is null. + */ + public static DecryptionHandler create( + final CryptoMaterialsManager materialsManager, + final ParsedCiphertext headers, + final CommitmentPolicy commitmentPolicy, + final SignaturePolicy signaturePolicy, + final int maxEncryptedDataKeys) + throws AwsCryptoException { + return new DecryptionHandler( + materialsManager, headers, commitmentPolicy, signaturePolicy, maxEncryptedDataKeys); + } + + /** + * Decrypt the ciphertext bytes provided in {@code in} and copy the plaintext bytes to {@code + * out}. + * + *

This method consumes and parses the ciphertext headers. The decryption of the actual content + * is delegated to {@link BlockDecryptionHandler} or {@link FrameDecryptionHandler} based on the + * content type parsed in the ciphertext header. + * + * @param in the input byte array. + * @param off the offset into the in array where the data to be decrypted starts. + * @param len the number of bytes to be decrypted. + * @param out the output buffer the decrypted plaintext bytes go into. + * @param outOff the offset into the output byte array the decrypted data starts at. + * @return the number of bytes written to {@code out} and processed. + * @throws BadCiphertextException if the ciphertext header contains invalid entries or if the + * header integrity check fails. + * @throws AwsCryptoException if any of the offset or length arguments are negative or if the + * total bytes to decrypt exceeds the maximum allowed value. + */ + @Override + public ProcessingSummary processBytes( + final byte[] in, final int off, final int len, final byte[] out, final int outOff) + throws BadCiphertextException, AwsCryptoException { + + // We should arguably check if we are already complete_ here as other handlers + // like FrameDecryptionHandler and BlockDecryptionHandler do. + // However, adding that now could potentially break customers who have extra trailing + // bytes in their decryption streams. + // The handlers are also inconsistent in general with this check. Even those that + // do raise an exception here if already complete will not complain if + // a single call to processBytes() completes the message and provides extra trailing bytes: + // in that case they will just indicate that they didn't process the extra bytes instead. + + if (len < 0 || off < 0) { + throw new AwsCryptoException( + String.format("Invalid values for input offset: %d and length: %d", off, len)); } - /** - * Create a decryption handler using the provided master key. - * - *

- * Note the methods in the provided master key are used in decrypting the - * encrypted data key parsed from the ciphertext headers. - * - * @param customerMasterKeyProvider - * the master key provider to use in picking a master key from - * the key blobs encoded in the provided ciphertext. - * @param commitmentPolicy The commitment policy to enforce during decryption - * @param signaturePolicy The signature policy to enforce during decryption - * @param maxEncryptedDataKeys - * The maximum number of encrypted data keys to unwrap during decryption; zero indicates no maximum - * @throws AwsCryptoException - * if the master key is null. - */ - @SuppressWarnings("unchecked") - public static > DecryptionHandler create( - final MasterKeyProvider customerMasterKeyProvider, - final CommitmentPolicy commitmentPolicy, - final SignaturePolicy signaturePolicy, - final int maxEncryptedDataKeys - ) throws AwsCryptoException { - Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider"); - - return (DecryptionHandler)create(new DefaultCryptoMaterialsManager(customerMasterKeyProvider), - commitmentPolicy, signaturePolicy, maxEncryptedDataKeys); + if (in.length == 0 || len == 0) { + return ProcessingSummary.ZERO; } - /** - * Create a decryption handler using the provided master key and already parsed {@code headers}. - * - *

- * Note the methods in the provided master key are used in decrypting the encrypted data key - * parsed from the ciphertext headers. - * - * @param customerMasterKeyProvider - * the master key provider to use in picking a master key from the key blobs encoded - * in the provided ciphertext. - * @param headers - * already parsed headers which will not be passed into - * {@link #processBytes(byte[], int, int, byte[], int)} - * @param commitmentPolicy The commitment policy to enforce during decryption - * @param signaturePolicy The signature policy to enforce during decryption - * @param maxEncryptedDataKeys - * The maximum number of encrypted data keys to unwrap during decryption; zero indicates no maximum - * @throws AwsCryptoException - * if the master key is null. - * @deprecated This version may have to recalculate the number of bytes already parsed, which adds - * a performance penalty. Use {@link #create(CryptoMaterialsManager, ParsedCiphertext, - * CommitmentPolicy, SignaturePolicy, int)} instead, which makes the parsed byte count directly - * available instead. - */ - @SuppressWarnings("unchecked") - @Deprecated - public static > DecryptionHandler create( - final MasterKeyProvider customerMasterKeyProvider, final CiphertextHeaders headers, - final CommitmentPolicy commitmentPolicy, - final SignaturePolicy signaturePolicy, - final int maxEncryptedDataKeys - ) throws AwsCryptoException { - Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider"); - - return (DecryptionHandler) create(new DefaultCryptoMaterialsManager(customerMasterKeyProvider), headers, - commitmentPolicy, signaturePolicy, maxEncryptedDataKeys); + final long totalBytesToParse = unparsedBytes_.length + (long) len; + // check for integer overflow + if (totalBytesToParse > Integer.MAX_VALUE) { + throw new AwsCryptoException( + "Size of the total bytes to parse and decrypt exceeded allowed maximum:" + + Integer.MAX_VALUE); } - /** - * Create a decryption handler using the provided master key and already parsed {@code headers}. - * - *

- * Note the methods in the provided master key are used in decrypting the encrypted data key - * parsed from the ciphertext headers. - * - * @param customerMasterKeyProvider - * the master key provider to use in picking a master key from the key blobs encoded - * in the provided ciphertext. - * @param headers - * already parsed headers which will not be passed into - * {@link #processBytes(byte[], int, int, byte[], int)} - * @param commitmentPolicy The commitment policy to enforce during decryption - * @param signaturePolicy The signature policy to enforce during decryption - * @param maxEncryptedDataKeys - * The maximum number of encrypted data keys to unwrap during decryption; zero indicates no maximum - * @throws AwsCryptoException - * if the master key is null. - */ - @SuppressWarnings("unchecked") - public static > DecryptionHandler create( - final MasterKeyProvider customerMasterKeyProvider, final ParsedCiphertext headers, - final CommitmentPolicy commitmentPolicy, - final SignaturePolicy signaturePolicy, - final int maxEncryptedDataKeys - ) throws AwsCryptoException { - Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider"); - - return (DecryptionHandler) create(new DefaultCryptoMaterialsManager(customerMasterKeyProvider), headers, - commitmentPolicy, signaturePolicy, maxEncryptedDataKeys); + checkSizeBound(len); + ciphertextBytesSupplied_ += len; + + final byte[] bytesToParse = new byte[(int) totalBytesToParse]; + final int leftoverBytes = unparsedBytes_.length; + // If there were previously unparsed bytes, add them as the first + // set of bytes to be parsed in this call. + System.arraycopy(unparsedBytes_, 0, bytesToParse, 0, unparsedBytes_.length); + System.arraycopy(in, off, bytesToParse, unparsedBytes_.length, len); + + int totalParsedBytes = 0; + if (!ciphertextHeadersParsed_) { + totalParsedBytes += ciphertextHeaders_.deserialize(bytesToParse, 0, maxEncryptedDataKeys_); + // When ciphertext headers are complete, we have the data + // key and cipher mode to initialize the underlying cipher + if (ciphertextHeaders_.isComplete() == true) { + readHeaderFields(ciphertextHeaders_); + updateTrailingSignature(ciphertextHeaders_); + // reset unparsed bytes as parsing of ciphertext headers is + // complete. + unparsedBytes_ = new byte[0]; + } else { + // If there aren't enough bytes to parse ciphertext + // headers, we don't have anymore bytes to continue parsing. + // But first copy the leftover bytes to unparsed bytes. + unparsedBytes_ = Arrays.copyOfRange(bytesToParse, totalParsedBytes, bytesToParse.length); + return new ProcessingSummary(0, len); + } } - /** - * Create a decryption handler using the provided materials manager. - * - *

- * Note the methods in the provided materials manager are used in decrypting the encrypted data key - * parsed from the ciphertext headers. - * - * @param materialsManager - * the materials manager to use in decrypting the data key from the key blobs encoded - * in the provided ciphertext. - * @param commitmentPolicy The commitment policy to enforce during decryption - * @param signaturePolicy The signature policy to enforce during decryption - * @param maxEncryptedDataKeys - * The maximum number of encrypted data keys to unwrap during decryption; zero indicates no maximum - * @throws AwsCryptoException - * if the master key is null. - */ - public static DecryptionHandler create( - final CryptoMaterialsManager materialsManager, - final CommitmentPolicy commitmentPolicy, - final SignaturePolicy signaturePolicy, - final int maxEncryptedDataKeys - ) throws AwsCryptoException { - return new DecryptionHandler(materialsManager, commitmentPolicy, signaturePolicy, maxEncryptedDataKeys); + int actualOutLen = 0; + if (!contentCryptoHandler_.isComplete()) { + // if there are bytes to parse further, pass it off to underlying + // content cryptohandler. + if ((bytesToParse.length - totalParsedBytes) > 0) { + final ProcessingSummary contentResult = + contentCryptoHandler_.processBytes( + bytesToParse, + totalParsedBytes, + bytesToParse.length - totalParsedBytes, + out, + outOff); + updateTrailingSignature(bytesToParse, totalParsedBytes, contentResult.getBytesProcessed()); + actualOutLen = contentResult.getBytesWritten(); + totalParsedBytes += contentResult.getBytesProcessed(); + } + if (contentCryptoHandler_.isComplete()) { + actualOutLen += contentCryptoHandler_.doFinal(out, outOff + actualOutLen); + } } - /** - * Create a decryption handler using the provided materials manager and already parsed {@code headers}. - * - *

- * Note the methods in the provided materials manager are used in decrypting the encrypted data key - * parsed from the ciphertext headers. - * - * @param materialsManager - * the materials manager to use in decrypting the data key from the key blobs encoded - * in the provided ciphertext. - * @param headers - * already parsed headers which will not be passed into - * {@link #processBytes(byte[], int, int, byte[], int)} - * @param commitmentPolicy The commitment policy to enforce during decryption - * @param signaturePolicy The signature policy to enforce during decryption - * @param maxEncryptedDataKeys - * The maximum number of encrypted data keys to unwrap during decryption; zero indicates no maximum - * @throws AwsCryptoException - * if the master key is null. - * @deprecated This version may have to recalculate the number of bytes already parsed, which adds - * a performance penalty. Use {@link #create(CryptoMaterialsManager, ParsedCiphertext, - * CommitmentPolicy, SignaturePolicy, int)} instead, which makes the parsed byte count directly - * available instead. - */ - @Deprecated - public static DecryptionHandler create( - final CryptoMaterialsManager materialsManager, final CiphertextHeaders headers, - final CommitmentPolicy commitmentPolicy, - final SignaturePolicy signaturePolicy, - final int maxEncryptedDataKeys - ) throws AwsCryptoException { - return new DecryptionHandler(materialsManager, headers, commitmentPolicy, signaturePolicy, maxEncryptedDataKeys); - } - - /** - * Create a decryption handler using the provided materials manager and already parsed {@code headers}. - * - *

- * Note the methods in the provided materials manager are used in decrypting the encrypted data key - * parsed from the ciphertext headers. - * - * @param materialsManager - * the materials manager to use in decrypting the data key from the key blobs encoded - * in the provided ciphertext. - * @param headers - * already parsed headers which will not be passed into - * {@link #processBytes(byte[], int, int, byte[], int)} - * @param commitmentPolicy The commitment policy to enforce during decryption - * @param signaturePolicy The signature policy to enforce during decryption - * @param maxEncryptedDataKeys - * The maximum number of encrypted data keys to unwrap during decryption; zero indicates no maximum - * @throws AwsCryptoException - * if the master key is null. - */ - public static DecryptionHandler create( - final CryptoMaterialsManager materialsManager, final ParsedCiphertext headers, - final CommitmentPolicy commitmentPolicy, - final SignaturePolicy signaturePolicy, - final int maxEncryptedDataKeys - ) throws AwsCryptoException { - return new DecryptionHandler(materialsManager, headers, commitmentPolicy, signaturePolicy, maxEncryptedDataKeys); - } - - /** - * Decrypt the ciphertext bytes provided in {@code in} and copy the plaintext bytes to - * {@code out}. - * - *

- * This method consumes and parses the ciphertext headers. The decryption of the actual content - * is delegated to {@link BlockDecryptionHandler} or {@link FrameDecryptionHandler} based on the - * content type parsed in the ciphertext header. - * - * @param in - * the input byte array. - * @param off - * the offset into the in array where the data to be decrypted starts. - * @param len - * the number of bytes to be decrypted. - * @param out - * the output buffer the decrypted plaintext bytes go into. - * @param outOff - * the offset into the output byte array the decrypted data starts at. - * @return the number of bytes written to {@code out} and processed. - * - * @throws BadCiphertextException - * if the ciphertext header contains invalid entries or if the header integrity - * check fails. - * @throws AwsCryptoException - * if any of the offset or length arguments are negative or if the total bytes to - * decrypt exceeds the maximum allowed value. - */ - @Override - public ProcessingSummary processBytes(final byte[] in, final int off, final int len, final byte[] out, - final int outOff) - throws BadCiphertextException, AwsCryptoException { - - // We should arguably check if we are already complete_ here as other handlers - // like FrameDecryptionHandler and BlockDecryptionHandler do. - // However, adding that now could potentially break customers who have extra trailing - // bytes in their decryption streams. - // The handlers are also inconsistent in general with this check. Even those that - // do raise an exception here if already complete will not complain if - // a single call to processBytes() completes the message and provides extra trailing bytes: - // in that case they will just indicate that they didn't process the extra bytes instead. - - if (len < 0 || off < 0) { - throw new AwsCryptoException(String.format( - "Invalid values for input offset: %d and length: %d", off, len)); - } - - if (in.length == 0 || len == 0) { - return ProcessingSummary.ZERO; - } - - final long totalBytesToParse = unparsedBytes_.length + (long) len; - // check for integer overflow - if (totalBytesToParse > Integer.MAX_VALUE) { - throw new AwsCryptoException( - "Size of the total bytes to parse and decrypt exceeded allowed maximum:" + Integer.MAX_VALUE); - } - - checkSizeBound(len); - ciphertextBytesSupplied_ += len; - - final byte[] bytesToParse = new byte[(int) totalBytesToParse]; - final int leftoverBytes = unparsedBytes_.length; - // If there were previously unparsed bytes, add them as the first - // set of bytes to be parsed in this call. - System.arraycopy(unparsedBytes_, 0, bytesToParse, 0, unparsedBytes_.length); - System.arraycopy(in, off, bytesToParse, unparsedBytes_.length, len); - - int totalParsedBytes = 0; - if (!ciphertextHeadersParsed_) { - totalParsedBytes += ciphertextHeaders_.deserialize(bytesToParse, 0, maxEncryptedDataKeys_); - // When ciphertext headers are complete, we have the data - // key and cipher mode to initialize the underlying cipher - if (ciphertextHeaders_.isComplete() == true) { - readHeaderFields(ciphertextHeaders_); - updateTrailingSignature(ciphertextHeaders_); - // reset unparsed bytes as parsing of ciphertext headers is - // complete. - unparsedBytes_ = new byte[0]; - } else { - // If there aren't enough bytes to parse ciphertext - // headers, we don't have anymore bytes to continue parsing. - // But first copy the leftover bytes to unparsed bytes. - unparsedBytes_ = Arrays.copyOfRange(bytesToParse, totalParsedBytes, bytesToParse.length); - return new ProcessingSummary(0, len); - } - } - - int actualOutLen = 0; - if (!contentCryptoHandler_.isComplete()) { - // if there are bytes to parse further, pass it off to underlying - // content cryptohandler. - if ((bytesToParse.length - totalParsedBytes) > 0) { - final ProcessingSummary contentResult = contentCryptoHandler_.processBytes(bytesToParse, - totalParsedBytes, bytesToParse.length - totalParsedBytes, - out, outOff); - updateTrailingSignature(bytesToParse, totalParsedBytes, contentResult.getBytesProcessed()); - actualOutLen = contentResult.getBytesWritten(); - totalParsedBytes += contentResult.getBytesProcessed(); - } - if (contentCryptoHandler_.isComplete()) { - actualOutLen += contentCryptoHandler_.doFinal(out, outOff + actualOutLen); - } - } - - if (contentCryptoHandler_.isComplete() ) { - // If the crypto algorithm contains trailing signature, we will need to verify - // the footer of the message. - if (cryptoAlgo_.getTrailingSignatureLength() > 0) { - totalParsedBytes += ciphertextFooters_.deserialize(bytesToParse, totalParsedBytes); - if (ciphertextFooters_.isComplete()) { - // reset unparsed bytes as parsing of the ciphertext footer is - // complete. - // This isn't strictly necessary since processing any further data - // should be an error. - unparsedBytes_ = new byte[0]; - - try { - if (!trailingSig_.verify(ciphertextFooters_.getMAuth())) { - throw new BadCiphertextException("Bad trailing signature"); - } - } catch (final SignatureException ex) { - throw new BadCiphertextException("Bad trailing signature", ex); - } - complete_ = true; - } else { - // If there aren't enough bytes to parse the ciphertext - // footer, we don't have any more bytes to continue parsing. - // But first copy the leftover bytes to unparsed bytes. - unparsedBytes_ = Arrays.copyOfRange(bytesToParse, totalParsedBytes, bytesToParse.length); - return new ProcessingSummary(actualOutLen, len); - } - } else { - complete_ = true; + if (contentCryptoHandler_.isComplete()) { + // If the crypto algorithm contains trailing signature, we will need to verify + // the footer of the message. + if (cryptoAlgo_.getTrailingSignatureLength() > 0) { + totalParsedBytes += ciphertextFooters_.deserialize(bytesToParse, totalParsedBytes); + if (ciphertextFooters_.isComplete()) { + // reset unparsed bytes as parsing of the ciphertext footer is + // complete. + // This isn't strictly necessary since processing any further data + // should be an error. + unparsedBytes_ = new byte[0]; + + try { + if (!trailingSig_.verify(ciphertextFooters_.getMAuth())) { + throw new BadCiphertextException("Bad trailing signature"); } + } catch (final SignatureException ex) { + throw new BadCiphertextException("Bad trailing signature", ex); + } + complete_ = true; + } else { + // If there aren't enough bytes to parse the ciphertext + // footer, we don't have any more bytes to continue parsing. + // But first copy the leftover bytes to unparsed bytes. + unparsedBytes_ = Arrays.copyOfRange(bytesToParse, totalParsedBytes, bytesToParse.length); + return new ProcessingSummary(actualOutLen, len); } - return new ProcessingSummary(actualOutLen, totalParsedBytes - leftoverBytes); + } else { + complete_ = true; + } + } + return new ProcessingSummary(actualOutLen, totalParsedBytes - leftoverBytes); + } + + /** + * Finish processing of the bytes. + * + * @param out space for any resulting output data. + * @param outOff offset into {@code out} to start copying the data at. + * @return number of bytes written into {@code out}. + * @throws BadCiphertextException if the bytes do not decrypt correctly. + */ + @Override + public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException { + // This is an unfortunate special case we have to support for backwards-compatibility. + if (ciphertextBytesSupplied_ == 0) { + return 0; } - /** - * Finish processing of the bytes. - * - * @param out - * space for any resulting output data. - * @param outOff - * offset into {@code out} to start copying the data at. - * @return - * number of bytes written into {@code out}. - * @throws BadCiphertextException - * if the bytes do not decrypt correctly. - */ - @Override - public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException { - // This is an unfortunate special case we have to support for backwards-compatibility. - if (ciphertextBytesSupplied_ == 0) { - return 0; - } + // check if cryptohandler for content has been created. There are cases + // when it might not have been created such as when doFinal() is called + // before the ciphertext headers are fully received and parsed. + if (contentCryptoHandler_ == null) { + throw new BadCiphertextException("Unable to process entire ciphertext."); + } else { + int result = contentCryptoHandler_.doFinal(out, outOff); - // check if cryptohandler for content has been created. There are cases - // when it might not have been created such as when doFinal() is called - // before the ciphertext headers are fully received and parsed. - if (contentCryptoHandler_ == null) { - throw new BadCiphertextException("Unable to process entire ciphertext."); - } else { - int result = contentCryptoHandler_.doFinal(out, outOff); + if (!complete_) { + throw new BadCiphertextException("Unable to process entire ciphertext."); + } - if (!complete_) { - throw new BadCiphertextException("Unable to process entire ciphertext."); - } - - return result; - } + return result; } - - /** - * Return the size of the output buffer required for a - * processBytes plus a doFinal with an input of - * inLen bytes. - * - * @param inLen - * the length of the input. - * @return - * the space required to accommodate a call to processBytes and - * doFinal with input of size {@code inLen} bytes. - */ - @Override - public int estimateOutputSize(final int inLen) { - if (contentCryptoHandler_ != null) { - return contentCryptoHandler_.estimateOutputSize(inLen); - } else { - return (inLen > 0) ? inLen : 0; - } + } + + /** + * Return the size of the output buffer required for a processBytes plus a + * doFinal with an input of inLen bytes. + * + * @param inLen the length of the input. + * @return the space required to accommodate a call to processBytes and doFinal with input of size + * {@code inLen} bytes. + */ + @Override + public int estimateOutputSize(final int inLen) { + if (contentCryptoHandler_ != null) { + return contentCryptoHandler_.estimateOutputSize(inLen); + } else { + return (inLen > 0) ? inLen : 0; } - - @Override - public int estimatePartialOutputSize(int inLen) { - if (contentCryptoHandler_ != null) { - return contentCryptoHandler_.estimatePartialOutputSize(inLen); - } else { - return (inLen > 0) ? inLen : 0; - } + } + + @Override + public int estimatePartialOutputSize(int inLen) { + if (contentCryptoHandler_ != null) { + return contentCryptoHandler_.estimatePartialOutputSize(inLen); + } else { + return (inLen > 0) ? inLen : 0; } - - @Override - public int estimateFinalOutputSize() { - if (contentCryptoHandler_ != null) { - return contentCryptoHandler_.estimateFinalOutputSize(); - } else { - return 0; - } + } + + @Override + public int estimateFinalOutputSize() { + if (contentCryptoHandler_ != null) { + return contentCryptoHandler_.estimateFinalOutputSize(); + } else { + return 0; } - - /** - * Return the encryption context. This value is parsed from the ciphertext. - * - * @return - * the key-value map containing the encryption client. - */ - @Override - public Map getEncryptionContext() { - return encryptionContext_; + } + + /** + * Return the encryption context. This value is parsed from the ciphertext. + * + * @return the key-value map containing the encryption client. + */ + @Override + public Map getEncryptionContext() { + return encryptionContext_; + } + + private void checkSizeBound(long additionalBytes) { + if (ciphertextSizeBound_ != -1 + && ciphertextBytesSupplied_ + additionalBytes > ciphertextSizeBound_) { + throw new IllegalStateException("Ciphertext size exceeds size bound"); } + } - private void checkSizeBound(long additionalBytes) { - if (ciphertextSizeBound_ != -1 && ciphertextBytesSupplied_ + additionalBytes > ciphertextSizeBound_) { - throw new IllegalStateException("Ciphertext size exceeds size bound"); - } + @Override + public void setMaxInputLength(long size) { + if (size < 0) { + throw Utils.cannotBeNegative("Max input length"); } - @Override - public void setMaxInputLength(long size) { - if (size < 0) { - throw Utils.cannotBeNegative("Max input length"); - } - - if (ciphertextSizeBound_ == -1 || ciphertextSizeBound_ > size) { - ciphertextSizeBound_ = size; - } - - // check that we haven't already exceeded the limit - checkSizeBound(0); + if (ciphertextSizeBound_ == -1 || ciphertextSizeBound_ > size) { + ciphertextSizeBound_ = size; } - long getMaxInputLength() { - return ciphertextSizeBound_; + // check that we haven't already exceeded the limit + checkSizeBound(0); + } + + long getMaxInputLength() { + return ciphertextSizeBound_; + } + + /** + * Check integrity of the header bytes by processing the parsed MAC tag in the headers through the + * cipher. + * + * @param ciphertextHeaders the ciphertext headers object whose integrity needs to be checked. + * @return true if the integrity of the header is intact; false otherwise. + */ + private void verifyHeaderIntegrity(final CiphertextHeaders ciphertextHeaders) + throws BadCiphertextException { + final CipherHandler cipherHandler = + new CipherHandler(decryptionKey_, Cipher.DECRYPT_MODE, cryptoAlgo_); + + try { + final byte[] headerTag = ciphertextHeaders.getHeaderTag(); + cipherHandler.cipherData( + ciphertextHeaders.getHeaderNonce(), + ciphertextHeaders.serializeAuthenticatedFields(), + headerTag, + 0, + headerTag.length); + } catch (BadCiphertextException e) { + throw new BadCiphertextException("Header integrity check failed.", e); } - - /** - * Check integrity of the header bytes by processing the parsed MAC tag in - * the headers through the cipher. - * - * @param ciphertextHeaders - * the ciphertext headers object whose integrity needs to be - * checked. - * @return - * true if the integrity of the header is intact; false otherwise. - */ - private void verifyHeaderIntegrity(final CiphertextHeaders ciphertextHeaders) throws BadCiphertextException { - final CipherHandler cipherHandler = new CipherHandler(decryptionKey_, Cipher.DECRYPT_MODE, cryptoAlgo_); - - try { - final byte[] headerTag = ciphertextHeaders.getHeaderTag(); - cipherHandler.cipherData(ciphertextHeaders.getHeaderNonce(), - ciphertextHeaders.serializeAuthenticatedFields(), - headerTag, 0, headerTag.length); - } catch (BadCiphertextException e) { - throw new BadCiphertextException("Header integrity check failed.", e); - } + } + + /** + * Read the fields in the ciphertext headers to populate the corresponding instance variables used + * during decryption. + * + * @param ciphertextHeaders the ciphertext headers object to read. + */ + @SuppressWarnings("unchecked") + private void readHeaderFields(final CiphertextHeaders ciphertextHeaders) { + cryptoAlgo_ = ciphertextHeaders.getCryptoAlgoId(); + + final CiphertextType ciphertextType = ciphertextHeaders.getType(); + if (ciphertextType != CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA) { + throw new BadCiphertextException("Invalid type in ciphertext."); } - /** - * Read the fields in the ciphertext headers to populate the corresponding - * instance variables used during decryption. - * - * @param ciphertextHeaders - * the ciphertext headers object to read. - */ - @SuppressWarnings("unchecked") - private void readHeaderFields(final CiphertextHeaders ciphertextHeaders) { - cryptoAlgo_ = ciphertextHeaders.getCryptoAlgoId(); - - final CiphertextType ciphertextType = ciphertextHeaders.getType(); - if (ciphertextType != CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA) { - throw new BadCiphertextException("Invalid type in ciphertext."); - } - - final byte[] messageId = ciphertextHeaders.getMessageId(); - - if (!commitmentPolicy_.algorithmAllowedForDecrypt(cryptoAlgo_)) { - throw new AwsCryptoException("Configuration conflict. " + - "Cannot decrypt message with ID " + messageId + " due to CommitmentPolicy " + - commitmentPolicy_ + " requiring only committed messages. Algorithm ID was " + - cryptoAlgo_ + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); - } - - if (maxEncryptedDataKeys_ > 0 && ciphertextHeaders_.getEncryptedKeyBlobCount() > maxEncryptedDataKeys_) { - throw new AwsCryptoException("Ciphertext encrypted data keys exceed maxEncryptedDataKeys"); - } - - if (!signaturePolicy_.algorithmAllowedForDecrypt(cryptoAlgo_)) { - throw new AwsCryptoException("Configuration conflict. " + - "Cannot decrypt message with ID " + messageId + " because AwsCrypto.createUnsignedMessageDecryptingStream() " + - " accepts only unsigned messages. Algorithm ID was " + - cryptoAlgo_ + "."); - } - - encryptionContext_ = ciphertextHeaders.getEncryptionContextMap(); + final byte[] messageId = ciphertextHeaders.getMessageId(); + + if (!commitmentPolicy_.algorithmAllowedForDecrypt(cryptoAlgo_)) { + throw new AwsCryptoException( + "Configuration conflict. " + + "Cannot decrypt message with ID " + + messageId + + " due to CommitmentPolicy " + + commitmentPolicy_ + + " requiring only committed messages. Algorithm ID was " + + cryptoAlgo_ + + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); + } - DecryptionMaterialsRequest request = DecryptionMaterialsRequest.newBuilder() - .setAlgorithm(cryptoAlgo_) - .setEncryptionContext(encryptionContext_) - .setEncryptedDataKeys(ciphertextHeaders.getEncryptedKeyBlobs()) - .build(); + if (maxEncryptedDataKeys_ > 0 + && ciphertextHeaders_.getEncryptedKeyBlobCount() > maxEncryptedDataKeys_) { + throw new AwsCryptoException("Ciphertext encrypted data keys exceed maxEncryptedDataKeys"); + } - DecryptionMaterials result = materialsManager_.decryptMaterials(request); + if (!signaturePolicy_.algorithmAllowedForDecrypt(cryptoAlgo_)) { + throw new AwsCryptoException( + "Configuration conflict. " + + "Cannot decrypt message with ID " + + messageId + + " because AwsCrypto.createUnsignedMessageDecryptingStream() " + + " accepts only unsigned messages. Algorithm ID was " + + cryptoAlgo_ + + "."); + } - //noinspection unchecked - dataKey_ = (DataKey)result.getDataKey(); - PublicKey trailingPublicKey = result.getTrailingSignatureKey(); + encryptionContext_ = ciphertextHeaders.getEncryptionContextMap(); - try { - decryptionKey_ = cryptoAlgo_.getEncryptionKeyFromDataKey(dataKey_.getKey(), ciphertextHeaders); - } catch (final InvalidKeyException ex) { - throw new AwsCryptoException(ex); - } + DecryptionMaterialsRequest request = + DecryptionMaterialsRequest.newBuilder() + .setAlgorithm(cryptoAlgo_) + .setEncryptionContext(encryptionContext_) + .setEncryptedDataKeys(ciphertextHeaders.getEncryptedKeyBlobs()) + .build(); - if (cryptoAlgo_.getTrailingSignatureLength() > 0) { - Utils.assertNonNull(trailingPublicKey, "trailing public key"); + DecryptionMaterials result = materialsManager_.decryptMaterials(request); - TrailingSignatureAlgorithm trailingSignatureAlgorithm - = TrailingSignatureAlgorithm.forCryptoAlgorithm(cryptoAlgo_); + //noinspection unchecked + dataKey_ = (DataKey) result.getDataKey(); + PublicKey trailingPublicKey = result.getTrailingSignatureKey(); - try { - trailingSig_ = Signature.getInstance( - trailingSignatureAlgorithm.getHashAndSignAlgorithm() - ); + try { + decryptionKey_ = + cryptoAlgo_.getEncryptionKeyFromDataKey(dataKey_.getKey(), ciphertextHeaders); + } catch (final InvalidKeyException ex) { + throw new AwsCryptoException(ex); + } - trailingSig_.initVerify(trailingPublicKey); - } catch (GeneralSecurityException e) { - throw new AwsCryptoException(e); - } - } else { - if (trailingPublicKey != null) { - throw new AwsCryptoException("Unexpected trailing signature key in context"); - } + if (cryptoAlgo_.getTrailingSignatureLength() > 0) { + Utils.assertNonNull(trailingPublicKey, "trailing public key"); - trailingSig_ = null; - } + TrailingSignatureAlgorithm trailingSignatureAlgorithm = + TrailingSignatureAlgorithm.forCryptoAlgorithm(cryptoAlgo_); - final ContentType contentType = ciphertextHeaders.getContentType(); - - final short nonceLen = ciphertextHeaders.getNonceLength(); - final int frameLen = ciphertextHeaders.getFrameLength(); - - verifyHeaderIntegrity(ciphertextHeaders); - - switch (contentType) { - case FRAME: - contentCryptoHandler_ = new FrameDecryptionHandler(decryptionKey_, (byte) nonceLen, cryptoAlgo_, - messageId, frameLen); - break; - case SINGLEBLOCK: - contentCryptoHandler_ = new BlockDecryptionHandler(decryptionKey_, (byte) nonceLen, cryptoAlgo_, - messageId); - break; - default: - // should never get here because an invalid content type is - // detected when parsing. - break; - } + try { + trailingSig_ = Signature.getInstance(trailingSignatureAlgorithm.getHashAndSignAlgorithm()); - ciphertextHeadersParsed_ = true; - } + trailingSig_.initVerify(trailingPublicKey); + } catch (GeneralSecurityException e) { + throw new AwsCryptoException(e); + } + } else { + if (trailingPublicKey != null) { + throw new AwsCryptoException("Unexpected trailing signature key in context"); + } - private void updateTrailingSignature(final CiphertextHeaders headers) { - if (trailingSig_ != null) { - final byte[] reserializedHeaders = headers.toByteArray(); - updateTrailingSignature(reserializedHeaders, 0, reserializedHeaders.length); - } + trailingSig_ = null; } - private void updateTrailingSignature(byte[] input, int offset, int len) { - if (trailingSig_ != null) { - try { - trailingSig_.update(input, offset, len); - } catch (final SignatureException ex) { - throw new AwsCryptoException(ex); - } - } + final ContentType contentType = ciphertextHeaders.getContentType(); + + final short nonceLen = ciphertextHeaders.getNonceLength(); + final int frameLen = ciphertextHeaders.getFrameLength(); + + verifyHeaderIntegrity(ciphertextHeaders); + + switch (contentType) { + case FRAME: + contentCryptoHandler_ = + new FrameDecryptionHandler( + decryptionKey_, (byte) nonceLen, cryptoAlgo_, messageId, frameLen); + break; + case SINGLEBLOCK: + contentCryptoHandler_ = + new BlockDecryptionHandler(decryptionKey_, (byte) nonceLen, cryptoAlgo_, messageId); + break; + default: + // should never get here because an invalid content type is + // detected when parsing. + break; } - @Override - public CiphertextHeaders getHeaders() { - return ciphertextHeaders_; - } + ciphertextHeadersParsed_ = true; + } - @Override - public List getMasterKeys() { - return Collections.singletonList(dataKey_.getMasterKey()); + private void updateTrailingSignature(final CiphertextHeaders headers) { + if (trailingSig_ != null) { + final byte[] reserializedHeaders = headers.toByteArray(); + updateTrailingSignature(reserializedHeaders, 0, reserializedHeaders.length); } - - @Override - public boolean isComplete() { - return complete_; + } + + private void updateTrailingSignature(byte[] input, int offset, int len) { + if (trailingSig_ != null) { + try { + trailingSig_.update(input, offset, len); + } catch (final SignatureException ex) { + throw new AwsCryptoException(ex); + } } + } + + @Override + public CiphertextHeaders getHeaders() { + return ciphertextHeaders_; + } + + @Override + public List getMasterKeys() { + return Collections.singletonList(dataKey_.getMasterKey()); + } + + @Override + public boolean isComplete() { + return complete_; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/EncryptionContextSerializer.java b/src/main/java/com/amazonaws/encryptionsdk/internal/EncryptionContextSerializer.java index d78864b6f..c6ec54d38 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/EncryptionContextSerializer.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/EncryptionContextSerializer.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,6 +13,7 @@ package com.amazonaws.encryptionsdk.internal; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import java.nio.BufferOverflowException; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; @@ -30,177 +31,173 @@ import java.util.SortedMap; import java.util.TreeMap; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; - /** - * This class provides methods that serialize and deserialize the encryption - * context provided as a map containing key-value pairs comprised of strings. + * This class provides methods that serialize and deserialize the encryption context provided as a + * map containing key-value pairs comprised of strings. */ public class EncryptionContextSerializer { - private EncryptionContextSerializer() { - // Prevent instantiation + private EncryptionContextSerializer() { + // Prevent instantiation + } + + /** + * Serialize the encryption context provided as a map containing key-value pairs comprised of + * strings into a byte array. + * + * @param encryptionContext the map containing the encryption context to serialize. + * @return serialized bytes of the encryption context. + */ + public static byte[] serialize(Map encryptionContext) { + if (encryptionContext == null) return null; + + if (encryptionContext.size() == 0) { + return new byte[0]; } - /** - * Serialize the encryption context provided as a map containing key-value - * pairs comprised of strings into a byte array. - * - * @param encryptionContext - * the map containing the encryption context to serialize. - * @return - * serialized bytes of the encryption context. - */ - public static byte[] serialize(Map encryptionContext) { - if (encryptionContext == null) - return null; - - if (encryptionContext.size() == 0) { - return new byte[0]; + // Make sure we don't accidentally overwrite anything. + encryptionContext = Collections.unmodifiableMap(encryptionContext); + + if (encryptionContext.size() > Short.MAX_VALUE) { + throw new AwsCryptoException( + "The number of entries in encryption context exceeds the allowed maximum " + + Short.MAX_VALUE); + } + + final ByteBuffer result = ByteBuffer.allocate(Short.MAX_VALUE); + result.order(ByteOrder.BIG_ENDIAN); + // write the number of key-value entries first + result.putShort((short) encryptionContext.size()); + + try { + final CharsetEncoder encoder = StandardCharsets.UTF_8.newEncoder(); + + // ensure all failures in encoder are reported. + encoder.onMalformedInput(CodingErrorAction.REPORT); + encoder.onUnmappableCharacter(CodingErrorAction.REPORT); + + final SortedMap binaryEntries = + new TreeMap<>(new Utils.ComparingByteBuffers()); + for (Entry mapEntry : encryptionContext.entrySet()) { + if (mapEntry.getKey() == null || mapEntry.getValue() == null) { + throw new AwsCryptoException( + "All keys and values in encryption context must be non-null."); } - // Make sure we don't accidentally overwrite anything. - encryptionContext = Collections.unmodifiableMap(encryptionContext); + if (mapEntry.getKey().isEmpty() || mapEntry.getValue().isEmpty()) { + throw new AwsCryptoException( + "All keys and values in encryption context must be non-empty."); + } + + final ByteBuffer keyBytes = encoder.encode(CharBuffer.wrap(mapEntry.getKey())); + final ByteBuffer valueBytes = encoder.encode(CharBuffer.wrap(mapEntry.getValue())); - if (encryptionContext.size() > Short.MAX_VALUE) { - throw new AwsCryptoException( - "The number of entries in encryption context exceeds the allowed maximum " + Short.MAX_VALUE); + // check for duplicate entries. + if (binaryEntries.put(keyBytes, valueBytes) != null) { + throw new AwsCryptoException("Encryption context contains duplicate entries."); } - final ByteBuffer result = ByteBuffer.allocate(Short.MAX_VALUE); - result.order(ByteOrder.BIG_ENDIAN); - // write the number of key-value entries first - result.putShort((short) encryptionContext.size()); - - try { - final CharsetEncoder encoder = StandardCharsets.UTF_8.newEncoder(); - - // ensure all failures in encoder are reported. - encoder.onMalformedInput(CodingErrorAction.REPORT); - encoder.onUnmappableCharacter(CodingErrorAction.REPORT); - - final SortedMap binaryEntries = new TreeMap<>(new Utils.ComparingByteBuffers()); - for (Entry mapEntry : encryptionContext.entrySet()) { - if (mapEntry.getKey() == null || mapEntry.getValue() == null) { - throw new AwsCryptoException( - "All keys and values in encryption context must be non-null."); - } - - if (mapEntry.getKey().isEmpty() || mapEntry.getValue().isEmpty()) { - throw new AwsCryptoException( - "All keys and values in encryption context must be non-empty."); - } - - final ByteBuffer keyBytes = encoder.encode(CharBuffer.wrap(mapEntry.getKey())); - final ByteBuffer valueBytes = encoder.encode(CharBuffer.wrap(mapEntry.getValue())); - - // check for duplicate entries. - if (binaryEntries.put(keyBytes, valueBytes) != null) { - throw new AwsCryptoException("Encryption context contains duplicate entries."); - } - - if (keyBytes.limit() > Short.MAX_VALUE || valueBytes.limit() > Short.MAX_VALUE) { - throw new AwsCryptoException( - "All keys and values in encryption context must be shorter than " + Short.MAX_VALUE); - } - } - - for (final Entry entry : binaryEntries.entrySet()) { - // actual serialization happens here - result.putShort((short) entry.getKey().limit()); - result.put(entry.getKey()); - result.putShort((short) entry.getValue().limit()); - result.put(entry.getValue()); - } - - // get and return the bytes that have been serialized - Utils.flip(result); - final byte[] encryptionContextBytes = new byte[result.limit()]; - result.get(encryptionContextBytes); - - return encryptionContextBytes; - } catch (CharacterCodingException e) { - throw new IllegalArgumentException("Encryption context contains an invalid unicode character"); - } catch (BufferOverflowException e) { - throw new AwsCryptoException( - "The number of bytes in encryption context exceeds the allowed maximum " + Short.MAX_VALUE, - e); + if (keyBytes.limit() > Short.MAX_VALUE || valueBytes.limit() > Short.MAX_VALUE) { + throw new AwsCryptoException( + "All keys and values in encryption context must be shorter than " + Short.MAX_VALUE); } + } + + for (final Entry entry : binaryEntries.entrySet()) { + // actual serialization happens here + result.putShort((short) entry.getKey().limit()); + result.put(entry.getKey()); + result.putShort((short) entry.getValue().limit()); + result.put(entry.getValue()); + } + + // get and return the bytes that have been serialized + Utils.flip(result); + final byte[] encryptionContextBytes = new byte[result.limit()]; + result.get(encryptionContextBytes); + + return encryptionContextBytes; + } catch (CharacterCodingException e) { + throw new IllegalArgumentException( + "Encryption context contains an invalid unicode character"); + } catch (BufferOverflowException e) { + throw new AwsCryptoException( + "The number of bytes in encryption context exceeds the allowed maximum " + + Short.MAX_VALUE, + e); } + } + + /** + * Deserialize the provided byte array into a map containing key-value pairs comprised of strings. + * + * @param b the bytes to deserialize into a map representing the encryption context. + * @return the map containing key-value pairs comprised of strings. + */ + public static Map deserialize(final byte[] b) { + try { + if (b == null) { + return null; + } + + if (b.length == 0) { + return (Collections.emptyMap()); + } + + final ByteBuffer encryptionContextBytes = ByteBuffer.wrap(b); + + // retrieve the number of entries first + final int entryCount = encryptionContextBytes.getShort(); + if (entryCount <= 0 || entryCount > Short.MAX_VALUE) { + throw new AwsCryptoException( + "The number of entries in encryption context must be greater than 0 and smaller than " + + Short.MAX_VALUE); + } + + final CharsetDecoder decoder = StandardCharsets.UTF_8.newDecoder(); + + // ensure all failures in decoder are reported. + decoder.onMalformedInput(CodingErrorAction.REPORT); + decoder.onUnmappableCharacter(CodingErrorAction.REPORT); + + final Map result = new HashMap<>(entryCount); + for (int i = 0; i < entryCount; i++) { + // retrieve key + final int keyLen = encryptionContextBytes.getShort(); + if (keyLen <= 0 || keyLen > Short.MAX_VALUE) { + throw new AwsCryptoException( + "Key length must be greater than 0 and smaller than " + Short.MAX_VALUE); + } + + final ByteBuffer keyBytes = encryptionContextBytes.slice(); + Utils.limit(keyBytes, keyLen); + Utils.position(encryptionContextBytes, encryptionContextBytes.position() + keyLen); + + final int valueLen = encryptionContextBytes.getShort(); + if (valueLen <= 0 || valueLen > Short.MAX_VALUE) { + throw new AwsCryptoException( + "Value length must be greater than 0 and smaller than " + Short.MAX_VALUE); + } + + // retrieve value + final ByteBuffer valueBytes = encryptionContextBytes.slice(); + Utils.limit(valueBytes, valueLen); + Utils.position(encryptionContextBytes, encryptionContextBytes.position() + valueLen); + + final CharBuffer keyChars = decoder.decode(keyBytes); + final CharBuffer valueChars = decoder.decode(valueBytes); - /** - * Deserialize the provided byte array into a map containing key-value - * pairs comprised of strings. - * - * @param b - * the bytes to deserialize into a map representing the - * encryption context. - * @return - * the map containing key-value pairs comprised of strings. - */ - public static Map deserialize(final byte[] b) { - try { - if (b == null) { - return null; - } - - if (b.length == 0) { - return (Collections. emptyMap()); - } - - final ByteBuffer encryptionContextBytes = ByteBuffer.wrap(b); - - // retrieve the number of entries first - final int entryCount = encryptionContextBytes.getShort(); - if (entryCount <= 0 || entryCount > Short.MAX_VALUE) { - throw new AwsCryptoException( - "The number of entries in encryption context must be greater than 0 and smaller than " - + Short.MAX_VALUE); - } - - final CharsetDecoder decoder = StandardCharsets.UTF_8.newDecoder(); - - // ensure all failures in decoder are reported. - decoder.onMalformedInput(CodingErrorAction.REPORT); - decoder.onUnmappableCharacter(CodingErrorAction.REPORT); - - final Map result = new HashMap<>(entryCount); - for (int i = 0; i < entryCount; i++) { - // retrieve key - final int keyLen = encryptionContextBytes.getShort(); - if (keyLen <= 0 || keyLen > Short.MAX_VALUE) { - throw new AwsCryptoException("Key length must be greater than 0 and smaller than " - + Short.MAX_VALUE); - } - - final ByteBuffer keyBytes = encryptionContextBytes.slice(); - Utils.limit(keyBytes, keyLen); - Utils.position(encryptionContextBytes, encryptionContextBytes.position() + keyLen); - - final int valueLen = encryptionContextBytes.getShort(); - if (valueLen <= 0 || valueLen > Short.MAX_VALUE) { - throw new AwsCryptoException("Value length must be greater than 0 and smaller than " - + Short.MAX_VALUE); - } - - // retrieve value - final ByteBuffer valueBytes = encryptionContextBytes.slice(); - Utils.limit(valueBytes, valueLen); - Utils.position(encryptionContextBytes, encryptionContextBytes.position() + valueLen); - - final CharBuffer keyChars = decoder.decode(keyBytes); - final CharBuffer valueChars = decoder.decode(valueBytes); - - // check for duplicate entries. - if (result.put(keyChars.toString(), valueChars.toString()) != null) { - throw new AwsCryptoException("Encryption context contains duplicate entries."); - } - } - - return result; - } catch (CharacterCodingException e) { - throw new IllegalArgumentException("Encryption context contains an invalid unicode character"); - } catch (BufferUnderflowException e) { - throw new AwsCryptoException("Invalid encryption context. Expected more bytes.", e); + // check for duplicate entries. + if (result.put(keyChars.toString(), valueChars.toString()) != null) { + throw new AwsCryptoException("Encryption context contains duplicate entries."); } + } + + return result; + } catch (CharacterCodingException e) { + throw new IllegalArgumentException( + "Encryption context contains an invalid unicode character"); + } catch (BufferUnderflowException e) { + throw new AwsCryptoException("Invalid encryption context. Expected more bytes.", e); } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/EncryptionHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/EncryptionHandler.java index a827a7081..4487419b9 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/EncryptionHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/EncryptionHandler.java @@ -3,6 +3,17 @@ package com.amazonaws.encryptionsdk.internal; +import com.amazonaws.encryptionsdk.CommitmentPolicy; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.MasterKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; +import com.amazonaws.encryptionsdk.model.CiphertextFooters; +import com.amazonaws.encryptionsdk.model.CiphertextHeaders; +import com.amazonaws.encryptionsdk.model.CiphertextType; +import com.amazonaws.encryptionsdk.model.ContentType; +import com.amazonaws.encryptionsdk.model.EncryptionMaterials; +import com.amazonaws.encryptionsdk.model.KeyBlob; import java.io.IOException; import java.security.GeneralSecurityException; import java.security.InvalidKeyException; @@ -13,423 +24,408 @@ import java.security.interfaces.ECPrivateKey; import java.util.List; import java.util.Map; - import javax.crypto.Cipher; import javax.crypto.SecretKey; - -import com.amazonaws.encryptionsdk.CommitmentPolicy; -import com.amazonaws.encryptionsdk.model.CiphertextFooters; -import com.amazonaws.encryptionsdk.model.CiphertextHeaders; -import com.amazonaws.encryptionsdk.model.CiphertextType; -import com.amazonaws.encryptionsdk.model.ContentType; -import com.amazonaws.encryptionsdk.model.EncryptionMaterials; -import com.amazonaws.encryptionsdk.model.KeyBlob; import org.bouncycastle.asn1.ASN1Encodable; import org.bouncycastle.asn1.ASN1Integer; import org.bouncycastle.asn1.ASN1Sequence; import org.bouncycastle.asn1.DERSequence; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.MasterKey; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; - /** * This class implements the CryptoHandler interface by providing methods for the encryption of * plaintext data. - * - *

- * This class creates the ciphertext headers and delegates the encryption of the plaintext to the + * + *

This class creates the ciphertext headers and delegates the encryption of the plaintext to the * {@link BlockEncryptionHandler} or {@link FrameEncryptionHandler} based on the content type. */ public class EncryptionHandler implements MessageCryptoHandler { - private static final CiphertextType CIPHERTEXT_TYPE = CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA; - - private final EncryptionMaterials encryptionMaterials_; - private final Map encryptionContext_; - private final CryptoAlgorithm cryptoAlgo_; - private final List masterKeys_; - private final List keyBlobs_; - private final SecretKey encryptionKey_; - private final byte version_; - private final CiphertextType type_; - private final byte nonceLen_; - private final PrivateKey trailingSignaturePrivateKey_; - private final MessageDigest trailingDigest_; - private final Signature trailingSig_; - - private final CiphertextHeaders ciphertextHeaders_; - private final byte[] ciphertextHeaderBytes_; - private final CryptoHandler contentCryptoHandler_; - - private boolean firstOperation_ = true; - private boolean complete_ = false; - - private long plaintextBytes_ = 0; - private long plaintextByteLimit_ = -1; - - /** - * Create an encryption handler using the provided master key and encryption context. - * - * @param frameSize The encryption frame size, or zero for a one-shot encryption task - * @param result The EncryptionMaterials with the crypto materials for this encryption job - * @throws AwsCryptoException - * if the encryption context or master key is null. - */ - public EncryptionHandler(int frameSize, EncryptionMaterials result, CommitmentPolicy commitmentPolicy) throws AwsCryptoException { - Utils.assertNonNull(result, "result"); - Utils.assertNonNull(commitmentPolicy, "commitmentPolicy"); - - this.encryptionMaterials_ = result; - this.encryptionContext_ = result.getEncryptionContext(); - if (!commitmentPolicy.algorithmAllowedForEncrypt(result.getAlgorithm())) { - if (commitmentPolicy == CommitmentPolicy.ForbidEncryptAllowDecrypt) { - throw new AwsCryptoException("Configuration conflict. Cannot encrypt due to CommitmentPolicy " + - commitmentPolicy + " requiring only non-committed messages. Algorithm ID was " + - result.getAlgorithm() + - ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); - } else { - throw new AwsCryptoException("Configuration conflict. Cannot encrypt due to CommitmentPolicy " + - commitmentPolicy + " requiring only committed messages. Algorithm ID was " + - result.getAlgorithm() + - ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); - } - } - this.cryptoAlgo_ = result.getAlgorithm(); - this.masterKeys_ = result.getMasterKeys(); - this.keyBlobs_ = result.getEncryptedDataKeys(); - this.trailingSignaturePrivateKey_ = result.getTrailingSignatureKey(); - - if (keyBlobs_.isEmpty()) { - throw new IllegalArgumentException("No encrypted data keys in materials result"); - } - - if (trailingSignaturePrivateKey_ != null) { - try { - TrailingSignatureAlgorithm algorithm = TrailingSignatureAlgorithm.forCryptoAlgorithm(cryptoAlgo_); - trailingDigest_ = MessageDigest.getInstance(algorithm.getMessageDigestAlgorithm()); - trailingSig_ = Signature.getInstance(algorithm.getRawSignatureAlgorithm()); - - trailingSig_.initSign(trailingSignaturePrivateKey_, Utils.getSecureRandom()); - } catch (final GeneralSecurityException ex) { - throw new AwsCryptoException(ex); - } - } else { - trailingDigest_ = null; - trailingSig_ = null; - } - - // set default values - version_ = cryptoAlgo_.getMessageFormatVersion(); - type_ = CIPHERTEXT_TYPE; - nonceLen_ = cryptoAlgo_.getNonceLen(); - - ContentType contentType; - if (frameSize > 0) { - contentType = ContentType.FRAME; - } else if (frameSize == 0) { - contentType = ContentType.SINGLEBLOCK; - } else { - throw Utils.cannotBeNegative("Frame size"); - } - - // Construct the headers - // Included here rather than as a sub-routine so we can set final variables. - // This way we can avoid calculating the keys more times than we need. - final byte[] encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext_); - final CiphertextHeaders unsignedHeaders = new CiphertextHeaders(type_, cryptoAlgo_, - encryptionContextBytes, keyBlobs_, contentType, frameSize); - // We use a deterministic IV of zero for the header authentication. - unsignedHeaders.setHeaderNonce(new byte[nonceLen_]); - - // If using a committing crypto algorithm, we also need to calculate the commitment value along - // with the key derivation - if (cryptoAlgo_.isCommitting()) { - final CommittedKey committedKey = CommittedKey.generate(cryptoAlgo_, result.getCleartextDataKey(), unsignedHeaders.getMessageId()); - unsignedHeaders.setSuiteData(committedKey.getCommitment()); - encryptionKey_ = committedKey.getKey(); - } else { - try { - encryptionKey_ = cryptoAlgo_.getEncryptionKeyFromDataKey(result.getCleartextDataKey(), unsignedHeaders); - } catch (final InvalidKeyException ex) { - throw new AwsCryptoException(ex); - } - } - - ciphertextHeaders_ = signCiphertextHeaders(unsignedHeaders); - ciphertextHeaderBytes_ = ciphertextHeaders_.toByteArray(); - byte[] messageId_ = ciphertextHeaders_.getMessageId(); - - switch (contentType) { - case FRAME: - contentCryptoHandler_ = new FrameEncryptionHandler(encryptionKey_, nonceLen_, cryptoAlgo_, messageId_, - frameSize); - break; - case SINGLEBLOCK: - contentCryptoHandler_ = new BlockEncryptionHandler(encryptionKey_, nonceLen_, cryptoAlgo_, messageId_); - break; - default: - // should never get here because a valid content type is always - // set above based on the frame size. - throw new AwsCryptoException("Unknown content type."); - } + private static final CiphertextType CIPHERTEXT_TYPE = + CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA; + + private final EncryptionMaterials encryptionMaterials_; + private final Map encryptionContext_; + private final CryptoAlgorithm cryptoAlgo_; + private final List masterKeys_; + private final List keyBlobs_; + private final SecretKey encryptionKey_; + private final byte version_; + private final CiphertextType type_; + private final byte nonceLen_; + private final PrivateKey trailingSignaturePrivateKey_; + private final MessageDigest trailingDigest_; + private final Signature trailingSig_; + + private final CiphertextHeaders ciphertextHeaders_; + private final byte[] ciphertextHeaderBytes_; + private final CryptoHandler contentCryptoHandler_; + + private boolean firstOperation_ = true; + private boolean complete_ = false; + + private long plaintextBytes_ = 0; + private long plaintextByteLimit_ = -1; + + /** + * Create an encryption handler using the provided master key and encryption context. + * + * @param frameSize The encryption frame size, or zero for a one-shot encryption task + * @param result The EncryptionMaterials with the crypto materials for this encryption job + * @throws AwsCryptoException if the encryption context or master key is null. + */ + public EncryptionHandler( + int frameSize, EncryptionMaterials result, CommitmentPolicy commitmentPolicy) + throws AwsCryptoException { + Utils.assertNonNull(result, "result"); + Utils.assertNonNull(commitmentPolicy, "commitmentPolicy"); + + this.encryptionMaterials_ = result; + this.encryptionContext_ = result.getEncryptionContext(); + if (!commitmentPolicy.algorithmAllowedForEncrypt(result.getAlgorithm())) { + if (commitmentPolicy == CommitmentPolicy.ForbidEncryptAllowDecrypt) { + throw new AwsCryptoException( + "Configuration conflict. Cannot encrypt due to CommitmentPolicy " + + commitmentPolicy + + " requiring only non-committed messages. Algorithm ID was " + + result.getAlgorithm() + + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); + } else { + throw new AwsCryptoException( + "Configuration conflict. Cannot encrypt due to CommitmentPolicy " + + commitmentPolicy + + " requiring only committed messages. Algorithm ID was " + + result.getAlgorithm() + + ". See: https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/troubleshooting-migration.html"); + } } + this.cryptoAlgo_ = result.getAlgorithm(); + this.masterKeys_ = result.getMasterKeys(); + this.keyBlobs_ = result.getEncryptedDataKeys(); + this.trailingSignaturePrivateKey_ = result.getTrailingSignatureKey(); - /** - * Encrypt a block of bytes from {@code in} putting the plaintext result into {@code out}. - * - *

- * It encrypts by performing the following operations: - *

    - *
  1. if this is the first call to encrypt, write the ciphertext headers to the output being - * returned.
  2. - *
  3. else, pass off the input data to underlying content cryptohandler.
  4. - *
- * - * @param in - * the input byte array. - * @param off - * the offset into the in array where the data to be encrypted starts. - * @param len - * the number of bytes to be encrypted. - * @param out - * the output buffer the encrypted bytes go into. - * @param outOff - * the offset into the output byte array the encrypted data starts at. - * @return the number of bytes written to out and processed - * @throws AwsCryptoException - * if len or offset values are negative. - * @throws BadCiphertextException - * thrown by the underlying cipher handler. - */ - @Override - public ProcessingSummary processBytes(final byte[] in, final int off, final int len, final byte[] out, - final int outOff) - throws AwsCryptoException, BadCiphertextException { - if (len < 0 || off < 0) { - throw new AwsCryptoException(String.format( - "Invalid values for input offset: %d and length: %d", off, len)); - } - - checkPlaintextSizeLimit(len); - - int actualOutLen = 0; - - if (firstOperation_ == true) { - System.arraycopy(ciphertextHeaderBytes_, 0, out, outOff, ciphertextHeaderBytes_.length); - actualOutLen += ciphertextHeaderBytes_.length; - - firstOperation_ = false; - } - - ProcessingSummary contentOut = - contentCryptoHandler_.processBytes(in, off, len, out, outOff + actualOutLen); - actualOutLen += contentOut.getBytesWritten(); - updateTrailingSignature(out, outOff, actualOutLen); - plaintextBytes_ += contentOut.getBytesProcessed(); - return new ProcessingSummary(actualOutLen, contentOut.getBytesProcessed()); + if (keyBlobs_.isEmpty()) { + throw new IllegalArgumentException("No encrypted data keys in materials result"); } - /** - * Finish encryption of the plaintext bytes. - * - * @param out - * space for any resulting output data. - * @param outOff - * offset into out to start copying the data at. - * @return number of bytes written into out. - * @throws BadCiphertextException - * thrown by the underlying cipher handler. - */ - @Override - public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException { - if (complete_) { - throw new IllegalStateException("Attempted to call doFinal twice"); - } + if (trailingSignaturePrivateKey_ != null) { + try { + TrailingSignatureAlgorithm algorithm = + TrailingSignatureAlgorithm.forCryptoAlgorithm(cryptoAlgo_); + trailingDigest_ = MessageDigest.getInstance(algorithm.getMessageDigestAlgorithm()); + trailingSig_ = Signature.getInstance(algorithm.getRawSignatureAlgorithm()); + + trailingSig_.initSign(trailingSignaturePrivateKey_, Utils.getSecureRandom()); + } catch (final GeneralSecurityException ex) { + throw new AwsCryptoException(ex); + } + } else { + trailingDigest_ = null; + trailingSig_ = null; + } - complete_ = true; - - checkPlaintextSizeLimit(0); - - int written = contentCryptoHandler_.doFinal(out, outOff); - updateTrailingSignature(out, outOff, written); - if (cryptoAlgo_.getTrailingSignatureLength() > 0) { - try { - CiphertextFooters footer = new CiphertextFooters(signContent()); - byte[] fBytes = footer.toByteArray(); - System.arraycopy(fBytes, 0, out, outOff + written, fBytes.length); - return written + fBytes.length; - } catch (final SignatureException ex) { - throw new AwsCryptoException(ex); - } - } else { - return written; - } + // set default values + version_ = cryptoAlgo_.getMessageFormatVersion(); + type_ = CIPHERTEXT_TYPE; + nonceLen_ = cryptoAlgo_.getNonceLen(); + + ContentType contentType; + if (frameSize > 0) { + contentType = ContentType.FRAME; + } else if (frameSize == 0) { + contentType = ContentType.SINGLEBLOCK; + } else { + throw Utils.cannotBeNegative("Frame size"); } - private byte[] signContent() throws SignatureException { - if (trailingDigest_ != null) { - if (!trailingSig_.getAlgorithm().contains("ECDSA")) { - throw new UnsupportedOperationException("Signatures calculated in pieces is only supported for ECDSA."); - } - final byte[] digest = trailingDigest_.digest(); - return generateEcdsaFixedLengthSignature(digest); - } - return trailingSig_.sign(); + // Construct the headers + // Included here rather than as a sub-routine so we can set final variables. + // This way we can avoid calculating the keys more times than we need. + final byte[] encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext_); + final CiphertextHeaders unsignedHeaders = + new CiphertextHeaders( + type_, cryptoAlgo_, encryptionContextBytes, keyBlobs_, contentType, frameSize); + // We use a deterministic IV of zero for the header authentication. + unsignedHeaders.setHeaderNonce(new byte[nonceLen_]); + + // If using a committing crypto algorithm, we also need to calculate the commitment value along + // with the key derivation + if (cryptoAlgo_.isCommitting()) { + final CommittedKey committedKey = + CommittedKey.generate( + cryptoAlgo_, result.getCleartextDataKey(), unsignedHeaders.getMessageId()); + unsignedHeaders.setSuiteData(committedKey.getCommitment()); + encryptionKey_ = committedKey.getKey(); + } else { + try { + encryptionKey_ = + cryptoAlgo_.getEncryptionKeyFromDataKey(result.getCleartextDataKey(), unsignedHeaders); + } catch (final InvalidKeyException ex) { + throw new AwsCryptoException(ex); + } } - private byte[] generateEcdsaFixedLengthSignature(final byte[] digest) throws SignatureException { - byte[] signature; - // Unfortunately, we need deterministic lengths some signatures are non-deterministic in length. - // So, retry until we get the right length :-( - do { - trailingSig_.update(digest); - signature = trailingSig_.sign(); - if (signature.length != cryptoAlgo_.getTrailingSignatureLength()) { - // Most of the time, a signature of the wrong length can be fixed - // be negating s in the signature relative to the group order. - ASN1Sequence seq = ASN1Sequence.getInstance(signature); - ASN1Integer r = (ASN1Integer) seq.getObjectAt(0); - ASN1Integer s = (ASN1Integer) seq.getObjectAt(1); - ECPrivateKey ecKey = (ECPrivateKey) trailingSignaturePrivateKey_; - s = new ASN1Integer(ecKey.getParams().getOrder().subtract(s.getPositiveValue())); - seq = new DERSequence(new ASN1Encodable[]{r, s}); - try { - signature = seq.getEncoded(); - } catch (IOException ex) { - throw new SignatureException(ex); - } - } - } while (signature.length != cryptoAlgo_.getTrailingSignatureLength()); - return signature; + ciphertextHeaders_ = signCiphertextHeaders(unsignedHeaders); + ciphertextHeaderBytes_ = ciphertextHeaders_.toByteArray(); + byte[] messageId_ = ciphertextHeaders_.getMessageId(); + + switch (contentType) { + case FRAME: + contentCryptoHandler_ = + new FrameEncryptionHandler( + encryptionKey_, nonceLen_, cryptoAlgo_, messageId_, frameSize); + break; + case SINGLEBLOCK: + contentCryptoHandler_ = + new BlockEncryptionHandler(encryptionKey_, nonceLen_, cryptoAlgo_, messageId_); + break; + default: + // should never get here because a valid content type is always + // set above based on the frame size. + throw new AwsCryptoException("Unknown content type."); + } + } + + /** + * Encrypt a block of bytes from {@code in} putting the plaintext result into {@code out}. + * + *

It encrypts by performing the following operations: + * + *

    + *
  1. if this is the first call to encrypt, write the ciphertext headers to the output being + * returned. + *
  2. else, pass off the input data to underlying content cryptohandler. + *
+ * + * @param in the input byte array. + * @param off the offset into the in array where the data to be encrypted starts. + * @param len the number of bytes to be encrypted. + * @param out the output buffer the encrypted bytes go into. + * @param outOff the offset into the output byte array the encrypted data starts at. + * @return the number of bytes written to out and processed + * @throws AwsCryptoException if len or offset values are negative. + * @throws BadCiphertextException thrown by the underlying cipher handler. + */ + @Override + public ProcessingSummary processBytes( + final byte[] in, final int off, final int len, final byte[] out, final int outOff) + throws AwsCryptoException, BadCiphertextException { + if (len < 0 || off < 0) { + throw new AwsCryptoException( + String.format("Invalid values for input offset: %d and length: %d", off, len)); } - /** - * Return the size of the output buffer required for a {@code processBytes} plus a - * {@code doFinal} with an input of inLen bytes. - * - * @param inLen - * the length of the input. - * @return the space required to accommodate a call to processBytes and doFinal with len bytes - * of input. - */ - @Override - public int estimateOutputSize(final int inLen) { - int outSize = 0; - if (firstOperation_ == true) { - outSize += ciphertextHeaderBytes_.length; - } - outSize += contentCryptoHandler_.estimateOutputSize(inLen); + checkPlaintextSizeLimit(len); - if (cryptoAlgo_.getTrailingSignatureLength() > 0) { - outSize += 2; // Length field in footer - outSize += cryptoAlgo_.getTrailingSignatureLength(); - } - return outSize; - } + int actualOutLen = 0; - @Override - public int estimatePartialOutputSize(int inLen) { - int outSize = 0; - if (firstOperation_ == true) { - outSize += ciphertextHeaderBytes_.length; - } - outSize += contentCryptoHandler_.estimatePartialOutputSize(inLen); + if (firstOperation_ == true) { + System.arraycopy(ciphertextHeaderBytes_, 0, out, outOff, ciphertextHeaderBytes_.length); + actualOutLen += ciphertextHeaderBytes_.length; - return outSize; + firstOperation_ = false; } - @Override - public int estimateFinalOutputSize() { - return estimateOutputSize(0); + ProcessingSummary contentOut = + contentCryptoHandler_.processBytes(in, off, len, out, outOff + actualOutLen); + actualOutLen += contentOut.getBytesWritten(); + updateTrailingSignature(out, outOff, actualOutLen); + plaintextBytes_ += contentOut.getBytesProcessed(); + return new ProcessingSummary(actualOutLen, contentOut.getBytesProcessed()); + } + + /** + * Finish encryption of the plaintext bytes. + * + * @param out space for any resulting output data. + * @param outOff offset into out to start copying the data at. + * @return number of bytes written into out. + * @throws BadCiphertextException thrown by the underlying cipher handler. + */ + @Override + public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException { + if (complete_) { + throw new IllegalStateException("Attempted to call doFinal twice"); } - /** - * Return the encryption context. - * - * @return the key-value map containing encryption context. - */ - @Override - public Map getEncryptionContext() { - return encryptionContext_; + complete_ = true; + + checkPlaintextSizeLimit(0); + + int written = contentCryptoHandler_.doFinal(out, outOff); + updateTrailingSignature(out, outOff, written); + if (cryptoAlgo_.getTrailingSignatureLength() > 0) { + try { + CiphertextFooters footer = new CiphertextFooters(signContent()); + byte[] fBytes = footer.toByteArray(); + System.arraycopy(fBytes, 0, out, outOff + written, fBytes.length); + return written + fBytes.length; + } catch (final SignatureException ex) { + throw new AwsCryptoException(ex); + } + } else { + return written; } - - @Override - public CiphertextHeaders getHeaders() { - return ciphertextHeaders_; + } + + private byte[] signContent() throws SignatureException { + if (trailingDigest_ != null) { + if (!trailingSig_.getAlgorithm().contains("ECDSA")) { + throw new UnsupportedOperationException( + "Signatures calculated in pieces is only supported for ECDSA."); + } + final byte[] digest = trailingDigest_.digest(); + return generateEcdsaFixedLengthSignature(digest); } - - @Override - public void setMaxInputLength(long size) { - if (size < 0) { - throw Utils.cannotBeNegative("Max input length"); + return trailingSig_.sign(); + } + + private byte[] generateEcdsaFixedLengthSignature(final byte[] digest) throws SignatureException { + byte[] signature; + // Unfortunately, we need deterministic lengths some signatures are non-deterministic in length. + // So, retry until we get the right length :-( + do { + trailingSig_.update(digest); + signature = trailingSig_.sign(); + if (signature.length != cryptoAlgo_.getTrailingSignatureLength()) { + // Most of the time, a signature of the wrong length can be fixed + // be negating s in the signature relative to the group order. + ASN1Sequence seq = ASN1Sequence.getInstance(signature); + ASN1Integer r = (ASN1Integer) seq.getObjectAt(0); + ASN1Integer s = (ASN1Integer) seq.getObjectAt(1); + ECPrivateKey ecKey = (ECPrivateKey) trailingSignaturePrivateKey_; + s = new ASN1Integer(ecKey.getParams().getOrder().subtract(s.getPositiveValue())); + seq = new DERSequence(new ASN1Encodable[] {r, s}); + try { + signature = seq.getEncoded(); + } catch (IOException ex) { + throw new SignatureException(ex); } - - if (plaintextByteLimit_ == -1 || plaintextByteLimit_ > size) { - plaintextByteLimit_ = size; - } - - // check that we haven't already exceeded the limit - checkPlaintextSizeLimit(0); + } + } while (signature.length != cryptoAlgo_.getTrailingSignatureLength()); + return signature; + } + + /** + * Return the size of the output buffer required for a {@code processBytes} plus a {@code doFinal} + * with an input of inLen bytes. + * + * @param inLen the length of the input. + * @return the space required to accommodate a call to processBytes and doFinal with len bytes of + * input. + */ + @Override + public int estimateOutputSize(final int inLen) { + int outSize = 0; + if (firstOperation_ == true) { + outSize += ciphertextHeaderBytes_.length; } + outSize += contentCryptoHandler_.estimateOutputSize(inLen); - private void checkPlaintextSizeLimit(long additionalBytes) { - if (plaintextByteLimit_ != -1 && plaintextBytes_ + additionalBytes > plaintextByteLimit_) { - throw new IllegalStateException("Plaintext size exceeds max input size limit"); - } + if (cryptoAlgo_.getTrailingSignatureLength() > 0) { + outSize += 2; // Length field in footer + outSize += cryptoAlgo_.getTrailingSignatureLength(); } - - long getMaxInputLength() { - return plaintextByteLimit_; + return outSize; + } + + @Override + public int estimatePartialOutputSize(int inLen) { + int outSize = 0; + if (firstOperation_ == true) { + outSize += ciphertextHeaderBytes_.length; } - - /** - * Compute the MAC tag of the header bytes using the provided key, nonce, AAD, and crypto - * algorithm identifier. - * - * @param nonce - * the nonce to use in computing the MAC tag. - * @param aad - * the AAD to use in computing the MAC tag. - * @return the bytes containing the computed MAC tag. - */ - private byte[] computeHeaderTag(final byte[] nonce, final byte[] aad) { - final CipherHandler cipherHandler = new CipherHandler(encryptionKey_, - Cipher.ENCRYPT_MODE, - cryptoAlgo_); - - return cipherHandler.cipherData(nonce, aad, new byte[0], 0, 0); + outSize += contentCryptoHandler_.estimatePartialOutputSize(inLen); + + return outSize; + } + + @Override + public int estimateFinalOutputSize() { + return estimateOutputSize(0); + } + + /** + * Return the encryption context. + * + * @return the key-value map containing encryption context. + */ + @Override + public Map getEncryptionContext() { + return encryptionContext_; + } + + @Override + public CiphertextHeaders getHeaders() { + return ciphertextHeaders_; + } + + @Override + public void setMaxInputLength(long size) { + if (size < 0) { + throw Utils.cannotBeNegative("Max input length"); } - private CiphertextHeaders signCiphertextHeaders(final CiphertextHeaders unsignedHeaders) { - final byte[] headerFields = unsignedHeaders.serializeAuthenticatedFields(); - final byte[] headerTag = computeHeaderTag(unsignedHeaders.getHeaderNonce(), headerFields); - - unsignedHeaders.setHeaderTag(headerTag); - - return unsignedHeaders; + if (plaintextByteLimit_ == -1 || plaintextByteLimit_ > size) { + plaintextByteLimit_ = size; } - @Override - public List> getMasterKeys() { - //noinspection unchecked - return (List)masterKeys_; // This is unmodifiable - } + // check that we haven't already exceeded the limit + checkPlaintextSizeLimit(0); + } - private void updateTrailingSignature(byte[] input, int offset, int len) { - if (trailingDigest_ != null) { - trailingDigest_.update(input, offset, len); - } else if (trailingSig_ != null) { - try { - trailingSig_.update(input, offset, len); - } catch (final SignatureException ex) { - throw new AwsCryptoException(ex); - } - } + private void checkPlaintextSizeLimit(long additionalBytes) { + if (plaintextByteLimit_ != -1 && plaintextBytes_ + additionalBytes > plaintextByteLimit_) { + throw new IllegalStateException("Plaintext size exceeds max input size limit"); } - - @Override - public boolean isComplete() { - return complete_; + } + + long getMaxInputLength() { + return plaintextByteLimit_; + } + + /** + * Compute the MAC tag of the header bytes using the provided key, nonce, AAD, and crypto + * algorithm identifier. + * + * @param nonce the nonce to use in computing the MAC tag. + * @param aad the AAD to use in computing the MAC tag. + * @return the bytes containing the computed MAC tag. + */ + private byte[] computeHeaderTag(final byte[] nonce, final byte[] aad) { + final CipherHandler cipherHandler = + new CipherHandler(encryptionKey_, Cipher.ENCRYPT_MODE, cryptoAlgo_); + + return cipherHandler.cipherData(nonce, aad, new byte[0], 0, 0); + } + + private CiphertextHeaders signCiphertextHeaders(final CiphertextHeaders unsignedHeaders) { + final byte[] headerFields = unsignedHeaders.serializeAuthenticatedFields(); + final byte[] headerTag = computeHeaderTag(unsignedHeaders.getHeaderNonce(), headerFields); + + unsignedHeaders.setHeaderTag(headerTag); + + return unsignedHeaders; + } + + @Override + public List> getMasterKeys() { + //noinspection unchecked + return (List) masterKeys_; // This is unmodifiable + } + + private void updateTrailingSignature(byte[] input, int offset, int len) { + if (trailingDigest_ != null) { + trailingDigest_.update(input, offset, len); + } else if (trailingSig_ != null) { + try { + trailingSig_.update(input, offset, len); + } catch (final SignatureException ex) { + throw new AwsCryptoException(ex); + } } + } + + @Override + public boolean isComplete() { + return complete_; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/FrameDecryptionHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/FrameDecryptionHandler.java index 2c2dc6903..5a889526f 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/FrameDecryptionHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/FrameDecryptionHandler.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,277 +13,260 @@ package com.amazonaws.encryptionsdk.internal; -import java.util.Arrays; - -import javax.crypto.Cipher; -import javax.crypto.SecretKey; - import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.model.CipherFrameHeaders; +import java.util.Arrays; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; /** - * The frame decryption handler is a subclass of the decryption handler and - * thereby provides an implementation of the Cryptography handler. - * - *

- * It implements methods for decrypting content that was encrypted and stored in - * frames. + * The frame decryption handler is a subclass of the decryption handler and thereby provides an + * implementation of the Cryptography handler. + * + *

It implements methods for decrypting content that was encrypted and stored in frames. */ class FrameDecryptionHandler implements CryptoHandler { - private final SecretKey decryptionKey_; - private final CryptoAlgorithm cryptoAlgo_; - private final CipherHandler cipherHandler_; - private final byte[] messageId_; - - private final short nonceLen_; - - private CipherFrameHeaders currentFrameHeaders_; - private final int frameSize_; - private long frameNumber_ = 1; - - boolean complete_ = false; - private byte[] unparsedBytes_ = new byte[0]; - - /** - * Construct a decryption handler for decrypting bytes stored in frames. - * - * @param customerMasterKey - * the master key to use when unwrapping the data key encoded in - * the ciphertext. - */ - public FrameDecryptionHandler(final SecretKey decryptionKey, final short nonceLen, - final CryptoAlgorithm cryptoAlgo, final byte[] messageId, final int frameLen) { - decryptionKey_ = decryptionKey; - nonceLen_ = nonceLen; - cryptoAlgo_ = cryptoAlgo; - messageId_ = messageId; - frameSize_ = frameLen; - cipherHandler_ = new CipherHandler(decryptionKey_, Cipher.DECRYPT_MODE, cryptoAlgo_); + private final SecretKey decryptionKey_; + private final CryptoAlgorithm cryptoAlgo_; + private final CipherHandler cipherHandler_; + private final byte[] messageId_; + + private final short nonceLen_; + + private CipherFrameHeaders currentFrameHeaders_; + private final int frameSize_; + private long frameNumber_ = 1; + + boolean complete_ = false; + private byte[] unparsedBytes_ = new byte[0]; + + /** + * Construct a decryption handler for decrypting bytes stored in frames. + * + * @param customerMasterKey the master key to use when unwrapping the data key encoded in the + * ciphertext. + */ + public FrameDecryptionHandler( + final SecretKey decryptionKey, + final short nonceLen, + final CryptoAlgorithm cryptoAlgo, + final byte[] messageId, + final int frameLen) { + decryptionKey_ = decryptionKey; + nonceLen_ = nonceLen; + cryptoAlgo_ = cryptoAlgo; + messageId_ = messageId; + frameSize_ = frameLen; + cipherHandler_ = new CipherHandler(decryptionKey_, Cipher.DECRYPT_MODE, cryptoAlgo_); + } + + /** + * Decrypt the ciphertext bytes containing content encrypted using frames and put the plaintext + * bytes into out. + * + *

It decrypts by performing the following operations: + * + *

    + *
  1. parse the ciphertext headers + *
  2. parse the ciphertext until encrypted content in a frame is available + *
  3. decrypt the encrypted content + *
  4. return decrypted bytes as output + *
+ * + * @param in the input byte array. + * @param off the offset into the in array where the data to be decrypted starts. + * @param len the number of bytes to be decrypted. + * @param out the output buffer the decrypted plaintext bytes go into. + * @param outOff the offset into the output byte array the decrypted data starts at. + * @return the number of bytes written to out and processed + * @throws BadCiphertextException if frame number is invalid/out-of-order or if the bytes do not + * decrypt correctly. + * @throws AwsCryptoException if the content type found in the headers is not of frame type. + */ + @Override + public ProcessingSummary processBytes( + final byte[] in, final int off, final int len, final byte[] out, final int outOff) + throws BadCiphertextException, AwsCryptoException { + + if (complete_) { + throw new AwsCryptoException("Ciphertext has already been processed."); } - /** - * Decrypt the ciphertext bytes containing content encrypted using frames and put the plaintext - * bytes into out. - * - *

- * It decrypts by performing the following operations: - *

    - *
  1. parse the ciphertext headers
  2. - *
  3. parse the ciphertext until encrypted content in a frame is available
  4. - *
  5. decrypt the encrypted content
  6. - *
  7. return decrypted bytes as output
  8. - *
- * - * @param in - * the input byte array. - * @param off - * the offset into the in array where the data to be decrypted starts. - * @param len - * the number of bytes to be decrypted. - * @param out - * the output buffer the decrypted plaintext bytes go into. - * @param outOff - * the offset into the output byte array the decrypted data starts at. - * @return the number of bytes written to out and processed - * @throws BadCiphertextException - * if frame number is invalid/out-of-order or if the bytes do not decrypt correctly. - * @throws AwsCryptoException - * if the content type found in the headers is not of frame type. - */ - @Override - public ProcessingSummary processBytes(final byte[] in, final int off, final int len, final byte[] out, - final int outOff) - throws BadCiphertextException, AwsCryptoException { - - if (complete_) { - throw new AwsCryptoException("Ciphertext has already been processed."); - } + final long totalBytesToParse = unparsedBytes_.length + (long) len; + if (totalBytesToParse > Integer.MAX_VALUE) { + throw new AwsCryptoException( + "Integer overflow of the total bytes to parse and decrypt occured."); + } - final long totalBytesToParse = unparsedBytes_.length + (long) len; - if (totalBytesToParse > Integer.MAX_VALUE) { - throw new AwsCryptoException( - "Integer overflow of the total bytes to parse and decrypt occured."); + final byte[] bytesToParse = new byte[(int) totalBytesToParse]; + // If there were previously unparsed bytes, add them as the first + // set of bytes to be parsed in this call. + System.arraycopy(unparsedBytes_, 0, bytesToParse, 0, unparsedBytes_.length); + System.arraycopy(in, off, bytesToParse, unparsedBytes_.length, len); + + int actualOutLen = 0; + int totalParsedBytes = 0; + + // Parse available bytes. Stop parsing when there aren't enough + // bytes to complete parsing: + // - the ciphertext headers + // - the cipher frame + while (!complete_ && totalParsedBytes < bytesToParse.length) { + if (currentFrameHeaders_ == null) { + currentFrameHeaders_ = new CipherFrameHeaders(); + currentFrameHeaders_.setNonceLength(nonceLen_); + if (frameSize_ == 0) { + // if frame size in ciphertext headers is 0, the frame size + // will need to be parsed in individual frame headers. + currentFrameHeaders_.includeFrameSize(true); } + } - final byte[] bytesToParse = new byte[(int) totalBytesToParse]; - // If there were previously unparsed bytes, add them as the first - // set of bytes to be parsed in this call. - System.arraycopy(unparsedBytes_, 0, bytesToParse, 0, unparsedBytes_.length); - System.arraycopy(in, off, bytesToParse, unparsedBytes_.length, len); - - int actualOutLen = 0; - int totalParsedBytes = 0; - - // Parse available bytes. Stop parsing when there aren't enough - // bytes to complete parsing: - // - the ciphertext headers - // - the cipher frame - while (!complete_ && totalParsedBytes < bytesToParse.length) { - if (currentFrameHeaders_ == null) { - currentFrameHeaders_ = new CipherFrameHeaders(); - currentFrameHeaders_.setNonceLength(nonceLen_); - if (frameSize_ == 0) { - // if frame size in ciphertext headers is 0, the frame size - // will need to be parsed in individual frame headers. - currentFrameHeaders_.includeFrameSize(true); - } - } - - totalParsedBytes += currentFrameHeaders_.deserialize(bytesToParse, totalParsedBytes); - - // if we have all frame fields, process the encrypted content. - if (currentFrameHeaders_.isComplete() == true) { - int protectedContentLen = -1; - if (currentFrameHeaders_.isFinalFrame()) { - protectedContentLen = currentFrameHeaders_.getFrameContentLength(); - - // The final frame should not be able to exceed the frameLength - if (frameSize_ > 0 && protectedContentLen > frameSize_) { - throw new BadCiphertextException("Final frame length exceeds frame length."); - } - } else { - protectedContentLen = frameSize_; - } - - // include the tag which is added by the underlying cipher. - protectedContentLen += cryptoAlgo_.getTagLen(); - - if ((bytesToParse.length - totalParsedBytes) < protectedContentLen) { - // if we don't have all of the encrypted bytes, break - // until they become available. - break; - } - - final byte[] bytesToDecrypt_ = Arrays.copyOfRange(bytesToParse, totalParsedBytes, totalParsedBytes - + protectedContentLen); - totalParsedBytes += protectedContentLen; - - if (frameNumber_ == Constants.MAX_FRAME_NUMBER) { - throw new BadCiphertextException("Frame number exceeds the maximum allowed value."); - } - - final byte[] decryptedBytes = decryptContent(bytesToDecrypt_, 0, bytesToDecrypt_.length); - - System.arraycopy(decryptedBytes, 0, out, (outOff + actualOutLen), decryptedBytes.length); - actualOutLen += decryptedBytes.length; - frameNumber_++; - - complete_ = currentFrameHeaders_.isFinalFrame(); - // reset frame headers as we are done processing current frame. - currentFrameHeaders_ = null; - } else { - // if there aren't enough bytes to parse cipher frame, - // we can't continue parsing. - break; - } - } + totalParsedBytes += currentFrameHeaders_.deserialize(bytesToParse, totalParsedBytes); + + // if we have all frame fields, process the encrypted content. + if (currentFrameHeaders_.isComplete() == true) { + int protectedContentLen = -1; + if (currentFrameHeaders_.isFinalFrame()) { + protectedContentLen = currentFrameHeaders_.getFrameContentLength(); - if (!complete_) { - // buffer remaining bytes for parsing in the next round. - unparsedBytes_ = Arrays.copyOfRange(bytesToParse, totalParsedBytes, bytesToParse.length); - return new ProcessingSummary(actualOutLen, len); + // The final frame should not be able to exceed the frameLength + if (frameSize_ > 0 && protectedContentLen > frameSize_) { + throw new BadCiphertextException("Final frame length exceeds frame length."); + } } else { - final ProcessingSummary result = new ProcessingSummary(actualOutLen, totalParsedBytes - - unparsedBytes_.length); - unparsedBytes_ = new byte[0]; - return result; + protectedContentLen = frameSize_; } - } - /** - * Finish processing of the bytes. This function does nothing since the - * final frame will be processed and decrypted in processBytes(). - * - * @param out - * space for any resulting output data. - * @param outOff - * offset into out to start copying the data at. - * @return - * 0 - */ - @Override - public int doFinal(final byte[] out, final int outOff) { - if (!complete_) { - throw new BadCiphertextException("Unable to process entire ciphertext."); + // include the tag which is added by the underlying cipher. + protectedContentLen += cryptoAlgo_.getTagLen(); + + if ((bytesToParse.length - totalParsedBytes) < protectedContentLen) { + // if we don't have all of the encrypted bytes, break + // until they become available. + break; } - return 0; - } + final byte[] bytesToDecrypt_ = + Arrays.copyOfRange( + bytesToParse, totalParsedBytes, totalParsedBytes + protectedContentLen); + totalParsedBytes += protectedContentLen; - /** - * Return the size of the output buffer required for a processBytes plus a - * doFinal with an input of inLen bytes. - * - * @param inLen - * the length of the input. - * @return - * the space required to accommodate a call to processBytes and - * doFinal with len bytes of input. - */ - @Override - public int estimateOutputSize(final int inLen) { - int outSize = 0; - - final int totalBytesToDecrypt = unparsedBytes_.length + inLen; - if (totalBytesToDecrypt > 0) { - int frames = totalBytesToDecrypt / frameSize_; - frames += 1; // add one for final frame which might be < frame size. - outSize += (frameSize_ * frames); + if (frameNumber_ == Constants.MAX_FRAME_NUMBER) { + throw new BadCiphertextException("Frame number exceeds the maximum allowed value."); } - return outSize; - } + final byte[] decryptedBytes = decryptContent(bytesToDecrypt_, 0, bytesToDecrypt_.length); - @Override - public int estimatePartialOutputSize(int inLen) { - return estimateOutputSize(inLen); - } + System.arraycopy(decryptedBytes, 0, out, (outOff + actualOutLen), decryptedBytes.length); + actualOutLen += decryptedBytes.length; + frameNumber_++; - @Override - public int estimateFinalOutputSize() { - return 0; + complete_ = currentFrameHeaders_.isFinalFrame(); + // reset frame headers as we are done processing current frame. + currentFrameHeaders_ = null; + } else { + // if there aren't enough bytes to parse cipher frame, + // we can't continue parsing. + break; + } } - /** - * Returns the plaintext bytes of the encrypted content. - * - * @param input - * the input bytes containing the content - * @param off - * the offset into the input array where the data to be decrypted - * starts. - * @param len - * the number of bytes to be decrypted. - * @return - * the plaintext bytes of the encrypted content. - * @throws BadCiphertextException - * if the bytes do not decrypt correctly. - */ - private byte[] decryptContent(final byte[] input, final int off, final int len) throws BadCiphertextException { - final byte[] nonce = currentFrameHeaders_.getNonce(); - - byte[] contentAad = null; - if (currentFrameHeaders_.isFinalFrame() == true) { - contentAad = Utils.generateContentAad( - messageId_, - Constants.FINAL_FRAME_STRING_ID, - (int) frameNumber_, - currentFrameHeaders_.getFrameContentLength()); - } else { - contentAad = Utils.generateContentAad( - messageId_, - Constants.FRAME_STRING_ID, - (int) frameNumber_, - frameSize_); - } + if (!complete_) { + // buffer remaining bytes for parsing in the next round. + unparsedBytes_ = Arrays.copyOfRange(bytesToParse, totalParsedBytes, bytesToParse.length); + return new ProcessingSummary(actualOutLen, len); + } else { + final ProcessingSummary result = + new ProcessingSummary(actualOutLen, totalParsedBytes - unparsedBytes_.length); + unparsedBytes_ = new byte[0]; + return result; + } + } + + /** + * Finish processing of the bytes. This function does nothing since the final frame will be + * processed and decrypted in processBytes(). + * + * @param out space for any resulting output data. + * @param outOff offset into out to start copying the data at. + * @return 0 + */ + @Override + public int doFinal(final byte[] out, final int outOff) { + if (!complete_) { + throw new BadCiphertextException("Unable to process entire ciphertext."); + } - return cipherHandler_.cipherData(nonce, contentAad, input, off, len); + return 0; + } + + /** + * Return the size of the output buffer required for a processBytes plus a doFinal with an input + * of inLen bytes. + * + * @param inLen the length of the input. + * @return the space required to accommodate a call to processBytes and doFinal with len bytes of + * input. + */ + @Override + public int estimateOutputSize(final int inLen) { + int outSize = 0; + + final int totalBytesToDecrypt = unparsedBytes_.length + inLen; + if (totalBytesToDecrypt > 0) { + int frames = totalBytesToDecrypt / frameSize_; + frames += 1; // add one for final frame which might be < frame size. + outSize += (frameSize_ * frames); } - @Override - public boolean isComplete() { - return complete_; + return outSize; + } + + @Override + public int estimatePartialOutputSize(int inLen) { + return estimateOutputSize(inLen); + } + + @Override + public int estimateFinalOutputSize() { + return 0; + } + + /** + * Returns the plaintext bytes of the encrypted content. + * + * @param input the input bytes containing the content + * @param off the offset into the input array where the data to be decrypted starts. + * @param len the number of bytes to be decrypted. + * @return the plaintext bytes of the encrypted content. + * @throws BadCiphertextException if the bytes do not decrypt correctly. + */ + private byte[] decryptContent(final byte[] input, final int off, final int len) + throws BadCiphertextException { + final byte[] nonce = currentFrameHeaders_.getNonce(); + + byte[] contentAad = null; + if (currentFrameHeaders_.isFinalFrame() == true) { + contentAad = + Utils.generateContentAad( + messageId_, + Constants.FINAL_FRAME_STRING_ID, + (int) frameNumber_, + currentFrameHeaders_.getFrameContentLength()); + } else { + contentAad = + Utils.generateContentAad( + messageId_, Constants.FRAME_STRING_ID, (int) frameNumber_, frameSize_); } + + return cipherHandler_.cipherData(nonce, contentAad, input, off, len); + } + + @Override + public boolean isComplete() { + return complete_; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandler.java index 53193a0db..d9fc7f639 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandler.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,360 +13,333 @@ package com.amazonaws.encryptionsdk.internal; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -import javax.crypto.Cipher; -import javax.crypto.SecretKey; - import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.model.CipherFrameHeaders; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; /** - * The frame encryption handler is a subclass of the encryption handler and - * thereby provides an implementation of the Cryptography handler. - * - *

- * It implements methods for encrypting content and storing the encrypted bytes - * in frames. + * The frame encryption handler is a subclass of the encryption handler and thereby provides an + * implementation of the Cryptography handler. + * + *

It implements methods for encrypting content and storing the encrypted bytes in frames. */ class FrameEncryptionHandler implements CryptoHandler { - private final SecretKey encryptionKey_; - private final CryptoAlgorithm cryptoAlgo_; - private final CipherHandler cipherHandler_; - private final int nonceLen_; - private final byte[] messageId_; - private final int frameSize_; - private final int tagLenBytes_; - - private long frameNumber_ = 1; - private boolean isFinalFrame_; - - private final byte[] bytesToFrame_; - private int bytesToFrameLen_; - private boolean complete_ = false; - - /** - * Construct an encryption handler for encrypting bytes and storing them in - * frames. - * - * @param customerMasterKey - * the master key to use when wrapping the data key. - * @param encryptionContext - * the encryption context to use when wrapping the data key. - */ - public FrameEncryptionHandler(final SecretKey encryptionKey, final int nonceLen, final CryptoAlgorithm cryptoAlgo, - final byte[] messageId, final int frameSize) { - encryptionKey_ = encryptionKey; - cryptoAlgo_ = cryptoAlgo; - nonceLen_ = nonceLen; - messageId_ = messageId.clone(); - frameSize_ = frameSize; - tagLenBytes_ = cryptoAlgo_.getTagLen(); - bytesToFrame_ = new byte[frameSize_]; + private final SecretKey encryptionKey_; + private final CryptoAlgorithm cryptoAlgo_; + private final CipherHandler cipherHandler_; + private final int nonceLen_; + private final byte[] messageId_; + private final int frameSize_; + private final int tagLenBytes_; + + private long frameNumber_ = 1; + private boolean isFinalFrame_; + + private final byte[] bytesToFrame_; + private int bytesToFrameLen_; + private boolean complete_ = false; + + /** + * Construct an encryption handler for encrypting bytes and storing them in frames. + * + * @param customerMasterKey the master key to use when wrapping the data key. + * @param encryptionContext the encryption context to use when wrapping the data key. + */ + public FrameEncryptionHandler( + final SecretKey encryptionKey, + final int nonceLen, + final CryptoAlgorithm cryptoAlgo, + final byte[] messageId, + final int frameSize) { + encryptionKey_ = encryptionKey; + cryptoAlgo_ = cryptoAlgo; + nonceLen_ = nonceLen; + messageId_ = messageId.clone(); + frameSize_ = frameSize; + tagLenBytes_ = cryptoAlgo_.getTagLen(); + bytesToFrame_ = new byte[frameSize_]; + bytesToFrameLen_ = 0; + cipherHandler_ = new CipherHandler(encryptionKey_, Cipher.ENCRYPT_MODE, cryptoAlgo_); + } + + /** + * Encrypt a block of bytes from in putting the plaintext result into out. + * + *

It encrypts by performing the following operations: + * + *

    + *
  1. determine the size of encrypted content that can fit into current frame + *
  2. call processBytes() of the underlying cipher to do corresponding cryptographic encryption + * of plaintext + *
  3. check if current frame is fully filled using the processed bytes, write current frame to + * the output being returned. + *
+ * + * @param in the input byte array. + * @param inOff the offset into the in array where the data to be encrypted starts. + * @param inLen the number of bytes to be encrypted. + * @param out the output buffer the encrypted bytes go into. + * @param outOff the offset into the output byte array the encrypted data starts at. + * @return the number of bytes written to out and processed + * @throws InvalidCiphertextException thrown by the underlying cipher handler. + */ + @Override + public ProcessingSummary processBytes( + final byte[] in, final int off, final int len, final byte[] out, final int outOff) + throws BadCiphertextException { + int actualOutLen = 0; + + int size = len; + int offset = off; + while (size > 0) { + final int currentFrameCapacity = frameSize_ - bytesToFrameLen_; + // bind size to the capacity of the current frame + size = Math.min(currentFrameCapacity, size); + + System.arraycopy(in, offset, bytesToFrame_, bytesToFrameLen_, size); + bytesToFrameLen_ += size; + + // check if there is enough bytes to create a frame + if (bytesToFrameLen_ == frameSize_) { + actualOutLen += + writeEncryptedFrame(bytesToFrame_, 0, bytesToFrameLen_, out, outOff + actualOutLen); + + // reset buffer len as a new frame is created in next iteration bytesToFrameLen_ = 0; - cipherHandler_ = new CipherHandler(encryptionKey_, Cipher.ENCRYPT_MODE, cryptoAlgo_); - } + } - /** - * Encrypt a block of bytes from in putting the plaintext result into out. - * - *

- * It encrypts by performing the following operations: - *

    - *
  1. determine the size of encrypted content that can fit into current frame
  2. - *
  3. call processBytes() of the underlying cipher to do corresponding cryptographic encryption - * of plaintext
  4. - *
  5. check if current frame is fully filled using the processed bytes, write current frame to - * the output being returned.
  6. - *
- * - * @param in - * the input byte array. - * @param inOff - * the offset into the in array where the data to be encrypted starts. - * @param inLen - * the number of bytes to be encrypted. - * @param out - * the output buffer the encrypted bytes go into. - * @param outOff - * the offset into the output byte array the encrypted data starts at. - * @return the number of bytes written to out and processed - * @throws InvalidCiphertextException - * thrown by the underlying cipher handler. - */ - @Override - public ProcessingSummary processBytes(final byte[] in, final int off, final int len, final byte[] out, - final int outOff) - throws BadCiphertextException { - int actualOutLen = 0; - - int size = len; - int offset = off; - while (size > 0) { - final int currentFrameCapacity = frameSize_ - bytesToFrameLen_; - // bind size to the capacity of the current frame - size = Math.min(currentFrameCapacity, size); - - System.arraycopy(in, offset, bytesToFrame_, bytesToFrameLen_, size); - bytesToFrameLen_ += size; - - // check if there is enough bytes to create a frame - if (bytesToFrameLen_ == frameSize_) { - actualOutLen += writeEncryptedFrame(bytesToFrame_, 0, bytesToFrameLen_, out, outOff + actualOutLen); - - // reset buffer len as a new frame is created in next iteration - bytesToFrameLen_ = 0; - } - - // update offset by the size of bytes being encrypted. - offset += size; - // update size to the remaining bytes starting at offset. - size = len - offset; - } - - return new ProcessingSummary(actualOutLen, len); + // update offset by the size of bytes being encrypted. + offset += size; + // update size to the remaining bytes starting at offset. + size = len - offset; } - /** - * Finish processing of the bytes by writing out the ciphertext or final - * frame if framing. - * - * @param out - * space for any resulting output data. - * @param outOff - * offset into out to start copying the data at. - * @return - * number of bytes written into out. - * @throws InvalidCiphertextException - * thrown by the underlying cipher handler. + return new ProcessingSummary(actualOutLen, len); + } + + /** + * Finish processing of the bytes by writing out the ciphertext or final frame if framing. + * + * @param out space for any resulting output data. + * @param outOff offset into out to start copying the data at. + * @return number of bytes written into out. + * @throws InvalidCiphertextException thrown by the underlying cipher handler. + */ + @Override + public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException { + isFinalFrame_ = true; + complete_ = true; + return writeEncryptedFrame(bytesToFrame_, 0, bytesToFrameLen_, out, outOff); + } + + /** + * Return the size of the output buffer required for a processBytes plus a doFinal with an input + * of inLen bytes. + * + * @param inLen the length of the input. + * @return the space required to accommodate a call to processBytes and doFinal with len bytes of + * input. + */ + @Override + public int estimateOutputSize(final int inLen) { + int outSize = 0; + int frames = 0; + + // include any bytes held for inclusion in a subsequent frame + int totalContent = bytesToFrameLen_ + inLen; + + // compute the size of the frames that will be constructed + frames = totalContent / frameSize_; + outSize += (frameSize_ * frames); + + // account for remaining data that will need a new frame. + final int leftover = totalContent % frameSize_; + outSize += leftover; + // even if leftover is 0, there will be a final frame. + frames += 1; + + /* + * Calculate overhead of frame headers. */ - @Override - public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException { - isFinalFrame_ = true; - complete_ = true; - return writeEncryptedFrame(bytesToFrame_, 0, bytesToFrameLen_, out, outOff); - } + // nonce and MAC tag. + outSize += frames * (nonceLen_ + tagLenBytes_); - /** - * Return the size of the output buffer required for a processBytes plus a - * doFinal with an input of inLen bytes. - * - * @param inLen - * the length of the input. - * @return - * the space required to accommodate a call to processBytes and - * doFinal with len bytes of input. - */ - @Override - public int estimateOutputSize(final int inLen) { - int outSize = 0; - int frames = 0; + // sequence number for all frames + outSize += frames * (Integer.SIZE / Byte.SIZE); - // include any bytes held for inclusion in a subsequent frame - int totalContent = bytesToFrameLen_ + inLen; + // sequence number end for final frame + outSize += Integer.SIZE / Byte.SIZE; - // compute the size of the frames that will be constructed - frames = totalContent / frameSize_; - outSize += (frameSize_ * frames); + // integer for storing final frame size + outSize += Integer.SIZE / Byte.SIZE; - // account for remaining data that will need a new frame. - final int leftover = totalContent % frameSize_; - outSize += leftover; - // even if leftover is 0, there will be a final frame. - frames += 1; + return outSize; + } - /* - * Calculate overhead of frame headers. - */ - // nonce and MAC tag. - outSize += frames * (nonceLen_ + tagLenBytes_); + @Override + public int estimatePartialOutputSize(int inLen) { + int outSize = 0; + int frames = 0; - // sequence number for all frames - outSize += frames * (Integer.SIZE / Byte.SIZE); + // include any bytes held for inclusion in a subsequent frame + int totalContent = bytesToFrameLen_; + if (inLen >= 0) { + totalContent += inLen; + } - // sequence number end for final frame - outSize += Integer.SIZE / Byte.SIZE; + // compute the size of the frames that will be constructed + frames = totalContent / frameSize_; + outSize += (frameSize_ * frames); - // integer for storing final frame size - outSize += Integer.SIZE / Byte.SIZE; + /* + * Calculate overhead of frame headers. + */ + // nonce and MAC tag. + outSize += frames * (nonceLen_ + tagLenBytes_); - return outSize; - } + // sequence number for all frames + outSize += frames * (Integer.SIZE / Byte.SIZE); - @Override - public int estimatePartialOutputSize(int inLen) { - int outSize = 0; - int frames = 0; + return outSize; + } - // include any bytes held for inclusion in a subsequent frame - int totalContent = bytesToFrameLen_; - if (inLen >= 0) { - totalContent += inLen; - } + @Override + public int estimateFinalOutputSize() { + int outSize = 0; + int frames = 0; - // compute the size of the frames that will be constructed - frames = totalContent / frameSize_; - outSize += (frameSize_ * frames); + // include any bytes held for inclusion in a subsequent frame + int totalContent = bytesToFrameLen_; - /* - * Calculate overhead of frame headers. - */ - // nonce and MAC tag. - outSize += frames * (nonceLen_ + tagLenBytes_); + // compute the size of the frames that will be constructed + frames = totalContent / frameSize_; + outSize += (frameSize_ * frames); - // sequence number for all frames - outSize += frames * (Integer.SIZE / Byte.SIZE); + // account for remaining data that will need a new frame. + final int leftover = totalContent % frameSize_; + outSize += leftover; + // even if leftover is 0, there will be a final frame. + frames += 1; - return outSize; + /* + * Calculate overhead of frame headers. + */ + // nonce and MAC tag. + outSize += frames * (nonceLen_ + tagLenBytes_); + + // sequence number for all frames + outSize += frames * (Integer.SIZE / Byte.SIZE); + + // sequence number end for final frame + outSize += Integer.SIZE / Byte.SIZE; + + // integer for storing final frame size + outSize += Integer.SIZE / Byte.SIZE; + + return outSize; + } + + /** + * We encrypt the bytes, create the headers for the block, and assemble the frame containing the + * headers and the encrypted bytes. + * + * @param in the input byte array. + * @param inOff the offset into the in array where the data to be encrypted starts. + * @param inLen the number of bytes to be encrypted. + * @param out the output buffer the encrypted bytes go into. + * @param outOff the offset into the output byte array the encrypted data starts at. + * @return the number of bytes written to out. + * @throws BadCiphertextException thrown by the underlying cipher handler. + * @throws AwsCryptoException if frame number exceeds the maximum allowed value. + */ + private int writeEncryptedFrame( + final byte[] input, final int off, final int len, final byte[] out, final int outOff) + throws BadCiphertextException, AwsCryptoException { + if (frameNumber_ > Constants.MAX_FRAME_NUMBER + // Make sure we have the appropriate flag set for the final frame; we don't want to accept + // non-final-frame data when there won't be a subsequent frame for it to go into. + || (frameNumber_ == Constants.MAX_FRAME_NUMBER && !isFinalFrame_)) { + throw new AwsCryptoException("Frame number exceeded the maximum allowed value."); } - @Override - public int estimateFinalOutputSize() { - int outSize = 0; - int frames = 0; - - // include any bytes held for inclusion in a subsequent frame - int totalContent = bytesToFrameLen_; + if (out.length == 0) { + return 0; + } - // compute the size of the frames that will be constructed - frames = totalContent / frameSize_; - outSize += (frameSize_ * frames); + int outLen = 0; + + byte[] contentAad; + if (isFinalFrame_ == true) { + contentAad = + Utils.generateContentAad( + messageId_, Constants.FINAL_FRAME_STRING_ID, (int) frameNumber_, len); + } else { + contentAad = + Utils.generateContentAad( + messageId_, Constants.FRAME_STRING_ID, (int) frameNumber_, frameSize_); + } - // account for remaining data that will need a new frame. - final int leftover = totalContent % frameSize_; - outSize += leftover; - // even if leftover is 0, there will be a final frame. - frames += 1; + final byte[] nonce = getNonce(); - /* - * Calculate overhead of frame headers. - */ - // nonce and MAC tag. - outSize += frames * (nonceLen_ + tagLenBytes_); + final byte[] encryptedBytes = cipherHandler_.cipherData(nonce, contentAad, input, off, len); - // sequence number for all frames - outSize += frames * (Integer.SIZE / Byte.SIZE); + // create the cipherblock headers now for the encrypted data + final int encryptedContentLen = encryptedBytes.length - tagLenBytes_; + final CipherFrameHeaders cipherFrameHeaders = + new CipherFrameHeaders((int) frameNumber_, nonce, encryptedContentLen, isFinalFrame_); + final byte[] cipherFrameHeaderBytes = cipherFrameHeaders.toByteArray(); - // sequence number end for final frame - outSize += Integer.SIZE / Byte.SIZE; + // assemble the headers and the encrypted bytes into a single block + System.arraycopy( + cipherFrameHeaderBytes, 0, out, outOff + outLen, cipherFrameHeaderBytes.length); + outLen += cipherFrameHeaderBytes.length; + System.arraycopy(encryptedBytes, 0, out, outOff + outLen, encryptedBytes.length); + outLen += encryptedBytes.length; - // integer for storing final frame size - outSize += Integer.SIZE / Byte.SIZE; + frameNumber_++; - return outSize; - } + return outLen; + } - /** - * We encrypt the bytes, create the headers for the block, and assemble the - * frame containing the headers and the encrypted bytes. - * - * @param in - * the input byte array. - * @param inOff - * the offset into the in array where the data to be encrypted - * starts. - * @param inLen - * the number of bytes to be encrypted. - * @param out - * the output buffer the encrypted bytes go into. - * @param outOff - * the offset into the output byte array the encrypted data - * starts at. - * @return - * the number of bytes written to out. - * @throws BadCiphertextException - * thrown by the underlying cipher handler. - * @throws AwsCryptoException - * if frame number exceeds the maximum allowed value. + private byte[] getNonce() { + /* + * To mitigate the risk of IVs colliding within the same message, we use deterministic IV generation within a + * message. */ - private int writeEncryptedFrame(final byte[] input, final int off, final int len, final byte[] out, final int outOff) - throws BadCiphertextException, AwsCryptoException { - if (frameNumber_ > Constants.MAX_FRAME_NUMBER - // Make sure we have the appropriate flag set for the final frame; we don't want to accept - // non-final-frame data when there won't be a subsequent frame for it to go into. - || (frameNumber_ == Constants.MAX_FRAME_NUMBER && !isFinalFrame_)) { - throw new AwsCryptoException("Frame number exceeded the maximum allowed value."); - } - - if (out.length == 0) { - return 0; - } - - int outLen = 0; - - byte[] contentAad; - if (isFinalFrame_ == true) { - contentAad = Utils.generateContentAad( - messageId_, - Constants.FINAL_FRAME_STRING_ID, - (int) frameNumber_, - len); - } else { - contentAad = Utils.generateContentAad( - messageId_, - Constants.FRAME_STRING_ID, - (int) frameNumber_, - frameSize_); - } - - final byte[] nonce = getNonce(); - - final byte[] encryptedBytes = cipherHandler_.cipherData(nonce, contentAad, input, off, len); - - // create the cipherblock headers now for the encrypted data - final int encryptedContentLen = encryptedBytes.length - tagLenBytes_; - final CipherFrameHeaders cipherFrameHeaders = new CipherFrameHeaders( - (int) frameNumber_, - nonce, - encryptedContentLen, - isFinalFrame_); - final byte[] cipherFrameHeaderBytes = cipherFrameHeaders.toByteArray(); - - // assemble the headers and the encrypted bytes into a single block - System.arraycopy(cipherFrameHeaderBytes, 0, out, outOff + outLen, cipherFrameHeaderBytes.length); - outLen += cipherFrameHeaderBytes.length; - System.arraycopy(encryptedBytes, 0, out, outOff + outLen, encryptedBytes.length); - outLen += encryptedBytes.length; - - frameNumber_++; - - return outLen; - } - private byte[] getNonce() { - /* - * To mitigate the risk of IVs colliding within the same message, we use deterministic IV generation within a - * message. - */ - - if (frameNumber_ < 1) { - // This should never happen - however, since we use a "frame number zero" IV elsewhere (for header auth), - // we must be sure that we don't reuse it here. - throw new IllegalStateException("Illegal frame number"); - } - - if ((int)frameNumber_ == Constants.ENDFRAME_SEQUENCE_NUMBER && !isFinalFrame_) { - throw new IllegalStateException("Too many frames"); - } - - final byte[] nonce = new byte[nonceLen_]; - - ByteBuffer buf = ByteBuffer.wrap(nonce); - buf.order(ByteOrder.BIG_ENDIAN); - // We technically only allocate the low 32 bits for the frame number, and the other bits are defined to be - // zero. However, since MAX_FRAME_NUMBER is 2^32-1, the high-order four bytes of the long will be zero, so the - // big-endian representation will also have zeros in that position. - Utils.position(buf, buf.limit() - Long.BYTES); - buf.putLong(frameNumber_); - - return nonce; + if (frameNumber_ < 1) { + // This should never happen - however, since we use a "frame number zero" IV elsewhere (for + // header auth), + // we must be sure that we don't reuse it here. + throw new IllegalStateException("Illegal frame number"); } - @Override - public boolean isComplete() { - return complete_; + if ((int) frameNumber_ == Constants.ENDFRAME_SEQUENCE_NUMBER && !isFinalFrame_) { + throw new IllegalStateException("Too many frames"); } + + final byte[] nonce = new byte[nonceLen_]; + + ByteBuffer buf = ByteBuffer.wrap(nonce); + buf.order(ByteOrder.BIG_ENDIAN); + // We technically only allocate the low 32 bits for the frame number, and the other bits are + // defined to be + // zero. However, since MAX_FRAME_NUMBER is 2^32-1, the high-order four bytes of the long will + // be zero, so the + // big-endian representation will also have zeros in that position. + Utils.position(buf, buf.limit() - Long.BYTES); + buf.putLong(frameNumber_); + + return nonce; + } + + @Override + public boolean isComplete() { + return complete_; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/HmacKeyDerivationFunction.java b/src/main/java/com/amazonaws/encryptionsdk/internal/HmacKeyDerivationFunction.java index ca2d7cc8a..8c4f4f388 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/HmacKeyDerivationFunction.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/HmacKeyDerivationFunction.java @@ -14,158 +14,153 @@ */ package com.amazonaws.encryptionsdk.internal; +import static org.apache.commons.lang3.Validate.isTrue; + import java.security.GeneralSecurityException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.security.Provider; import java.util.Arrays; - import javax.crypto.Mac; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; -import static org.apache.commons.lang3.Validate.isTrue; - /** - * HMAC-based Key Derivation Function. - * Adapted from Hkdf.java in aws-dynamodb-encryption-java + * HMAC-based Key Derivation Function. Adapted from Hkdf.java in aws-dynamodb-encryption-java * * @see RFC 5869 */ public final class HmacKeyDerivationFunction { - private static final byte[] EMPTY_ARRAY = new byte[0]; - private final String algorithm; - private final Provider provider; - private SecretKey prk = null; + private static final byte[] EMPTY_ARRAY = new byte[0]; + private final String algorithm; + private final Provider provider; + private SecretKey prk = null; - /** - * Returns an HmacKeyDerivationFunction object using the specified algorithm. - * - * @param algorithm the standard name of the requested MAC algorithm. See the Mac - * section in the Java Cryptography Architecture Standard Algorithm Name - * Documentation for information about standard algorithm - * names. - * @return the new Hkdf object - * @throws NoSuchAlgorithmException if no Provider supports a MacSpi implementation for the - * specified algorithm. - */ - public static HmacKeyDerivationFunction getInstance(final String algorithm) - throws NoSuchAlgorithmException { - // Constructed specifically to sanity-test arguments. - Mac mac = Mac.getInstance(algorithm); - return new HmacKeyDerivationFunction(algorithm, mac.getProvider()); - } + /** + * Returns an HmacKeyDerivationFunction object using the specified algorithm. + * + * @param algorithm the standard name of the requested MAC algorithm. See the Mac section in the + * + * Java Cryptography Architecture Standard Algorithm Name Documentation for information + * about standard algorithm names. + * @return the new Hkdf object + * @throws NoSuchAlgorithmException if no Provider supports a MacSpi implementation for the + * specified algorithm. + */ + public static HmacKeyDerivationFunction getInstance(final String algorithm) + throws NoSuchAlgorithmException { + // Constructed specifically to sanity-test arguments. + Mac mac = Mac.getInstance(algorithm); + return new HmacKeyDerivationFunction(algorithm, mac.getProvider()); + } - /** - * Initializes this Hkdf with input keying material. A default salt of - * HashLen zeros will be used (where HashLen is the length of the return - * value of the supplied algorithm). - * - * @param ikm the Input Keying Material - */ - public void init(final byte[] ikm) { - init(ikm, null); - } + /** + * Initializes this Hkdf with input keying material. A default salt of HashLen zeros will be used + * (where HashLen is the length of the return value of the supplied algorithm). + * + * @param ikm the Input Keying Material + */ + public void init(final byte[] ikm) { + init(ikm, null); + } - /** - * Initializes this Hkdf with input keying material and a salt. If - * salt is null or of length 0, then a default salt of - * HashLen zeros will be used (where HashLen is the length of the return - * value of the supplied algorithm). - * - * @param salt the salt used for key extraction (optional) - * @param ikm the Input Keying Material - */ - public void init(final byte[] ikm, final byte[] salt) { - byte[] realSalt = (salt == null) ? EMPTY_ARRAY : salt.clone(); - byte[] rawKeyMaterial = EMPTY_ARRAY; - try { - Mac extractionMac = Mac.getInstance(algorithm, provider); - if (realSalt.length == 0) { - realSalt = new byte[extractionMac.getMacLength()]; - Arrays.fill(realSalt, (byte) 0); - } - extractionMac.init(new SecretKeySpec(realSalt, algorithm)); - rawKeyMaterial = extractionMac.doFinal(ikm); - this.prk = new SecretKeySpec(rawKeyMaterial, algorithm); - } catch (GeneralSecurityException e) { - // We've already checked all of the parameters so no exceptions - // should be possible here. - throw new RuntimeException("Unexpected exception", e); - } finally { - Arrays.fill(rawKeyMaterial, (byte) 0); // Zeroize temporary array - } - } - - private HmacKeyDerivationFunction(final String algorithm, final Provider provider) { - isTrue(algorithm.startsWith("Hmac"), "Invalid algorithm " + algorithm - + ". Hkdf may only be used with Hmac algorithms."); - this.algorithm = algorithm; - this.provider = provider; + /** + * Initializes this Hkdf with input keying material and a salt. If + * salt is null or of length 0, then a default salt of HashLen zeros will be + * used (where HashLen is the length of the return value of the supplied algorithm). + * + * @param salt the salt used for key extraction (optional) + * @param ikm the Input Keying Material + */ + public void init(final byte[] ikm, final byte[] salt) { + byte[] realSalt = (salt == null) ? EMPTY_ARRAY : salt.clone(); + byte[] rawKeyMaterial = EMPTY_ARRAY; + try { + Mac extractionMac = Mac.getInstance(algorithm, provider); + if (realSalt.length == 0) { + realSalt = new byte[extractionMac.getMacLength()]; + Arrays.fill(realSalt, (byte) 0); + } + extractionMac.init(new SecretKeySpec(realSalt, algorithm)); + rawKeyMaterial = extractionMac.doFinal(ikm); + this.prk = new SecretKeySpec(rawKeyMaterial, algorithm); + } catch (GeneralSecurityException e) { + // We've already checked all of the parameters so no exceptions + // should be possible here. + throw new RuntimeException("Unexpected exception", e); + } finally { + Arrays.fill(rawKeyMaterial, (byte) 0); // Zeroize temporary array } + } - /** - * Returns a pseudorandom key of length bytes. - * - * @param info optional context and application specific information (can be - * a zero-length array). - * @param length the length of the output key in bytes - * @return a pseudorandom key of length bytes. - * @throws IllegalStateException if this object has not been initialized - */ - public byte[] deriveKey(final byte[] info, final int length) throws IllegalStateException { - isTrue(length >= 0, "Length must be a non-negative value."); - assertInitialized(); - final byte[] result = new byte[length]; - Mac mac = createMac(); + private HmacKeyDerivationFunction(final String algorithm, final Provider provider) { + isTrue( + algorithm.startsWith("Hmac"), + "Invalid algorithm " + algorithm + ". Hkdf may only be used with Hmac algorithms."); + this.algorithm = algorithm; + this.provider = provider; + } - isTrue(length <= 255 * mac.getMacLength(), - "Requested keys may not be longer than 255 times the underlying HMAC length."); + /** + * Returns a pseudorandom key of length bytes. + * + * @param info optional context and application specific information (can be a zero-length array). + * @param length the length of the output key in bytes + * @return a pseudorandom key of length bytes. + * @throws IllegalStateException if this object has not been initialized + */ + public byte[] deriveKey(final byte[] info, final int length) throws IllegalStateException { + isTrue(length >= 0, "Length must be a non-negative value."); + assertInitialized(); + final byte[] result = new byte[length]; + Mac mac = createMac(); - byte[] t = EMPTY_ARRAY; - try { - int loc = 0; - byte i = 1; - while (loc < length) { - mac.update(t); - mac.update(info); - mac.update(i); - t = mac.doFinal(); + isTrue( + length <= 255 * mac.getMacLength(), + "Requested keys may not be longer than 255 times the underlying HMAC length."); - for (int x = 0; x < t.length && loc < length; x++, loc++) { - result[loc] = t[x]; - } + byte[] t = EMPTY_ARRAY; + try { + int loc = 0; + byte i = 1; + while (loc < length) { + mac.update(t); + mac.update(info); + mac.update(i); + t = mac.doFinal(); - i++; - } - } finally { - Arrays.fill(t, (byte) 0); // Zeroize temporary array + for (int x = 0; x < t.length && loc < length; x++, loc++) { + result[loc] = t[x]; } - return result; + + i++; + } + } finally { + Arrays.fill(t, (byte) 0); // Zeroize temporary array } + return result; + } - private Mac createMac() { - try { - Mac mac = Mac.getInstance(algorithm, provider); - mac.init(prk); - return mac; - } catch (NoSuchAlgorithmException | InvalidKeyException ex) { - // We've already validated that this algorithm/key is correct. - throw new RuntimeException(ex); - } + private Mac createMac() { + try { + Mac mac = Mac.getInstance(algorithm, provider); + mac.init(prk); + return mac; + } catch (NoSuchAlgorithmException | InvalidKeyException ex) { + // We've already validated that this algorithm/key is correct. + throw new RuntimeException(ex); } + } - /** - * Throws an IllegalStateException if this object has not been - * initialized. - * - * @throws IllegalStateException if this object has not been initialized - */ - private void assertInitialized() throws IllegalStateException { - if (prk == null) { - throw new IllegalStateException("Hkdf has not been initialized"); - } + /** + * Throws an IllegalStateException if this object has not been initialized. + * + * @throws IllegalStateException if this object has not been initialized + */ + private void assertInitialized() throws IllegalStateException { + if (prk == null) { + throw new IllegalStateException("Hkdf has not been initialized"); } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/JceKeyCipher.java b/src/main/java/com/amazonaws/encryptionsdk/internal/JceKeyCipher.java index 643278a71..b6362f424 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/JceKeyCipher.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/JceKeyCipher.java @@ -16,10 +16,6 @@ import com.amazonaws.encryptionsdk.EncryptedDataKey; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.model.KeyBlob; -import org.apache.commons.lang3.ArrayUtils; - -import javax.crypto.Cipher; -import javax.crypto.SecretKey; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; @@ -27,110 +23,119 @@ import java.security.PrivateKey; import java.security.PublicKey; import java.util.Map; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import org.apache.commons.lang3.ArrayUtils; -/** - * Abstract class for encrypting and decrypting JCE data keys. - */ +/** Abstract class for encrypting and decrypting JCE data keys. */ public abstract class JceKeyCipher { - private final Key wrappingKey; - private final Key unwrappingKey; - private static final Charset KEY_NAME_ENCODING = StandardCharsets.UTF_8; - - /** - * Returns a new instance of a JceKeyCipher based on the - * Advanced Encryption Standard in Galois/Counter Mode. - * - * @param secretKey The secret key to use for encrypt/decrypt operations. - * @return The JceKeyCipher. - */ - public static JceKeyCipher aesGcm(SecretKey secretKey) { - return new AesGcmJceKeyCipher(secretKey); - } - - /** - * Returns a new instance of a JceKeyCipher based on RSA. - * - * @param wrappingKey The public key to use for encrypting the key. - * @param unwrappingKey The private key to use for decrypting the key. - * @param transformation The transformation. - * @return The JceKeyCipher. - */ - public static JceKeyCipher rsa(PublicKey wrappingKey, PrivateKey unwrappingKey, String transformation) { - return new RsaJceKeyCipher(wrappingKey, unwrappingKey, transformation); - } - - JceKeyCipher(Key wrappingKey, Key unwrappingKey) { - this.wrappingKey = wrappingKey; - this.unwrappingKey = unwrappingKey; - } - - abstract WrappingData buildWrappingCipher(Key key, Map encryptionContext) throws GeneralSecurityException; - - abstract Cipher buildUnwrappingCipher(Key key, byte[] extraInfo, int offset, - Map encryptionContext) throws GeneralSecurityException; - - - /** - * Encrypts the given key, incorporating the given keyName and encryptionContext. - * @param key The key to encrypt. - * @param keyName A UTF-8 encoded representing a name for the key. - * @param keyNamespace A UTF-8 encoded value that namespaces the key. - * @param encryptionContext A key-value mapping of arbitrary, non-secret, UTF-8 encoded strings used - * during encryption and decryption to provide additional authenticated data (AAD). - * @return The encrypted data key. - */ - public EncryptedDataKey encryptKey(final byte[] key, final String keyName, final String keyNamespace, - final Map encryptionContext) { - - final byte[] keyNameBytes = keyName.getBytes(KEY_NAME_ENCODING); - - try { - final JceKeyCipher.WrappingData wData = buildWrappingCipher(wrappingKey, encryptionContext); - final Cipher cipher = wData.cipher; - final byte[] encryptedKey = cipher.doFinal(key); - - final byte[] provInfo; - if (wData.extraInfo.length == 0) { - provInfo = keyNameBytes; - } else { - provInfo = new byte[keyNameBytes.length + wData.extraInfo.length]; - System.arraycopy(keyNameBytes, 0, provInfo, 0, keyNameBytes.length); - System.arraycopy(wData.extraInfo, 0, provInfo, keyNameBytes.length, wData.extraInfo.length); - } - - return new KeyBlob(keyNamespace, provInfo, encryptedKey); - } catch (final GeneralSecurityException gsex) { - throw new AwsCryptoException(gsex); - } + private final Key wrappingKey; + private final Key unwrappingKey; + private static final Charset KEY_NAME_ENCODING = StandardCharsets.UTF_8; + + /** + * Returns a new instance of a JceKeyCipher based on the Advanced Encryption Standard in + * Galois/Counter Mode. + * + * @param secretKey The secret key to use for encrypt/decrypt operations. + * @return The JceKeyCipher. + */ + public static JceKeyCipher aesGcm(SecretKey secretKey) { + return new AesGcmJceKeyCipher(secretKey); + } + + /** + * Returns a new instance of a JceKeyCipher based on RSA. + * + * @param wrappingKey The public key to use for encrypting the key. + * @param unwrappingKey The private key to use for decrypting the key. + * @param transformation The transformation. + * @return The JceKeyCipher. + */ + public static JceKeyCipher rsa( + PublicKey wrappingKey, PrivateKey unwrappingKey, String transformation) { + return new RsaJceKeyCipher(wrappingKey, unwrappingKey, transformation); + } + + JceKeyCipher(Key wrappingKey, Key unwrappingKey) { + this.wrappingKey = wrappingKey; + this.unwrappingKey = unwrappingKey; + } + + abstract WrappingData buildWrappingCipher(Key key, Map encryptionContext) + throws GeneralSecurityException; + + abstract Cipher buildUnwrappingCipher( + Key key, byte[] extraInfo, int offset, Map encryptionContext) + throws GeneralSecurityException; + + /** + * Encrypts the given key, incorporating the given keyName and encryptionContext. + * + * @param key The key to encrypt. + * @param keyName A UTF-8 encoded representing a name for the key. + * @param keyNamespace A UTF-8 encoded value that namespaces the key. + * @param encryptionContext A key-value mapping of arbitrary, non-secret, UTF-8 encoded strings + * used during encryption and decryption to provide additional authenticated data (AAD). + * @return The encrypted data key. + */ + public EncryptedDataKey encryptKey( + final byte[] key, + final String keyName, + final String keyNamespace, + final Map encryptionContext) { + + final byte[] keyNameBytes = keyName.getBytes(KEY_NAME_ENCODING); + + try { + final JceKeyCipher.WrappingData wData = buildWrappingCipher(wrappingKey, encryptionContext); + final Cipher cipher = wData.cipher; + final byte[] encryptedKey = cipher.doFinal(key); + + final byte[] provInfo; + if (wData.extraInfo.length == 0) { + provInfo = keyNameBytes; + } else { + provInfo = new byte[keyNameBytes.length + wData.extraInfo.length]; + System.arraycopy(keyNameBytes, 0, provInfo, 0, keyNameBytes.length); + System.arraycopy(wData.extraInfo, 0, provInfo, keyNameBytes.length, wData.extraInfo.length); + } + + return new KeyBlob(keyNamespace, provInfo, encryptedKey); + } catch (final GeneralSecurityException gsex) { + throw new AwsCryptoException(gsex); } - - /** - * Decrypts the given encrypted data key. - * - * @param edk The encrypted data key. - * @param keyName A UTF-8 encoded String representing a name for the key. - * @param encryptionContext A key-value mapping of arbitrary, non-secret, UTF-8 encoded strings used - * during encryption and decryption to provide additional authenticated data (AAD). - * @return The decrypted key. - * @throws GeneralSecurityException If a problem occurred decrypting the key. - */ - public byte[] decryptKey(final EncryptedDataKey edk, final String keyName, - final Map encryptionContext) throws GeneralSecurityException { - final byte[] keyNameBytes = keyName.getBytes(KEY_NAME_ENCODING); - - final Cipher cipher = buildUnwrappingCipher(unwrappingKey, edk.getProviderInformation(), - keyNameBytes.length, encryptionContext); - return cipher.doFinal(edk.getEncryptedDataKey()); - } - - static class WrappingData { - public final Cipher cipher; - public final byte[] extraInfo; - - WrappingData(final Cipher cipher, final byte[] extraInfo) { - this.cipher = cipher; - this.extraInfo = extraInfo != null ? extraInfo : ArrayUtils.EMPTY_BYTE_ARRAY; - } + } + + /** + * Decrypts the given encrypted data key. + * + * @param edk The encrypted data key. + * @param keyName A UTF-8 encoded String representing a name for the key. + * @param encryptionContext A key-value mapping of arbitrary, non-secret, UTF-8 encoded strings + * used during encryption and decryption to provide additional authenticated data (AAD). + * @return The decrypted key. + * @throws GeneralSecurityException If a problem occurred decrypting the key. + */ + public byte[] decryptKey( + final EncryptedDataKey edk, final String keyName, final Map encryptionContext) + throws GeneralSecurityException { + final byte[] keyNameBytes = keyName.getBytes(KEY_NAME_ENCODING); + + final Cipher cipher = + buildUnwrappingCipher( + unwrappingKey, edk.getProviderInformation(), keyNameBytes.length, encryptionContext); + return cipher.doFinal(edk.getEncryptedDataKey()); + } + + static class WrappingData { + public final Cipher cipher; + public final byte[] extraInfo; + + WrappingData(final Cipher cipher, final byte[] extraInfo) { + this.cipher = cipher; + this.extraInfo = extraInfo != null ? extraInfo : ArrayUtils.EMPTY_BYTE_ARRAY; } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/LazyMessageCryptoHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/LazyMessageCryptoHandler.java index bc50183a4..14b4e688e 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/LazyMessageCryptoHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/LazyMessageCryptoHandler.java @@ -1,115 +1,115 @@ package com.amazonaws.encryptionsdk.internal; -import javax.annotation.concurrent.NotThreadSafe; +import com.amazonaws.encryptionsdk.MasterKey; +import com.amazonaws.encryptionsdk.model.CiphertextHeaders; import java.util.List; import java.util.Map; import java.util.function.Function; - -import com.amazonaws.encryptionsdk.MasterKey; -import com.amazonaws.encryptionsdk.model.CiphertextHeaders; +import javax.annotation.concurrent.NotThreadSafe; /** - * A {@link MessageCryptoHandler} that delegates to another MessageCryptoHandler, which is created at the last possible - * moment. Typically, this is used in order to defer the creation of the data key (and associated request to the - * {@link com.amazonaws.encryptionsdk.CryptoMaterialsManager} until the max message size is known. + * A {@link MessageCryptoHandler} that delegates to another MessageCryptoHandler, which is created + * at the last possible moment. Typically, this is used in order to defer the creation of the data + * key (and associated request to the {@link com.amazonaws.encryptionsdk.CryptoMaterialsManager} + * until the max message size is known. */ @NotThreadSafe public class LazyMessageCryptoHandler implements MessageCryptoHandler { - private Function delegateFactory; - private MessageCryptoHandler delegate; - private long maxInputSize = -1; + private Function delegateFactory; + private MessageCryptoHandler delegate; + private long maxInputSize = -1; - public static final class LateBoundInfo { - private final long maxInputSize; + public static final class LateBoundInfo { + private final long maxInputSize; - private LateBoundInfo(long maxInputSize) { - this.maxInputSize = maxInputSize; - } - - public long getMaxInputSize() { - return maxInputSize; - } + private LateBoundInfo(long maxInputSize) { + this.maxInputSize = maxInputSize; } - public LazyMessageCryptoHandler(Function delegateFactory) { - this.delegateFactory = delegateFactory; - this.delegate = null; + public long getMaxInputSize() { + return maxInputSize; } - - private MessageCryptoHandler getDelegate() { - if (delegate == null) { - delegate = delegateFactory.apply(new LateBoundInfo(maxInputSize)); - if (maxInputSize != -1) { - delegate.setMaxInputLength(maxInputSize); - } - - // Release references to the delegate factory, now that we're done with it. - delegateFactory = null; - } - - return delegate; + } + + public LazyMessageCryptoHandler(Function delegateFactory) { + this.delegateFactory = delegateFactory; + this.delegate = null; + } + + private MessageCryptoHandler getDelegate() { + if (delegate == null) { + delegate = delegateFactory.apply(new LateBoundInfo(maxInputSize)); + if (maxInputSize != -1) { + delegate.setMaxInputLength(maxInputSize); + } + + // Release references to the delegate factory, now that we're done with it. + delegateFactory = null; } - @Override - public void setMaxInputLength(long size) { - if (size < 0) { - throw new IllegalArgumentException("Max input size must be non-negative"); - } - - if (delegate == null) { - if (maxInputSize == -1 || maxInputSize > size) { - maxInputSize = size; - } - } else { - delegate.setMaxInputLength(size); - } - } - - @Override - public boolean isComplete() { - // If we haven't generated the delegate, we're definitely not done yet. - return delegate != null && delegate.isComplete(); - } - - /* Operations which autovivify the delegate */ - - @Override - public Map getEncryptionContext() { - return getDelegate().getEncryptionContext(); - } - - @Override - public CiphertextHeaders getHeaders() { - return getDelegate().getHeaders(); - } - - @Override - public ProcessingSummary processBytes(byte[] in, int inOff, int inLen, byte[] out, int outOff) { - return getDelegate().processBytes(in, inOff, inLen, out, outOff); - } - - @Override - public List> getMasterKeys() { - return getDelegate().getMasterKeys(); - } - - @Override - public int doFinal(byte[] out, int outOff) { - return getDelegate().doFinal(out, outOff); - } - - @Override - public int estimateOutputSize(int inLen) { - return getDelegate().estimateOutputSize(inLen); - } + return delegate; + } - @Override - public int estimatePartialOutputSize(int inLen) { - return getDelegate().estimatePartialOutputSize(inLen); + @Override + public void setMaxInputLength(long size) { + if (size < 0) { + throw new IllegalArgumentException("Max input size must be non-negative"); } - @Override - public int estimateFinalOutputSize() { - return getDelegate().estimateFinalOutputSize(); + if (delegate == null) { + if (maxInputSize == -1 || maxInputSize > size) { + maxInputSize = size; + } + } else { + delegate.setMaxInputLength(size); } + } + + @Override + public boolean isComplete() { + // If we haven't generated the delegate, we're definitely not done yet. + return delegate != null && delegate.isComplete(); + } + + /* Operations which autovivify the delegate */ + + @Override + public Map getEncryptionContext() { + return getDelegate().getEncryptionContext(); + } + + @Override + public CiphertextHeaders getHeaders() { + return getDelegate().getHeaders(); + } + + @Override + public ProcessingSummary processBytes(byte[] in, int inOff, int inLen, byte[] out, int outOff) { + return getDelegate().processBytes(in, inOff, inLen, out, outOff); + } + + @Override + public List> getMasterKeys() { + return getDelegate().getMasterKeys(); + } + + @Override + public int doFinal(byte[] out, int outOff) { + return getDelegate().doFinal(out, outOff); + } + + @Override + public int estimateOutputSize(int inLen) { + return getDelegate().estimateOutputSize(inLen); + } + + @Override + public int estimatePartialOutputSize(int inLen) { + return getDelegate().estimatePartialOutputSize(inLen); + } + + @Override + public int estimateFinalOutputSize() { + return getDelegate().estimateFinalOutputSize(); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/MessageCryptoHandler.java b/src/main/java/com/amazonaws/encryptionsdk/internal/MessageCryptoHandler.java index 34445f867..697299653 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/MessageCryptoHandler.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/MessageCryptoHandler.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,41 +13,40 @@ package com.amazonaws.encryptionsdk.internal; -import java.util.List; -import java.util.Map; - import com.amazonaws.encryptionsdk.MasterKey; import com.amazonaws.encryptionsdk.model.CiphertextHeaders; +import java.util.List; +import java.util.Map; public interface MessageCryptoHandler extends CryptoHandler { - /** - * Informs this handler of an upper bound on the input data size. The handler will throw an exception if this bound - * is exceeded, and may use it to perform performance optimizations as well. - * - * If this method is called multiple times, the smallest bound will be used. - * - * @param size An upper bound on the input data size. - */ - void setMaxInputLength(long size); + /** + * Informs this handler of an upper bound on the input data size. The handler will throw an + * exception if this bound is exceeded, and may use it to perform performance optimizations as + * well. + * + *

If this method is called multiple times, the smallest bound will be used. + * + * @param size An upper bound on the input data size. + */ + void setMaxInputLength(long size); - /** - * Return the encryption context used in the generation of the data key used for the encryption - * of content. - * - *

- * During decryption, this value should be obtained by parsing the ciphertext headers that - * encodes this value. - * - * @return the key-value map containing the encryption context. - */ - Map getEncryptionContext(); + /** + * Return the encryption context used in the generation of the data key used for the encryption of + * content. + * + *

During decryption, this value should be obtained by parsing the ciphertext headers that + * encodes this value. + * + * @return the key-value map containing the encryption context. + */ + Map getEncryptionContext(); - CiphertextHeaders getHeaders(); + CiphertextHeaders getHeaders(); - /** - * All used {@link MasterKey}s. For encryption flows, these are all the - * {@link MasterKey}s used to protect the data. In the decryption flow, it is the single - * {@link MasterKey} actually used to decrypt the data. - */ - List> getMasterKeys(); + /** + * All used {@link MasterKey}s. For encryption flows, these are all the {@link + * MasterKey}s used to protect the data. In the decryption flow, it is the single {@link + * MasterKey} actually used to decrypt the data. + */ + List> getMasterKeys(); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/PrimitivesParser.java b/src/main/java/com/amazonaws/encryptionsdk/internal/PrimitivesParser.java index 037174a19..d7ddc6c2b 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/PrimitivesParser.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/PrimitivesParser.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,246 +13,226 @@ package com.amazonaws.encryptionsdk.internal; +import com.amazonaws.encryptionsdk.exception.ParseException; import java.io.DataOutput; import java.io.IOException; -import com.amazonaws.encryptionsdk.exception.ParseException; - /** - * This class implements methods for parsing the primitives ( - * {@code byte, short, int, long}) in Java from a byte array. + * This class implements methods for parsing the primitives ( {@code byte, short, int, long}) in + * Java from a byte array. */ -//@ non_null_by_default +// @ non_null_by_default public class PrimitivesParser { - /** - * Construct a long value using 8 bytes starting at the specified offset. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return the parsed long value. - */ - //@ private normal_behavior - //@ requires 0 <= off && off <= b.length - Long.BYTES; - //@ ensures \result == Long.asLong(b[off],b[off+1],b[off+2],b[off+3],b[off+4],b[off+5],b[off+6],b[off+7]); - //@ pure spec_public - private static long getLong(final byte[] b, final int off) { - return ((b[off + 7] & 0xFFL)) + ((b[off + 6] & 0xFFL) << 8) + ((b[off + 5] & 0xFFL) << 16) - + ((b[off + 4] & 0xFFL) << 24) + ((b[off + 3] & 0xFFL) << 32) + ((b[off + 2] & 0xFFL) << 40) - + ((b[off + 1] & 0xFFL) << 48) + (((long) b[off]) << 56); - } + /** + * Construct a long value using 8 bytes starting at the specified offset. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the parsed long value. + */ + // @ private normal_behavior + // @ requires 0 <= off && off <= b.length - Long.BYTES; + // @ ensures \result == + // Long.asLong(b[off],b[off+1],b[off+2],b[off+3],b[off+4],b[off+5],b[off+6],b[off+7]); + // @ pure spec_public + private static long getLong(final byte[] b, final int off) { + return ((b[off + 7] & 0xFFL)) + + ((b[off + 6] & 0xFFL) << 8) + + ((b[off + 5] & 0xFFL) << 16) + + ((b[off + 4] & 0xFFL) << 24) + + ((b[off + 3] & 0xFFL) << 32) + + ((b[off + 2] & 0xFFL) << 40) + + ((b[off + 1] & 0xFFL) << 48) + + (((long) b[off]) << 56); + } - /** - * Construct an integer value using 4 bytes starting at the specified offset. - * - * @param b - * the byte array containing the integer value. - * @param off - * the offset in the byte array to use when parsing. - * @return the constructed integer value. - */ - //@ private normal_behavior - //@ requires 0 <= off && off <= b.length - Integer.BYTES; - //@ ensures \result == Integer.asInt(b[off],b[off+1],b[off+2],b[off+3]); - //@ pure spec_public - private static int getInt(final byte[] b, final int off) { - return ((b[off + 3] & 0xFF)) + ((b[off + 2] & 0xFF) << 8) + ((b[off + 1] & 0xFF) << 16) - + ((b[off] & 0xFF) << 24); - } + /** + * Construct an integer value using 4 bytes starting at the specified offset. + * + * @param b the byte array containing the integer value. + * @param off the offset in the byte array to use when parsing. + * @return the constructed integer value. + */ + // @ private normal_behavior + // @ requires 0 <= off && off <= b.length - Integer.BYTES; + // @ ensures \result == Integer.asInt(b[off],b[off+1],b[off+2],b[off+3]); + // @ pure spec_public + private static int getInt(final byte[] b, final int off) { + return ((b[off + 3] & 0xFF)) + + ((b[off + 2] & 0xFF) << 8) + + ((b[off + 1] & 0xFF) << 16) + + ((b[off] & 0xFF) << 24); + } - /** - * Construct a short value using 4 bytes starting at the specified offset. - * - * @param b - * the byte array containing the short value. - * @param off - * the offset in the byte array to use when parsing. - * @return the constructed short value. - */ - //@ private normal_behavior - //@ requires 0 <= off && off <= b.length - Short.BYTES; - //@ ensures \result == Short.asShort(b[off],b[off+1]); - //@ pure spec_public - private static short getShort(final byte[] b, final int off) { - return (short) ((b[off + 1] & 0xFF) + ((b[off] & 0xFF) << 8)); - } + /** + * Construct a short value using 4 bytes starting at the specified offset. + * + * @param b the byte array containing the short value. + * @param off the offset in the byte array to use when parsing. + * @return the constructed short value. + */ + // @ private normal_behavior + // @ requires 0 <= off && off <= b.length - Short.BYTES; + // @ ensures \result == Short.asShort(b[off],b[off+1]); + // @ pure spec_public + private static short getShort(final byte[] b, final int off) { + return (short) ((b[off + 1] & 0xFF) + ((b[off] & 0xFF) << 8)); + } - /** - * Parse a long primitive type in the provided bytes. It looks for - * 8 bytes in the provided bytes starting at the specified off. - * - *

- * If successful, it returns the value of the parsed long type. On failure, - * it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the parsed long value. - * @throws ParseException - * if there are not sufficient bytes. - */ - //@ public normal_behavior - //@ requires 0 <= off && off <= b.length - Long.BYTES; - //@ ensures \result == Long.asLong(b[off],b[off+1],b[off+2],b[off+3],b[off+4],b[off+5],b[off+6],b[off+7]); - //@ also private exceptional_behavior - //@ requires b.length - Long.BYTES < off; - //@ signals_only ParseException; - //@ pure - public static long parseLong(final byte[] b, final int off) throws ParseException { - final int size = Long.SIZE / Byte.SIZE; - final int len = b.length - off; - if (len >= size) { - return getLong(b, off); - } else { - throw new ParseException("Not enough bytes to parse a long."); - } + /** + * Parse a long primitive type in the provided bytes. It looks for 8 bytes in the provided bytes + * starting at the specified off. + * + *

If successful, it returns the value of the parsed long type. On failure, it throws a parse + * exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the parsed long value. + * @throws ParseException if there are not sufficient bytes. + */ + // @ public normal_behavior + // @ requires 0 <= off && off <= b.length - Long.BYTES; + // @ ensures \result == + // Long.asLong(b[off],b[off+1],b[off+2],b[off+3],b[off+4],b[off+5],b[off+6],b[off+7]); + // @ also private exceptional_behavior + // @ requires b.length - Long.BYTES < off; + // @ signals_only ParseException; + // @ pure + public static long parseLong(final byte[] b, final int off) throws ParseException { + final int size = Long.SIZE / Byte.SIZE; + final int len = b.length - off; + if (len >= size) { + return getLong(b, off); + } else { + throw new ParseException("Not enough bytes to parse a long."); } + } - /** - * Parse an integer primitive type in the provided bytes. It looks for - * 4 bytes in the provided bytes starting at the specified off. - * - *

- * If successful, it returns the value of the parsed integer type. On - * failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the parsed integer value. - * @throws ParseException - * if there are not sufficient bytes. - */ - //@ public normal_behavior - //@ requires 0 <= off && off <= b.length - Integer.BYTES; - //@ ensures \result == Integer.asInt(b[off],b[off+1],b[off+2],b[off+3]); - //@ also private exceptional_behavior - //@ requires b.length - Integer.BYTES < off; - //@ signals_only ParseException; - //@ pure - public static int parseInt(final byte[] b, final int off) throws ParseException { - final int size = Integer.SIZE / Byte.SIZE; - final int len = b.length - off; - if (len >= size) { - return getInt(b, off); - } else { - throw new ParseException("Not enough bytes to parse an integer."); - } + /** + * Parse an integer primitive type in the provided bytes. It looks for 4 bytes in the provided + * bytes starting at the specified off. + * + *

If successful, it returns the value of the parsed integer type. On failure, it throws a + * parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the parsed integer value. + * @throws ParseException if there are not sufficient bytes. + */ + // @ public normal_behavior + // @ requires 0 <= off && off <= b.length - Integer.BYTES; + // @ ensures \result == Integer.asInt(b[off],b[off+1],b[off+2],b[off+3]); + // @ also private exceptional_behavior + // @ requires b.length - Integer.BYTES < off; + // @ signals_only ParseException; + // @ pure + public static int parseInt(final byte[] b, final int off) throws ParseException { + final int size = Integer.SIZE / Byte.SIZE; + final int len = b.length - off; + if (len >= size) { + return getInt(b, off); + } else { + throw new ParseException("Not enough bytes to parse an integer."); } + } - /** - * Parse a short primitive type in the provided bytes. It looks for 2 bytes - * in the provided bytes starting at the specified off. - * - *

- * If successful, it returns the value of the parsed short type. On failure, - * it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the parsed short value. - * @throws ParseException - * if there are not sufficient bytes. - */ - //@ public normal_behavior - //@ requires 0 <= off && off <= b.length - Short.BYTES; - //@ ensures \result == Short.asShort(b[off],b[off+1]); - //@ also private exceptional_behavior - //@ requires b.length - Short.BYTES < off; - //@ signals_only ParseException; - //@ pure - public static short parseShort(final byte[] b, final int off) { - final short size = Short.SIZE / Byte.SIZE; - final int len = b.length - off; - if (len >= size) { - return getShort(b, off); - } else { - throw new ParseException("Not enough bytes to parse a short."); - } + /** + * Parse a short primitive type in the provided bytes. It looks for 2 bytes in the provided bytes + * starting at the specified off. + * + *

If successful, it returns the value of the parsed short type. On failure, it throws a parse + * exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the parsed short value. + * @throws ParseException if there are not sufficient bytes. + */ + // @ public normal_behavior + // @ requires 0 <= off && off <= b.length - Short.BYTES; + // @ ensures \result == Short.asShort(b[off],b[off+1]); + // @ also private exceptional_behavior + // @ requires b.length - Short.BYTES < off; + // @ signals_only ParseException; + // @ pure + public static short parseShort(final byte[] b, final int off) { + final short size = Short.SIZE / Byte.SIZE; + final int len = b.length - off; + if (len >= size) { + return getShort(b, off); + } else { + throw new ParseException("Not enough bytes to parse a short."); } + } - /** - * Equivalent to {@link #parseShort(byte[], int)} except the 2 bytes are treated as an unsigned - * value (and thus returned as an into to avoid overflow). - */ - //@ public normal_behavior - //@ requires 0 <= off && off <= b.length - Short.BYTES; - //@ ensures \result == Short.asUnsignedToInt(Short.asShort(b[off], b[off+1])); - //@ ensures \result >= 0 && \result <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ also private exceptional_behavior - //@ requires b.length - Short.BYTES < off; - //@ signals_only ParseException; - //@ pure - public static int parseUnsignedShort(final byte[] b, final int off) { - final int signedResult = parseShort(b, off); - if (signedResult >= 0) { - return signedResult; - } else { - return Constants.UNSIGNED_SHORT_MAX_VAL + 1 + signedResult; - } + /** + * Equivalent to {@link #parseShort(byte[], int)} except the 2 bytes are treated as an unsigned + * value (and thus returned as an into to avoid overflow). + */ + // @ public normal_behavior + // @ requires 0 <= off && off <= b.length - Short.BYTES; + // @ ensures \result == Short.asUnsignedToInt(Short.asShort(b[off], b[off+1])); + // @ ensures \result >= 0 && \result <= Constants.UNSIGNED_SHORT_MAX_VAL; + // @ also private exceptional_behavior + // @ requires b.length - Short.BYTES < off; + // @ signals_only ParseException; + // @ pure + public static int parseUnsignedShort(final byte[] b, final int off) { + final int signedResult = parseShort(b, off); + if (signedResult >= 0) { + return signedResult; + } else { + return Constants.UNSIGNED_SHORT_MAX_VAL + 1 + signedResult; } + } - /** - * Writes 2 bytes containing the unsigned value {@code uShort} to {@code out}. - */ - //@ // left as TODO because OpenJML/Specs does not have sufficiently detailed - //@ // specs for java.io.DataOutput - //@ public normal_behavior - //@ requires 0 <= uShort && uShort < -Short.MIN_VALUE-Short.MIN_VALUE; - //@// assignable TODO ... - //@// ensures TODO ... - public static void writeUnsignedShort(final DataOutput out, final int uShort) throws IOException { - if (uShort < 0 || uShort > Constants.UNSIGNED_SHORT_MAX_VAL) { - throw new IllegalArgumentException("Unsigned shorts must be between 0 and " - + Constants.UNSIGNED_SHORT_MAX_VAL); - } - if (uShort < Short.MAX_VALUE) { - out.writeShort(uShort); - } else { - out.writeShort(uShort - Constants.UNSIGNED_SHORT_MAX_VAL - 1); - } + /** Writes 2 bytes containing the unsigned value {@code uShort} to {@code out}. */ + // @ // left as TODO because OpenJML/Specs does not have sufficiently detailed + // @ // specs for java.io.DataOutput + // @ public normal_behavior + // @ requires 0 <= uShort && uShort < -Short.MIN_VALUE-Short.MIN_VALUE; + // @// assignable TODO ... + // @// ensures TODO ... + public static void writeUnsignedShort(final DataOutput out, final int uShort) throws IOException { + if (uShort < 0 || uShort > Constants.UNSIGNED_SHORT_MAX_VAL) { + throw new IllegalArgumentException( + "Unsigned shorts must be between 0 and " + Constants.UNSIGNED_SHORT_MAX_VAL); + } + if (uShort < Short.MAX_VALUE) { + out.writeShort(uShort); + } else { + out.writeShort(uShort - Constants.UNSIGNED_SHORT_MAX_VAL - 1); } + } - /** - * Parse a single byte in the provided bytes. It looks for a byte in the - * provided bytes starting at the specified off. - * - *

- * If successful, it returns the value of the parsed byte. On failure, it - * throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the parsed byte value. - * @throws ParseException - * if there are not sufficient bytes. - */ - //@ public normal_behavior - //@ requires 0 <= off && off <= b.length - Byte.BYTES; - //@ ensures \result == b[off]; - //@ also private exceptional_behavior - //@ requires b.length - Byte.BYTES < off; - //@ signals_only ParseException; - //@ pure - public static byte parseByte(final byte[] b, final int off) { - final int size = 1; - final int len = b.length - off; - if (len >= size) { - return b[off]; - } else { - throw new ParseException("Not enough bytes to parse a byte."); - } + /** + * Parse a single byte in the provided bytes. It looks for a byte in the provided bytes starting + * at the specified off. + * + *

If successful, it returns the value of the parsed byte. On failure, it throws a parse + * exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the parsed byte value. + * @throws ParseException if there are not sufficient bytes. + */ + // @ public normal_behavior + // @ requires 0 <= off && off <= b.length - Byte.BYTES; + // @ ensures \result == b[off]; + // @ also private exceptional_behavior + // @ requires b.length - Byte.BYTES < off; + // @ signals_only ParseException; + // @ pure + public static byte parseByte(final byte[] b, final int off) { + final int size = 1; + final int len = b.length - off; + if (len >= size) { + return b[off]; + } else { + throw new ParseException("Not enough bytes to parse a byte."); } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/ProcessingSummary.java b/src/main/java/com/amazonaws/encryptionsdk/internal/ProcessingSummary.java index 91a3c7d98..b71e26fff 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/ProcessingSummary.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/ProcessingSummary.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -14,21 +14,21 @@ package com.amazonaws.encryptionsdk.internal; public class ProcessingSummary { - public static final ProcessingSummary ZERO = new ProcessingSummary(0, 0); + public static final ProcessingSummary ZERO = new ProcessingSummary(0, 0); - private final int bytesWritten; - private final int bytesProcessed; + private final int bytesWritten; + private final int bytesProcessed; - public ProcessingSummary(final int bytesWritten, final int bytesProcessed) { - this.bytesWritten = bytesWritten; - this.bytesProcessed = bytesProcessed; - } + public ProcessingSummary(final int bytesWritten, final int bytesProcessed) { + this.bytesWritten = bytesWritten; + this.bytesProcessed = bytesProcessed; + } - public int getBytesProcessed() { - return bytesProcessed; - } + public int getBytesProcessed() { + return bytesProcessed; + } - public int getBytesWritten() { - return bytesWritten; - } + public int getBytesWritten() { + return bytesWritten; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/RsaJceKeyCipher.java b/src/main/java/com/amazonaws/encryptionsdk/internal/RsaJceKeyCipher.java index c830f5487..233cffe37 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/RsaJceKeyCipher.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/RsaJceKeyCipher.java @@ -13,11 +13,6 @@ package com.amazonaws.encryptionsdk.internal; -import org.apache.commons.lang3.ArrayUtils; - -import javax.crypto.Cipher; -import javax.crypto.spec.OAEPParameterSpec; -import javax.crypto.spec.PSource; import java.security.GeneralSecurityException; import java.security.Key; import java.security.PrivateKey; @@ -28,82 +23,88 @@ import java.util.logging.Logger; import java.util.regex.Matcher; import java.util.regex.Pattern; +import javax.crypto.Cipher; +import javax.crypto.spec.OAEPParameterSpec; +import javax.crypto.spec.PSource; +import org.apache.commons.lang3.ArrayUtils; -/** - * A JceKeyCipher based on RSA. - */ +/** A JceKeyCipher based on RSA. */ class RsaJceKeyCipher extends JceKeyCipher { - private static final Logger LOGGER = Logger.getLogger(RsaJceKeyCipher.class.getName()); - // MGF1 with SHA-224 isn't really supported, but we include it in the regex because we need it - // for proper handling of the algorithm. - private static final Pattern SUPPORTED_TRANSFORMATIONS = - Pattern.compile("RSA/ECB/(?:PKCS1Padding|OAEPWith(SHA-(?:1|224|256|384|512))AndMGF1Padding)", - Pattern.CASE_INSENSITIVE); - private final AlgorithmParameterSpec parameterSpec_; - private final String transformation_; + private static final Logger LOGGER = Logger.getLogger(RsaJceKeyCipher.class.getName()); + // MGF1 with SHA-224 isn't really supported, but we include it in the regex because we need it + // for proper handling of the algorithm. + private static final Pattern SUPPORTED_TRANSFORMATIONS = + Pattern.compile( + "RSA/ECB/(?:PKCS1Padding|OAEPWith(SHA-(?:1|224|256|384|512))AndMGF1Padding)", + Pattern.CASE_INSENSITIVE); + private final AlgorithmParameterSpec parameterSpec_; + private final String transformation_; - RsaJceKeyCipher(PublicKey wrappingKey, PrivateKey unwrappingKey, String transformation) { - super(wrappingKey, unwrappingKey); + RsaJceKeyCipher(PublicKey wrappingKey, PrivateKey unwrappingKey, String transformation) { + super(wrappingKey, unwrappingKey); - final Matcher matcher = SUPPORTED_TRANSFORMATIONS.matcher(transformation); - if (matcher.matches()) { - final String hashUnknownCase = matcher.group(1); - if (hashUnknownCase != null) { - // OAEP mode a.k.a PKCS #1v2 - final String hash = hashUnknownCase.toUpperCase(); - transformation_ = "RSA/ECB/OAEPPadding"; + final Matcher matcher = SUPPORTED_TRANSFORMATIONS.matcher(transformation); + if (matcher.matches()) { + final String hashUnknownCase = matcher.group(1); + if (hashUnknownCase != null) { + // OAEP mode a.k.a PKCS #1v2 + final String hash = hashUnknownCase.toUpperCase(); + transformation_ = "RSA/ECB/OAEPPadding"; - final MGF1ParameterSpec mgf1Spec; - switch (hash) { - case "SHA-1": - mgf1Spec = MGF1ParameterSpec.SHA1; - break; - case "SHA-224": - LOGGER.warning(transformation + " is not officially supported by the JceMasterKey"); - mgf1Spec = MGF1ParameterSpec.SHA224; - break; - case "SHA-256": - mgf1Spec = MGF1ParameterSpec.SHA256; - break; - case "SHA-384": - mgf1Spec = MGF1ParameterSpec.SHA384; - break; - case "SHA-512": - mgf1Spec = MGF1ParameterSpec.SHA512; - break; - default: - throw new IllegalArgumentException("Unsupported algorithm: " + transformation); - } - parameterSpec_ = new OAEPParameterSpec(hash, "MGF1", mgf1Spec, PSource.PSpecified.DEFAULT); - } else { - // PKCS #1 v1.x - transformation_ = transformation; - parameterSpec_ = null; - } - } else { + final MGF1ParameterSpec mgf1Spec; + switch (hash) { + case "SHA-1": + mgf1Spec = MGF1ParameterSpec.SHA1; + break; + case "SHA-224": LOGGER.warning(transformation + " is not officially supported by the JceMasterKey"); - // Unsupported transformation, just use exactly what we are given - transformation_ = transformation; - parameterSpec_ = null; + mgf1Spec = MGF1ParameterSpec.SHA224; + break; + case "SHA-256": + mgf1Spec = MGF1ParameterSpec.SHA256; + break; + case "SHA-384": + mgf1Spec = MGF1ParameterSpec.SHA384; + break; + case "SHA-512": + mgf1Spec = MGF1ParameterSpec.SHA512; + break; + default: + throw new IllegalArgumentException("Unsupported algorithm: " + transformation); } + parameterSpec_ = new OAEPParameterSpec(hash, "MGF1", mgf1Spec, PSource.PSpecified.DEFAULT); + } else { + // PKCS #1 v1.x + transformation_ = transformation; + parameterSpec_ = null; + } + } else { + LOGGER.warning(transformation + " is not officially supported by the JceMasterKey"); + // Unsupported transformation, just use exactly what we are given + transformation_ = transformation; + parameterSpec_ = null; } + } - @Override - WrappingData buildWrappingCipher(Key key, Map encryptionContext) throws GeneralSecurityException { - final Cipher cipher = Cipher.getInstance(transformation_); - cipher.init(Cipher.ENCRYPT_MODE, key, parameterSpec_); - return new WrappingData(cipher, ArrayUtils.EMPTY_BYTE_ARRAY); - } + @Override + WrappingData buildWrappingCipher(Key key, Map encryptionContext) + throws GeneralSecurityException { + final Cipher cipher = Cipher.getInstance(transformation_); + cipher.init(Cipher.ENCRYPT_MODE, key, parameterSpec_); + return new WrappingData(cipher, ArrayUtils.EMPTY_BYTE_ARRAY); + } - @Override - Cipher buildUnwrappingCipher(Key key, byte[] extraInfo, int offset, Map encryptionContext) throws GeneralSecurityException { - if (extraInfo.length != offset) { - throw new IllegalArgumentException("Extra info must be empty for RSA keys"); - } - - final Cipher cipher = Cipher.getInstance(transformation_); - cipher.init(Cipher.DECRYPT_MODE, key, parameterSpec_); - return cipher; + @Override + Cipher buildUnwrappingCipher( + Key key, byte[] extraInfo, int offset, Map encryptionContext) + throws GeneralSecurityException { + if (extraInfo.length != offset) { + throw new IllegalArgumentException("Extra info must be empty for RSA keys"); } + + final Cipher cipher = Cipher.getInstance(transformation_); + cipher.init(Cipher.DECRYPT_MODE, key, parameterSpec_); + return cipher; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/SignaturePolicy.java b/src/main/java/com/amazonaws/encryptionsdk/internal/SignaturePolicy.java index 767cee6d9..3e949f8d9 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/SignaturePolicy.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/SignaturePolicy.java @@ -3,18 +3,18 @@ import com.amazonaws.encryptionsdk.CryptoAlgorithm; public enum SignaturePolicy { - AllowEncryptAllowDecrypt { - @Override - public boolean algorithmAllowedForDecrypt(CryptoAlgorithm algorithm) { - return true; - } - }, - AllowEncryptForbidDecrypt { - @Override - public boolean algorithmAllowedForDecrypt(CryptoAlgorithm algorithm) { - return algorithm.getTrailingSignatureLength() == 0; - } - }; + AllowEncryptAllowDecrypt { + @Override + public boolean algorithmAllowedForDecrypt(CryptoAlgorithm algorithm) { + return true; + } + }, + AllowEncryptForbidDecrypt { + @Override + public boolean algorithmAllowedForDecrypt(CryptoAlgorithm algorithm) { + return algorithm.getTrailingSignatureLength() == 0; + } + }; - public abstract boolean algorithmAllowedForDecrypt(CryptoAlgorithm algorithm); + public abstract boolean algorithmAllowedForDecrypt(CryptoAlgorithm algorithm); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/TrailingSignatureAlgorithm.java b/src/main/java/com/amazonaws/encryptionsdk/internal/TrailingSignatureAlgorithm.java index ee9b1b1e9..875ea9a21 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/TrailingSignatureAlgorithm.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/TrailingSignatureAlgorithm.java @@ -1,5 +1,13 @@ package com.amazonaws.encryptionsdk.internal; +import static com.amazonaws.encryptionsdk.internal.Utils.bigIntegerToByteArray; +import static com.amazonaws.encryptionsdk.internal.Utils.encodeBase64String; +import static java.math.BigInteger.ONE; +import static java.math.BigInteger.ZERO; +import static org.apache.commons.lang3.Validate.isInstanceOf; +import static org.apache.commons.lang3.Validate.notNull; + +import com.amazonaws.encryptionsdk.CryptoAlgorithm; import java.math.BigInteger; import java.security.AlgorithmParameters; import java.security.GeneralSecurityException; @@ -18,193 +26,190 @@ import java.security.spec.InvalidParameterSpecException; import java.util.Arrays; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; - -import static com.amazonaws.encryptionsdk.internal.Utils.bigIntegerToByteArray; -import static com.amazonaws.encryptionsdk.internal.Utils.encodeBase64String; -import static java.math.BigInteger.ONE; -import static java.math.BigInteger.ZERO; -import static org.apache.commons.lang3.Validate.isInstanceOf; -import static org.apache.commons.lang3.Validate.notNull; - /** * Provides a consistent interface across various trailing signature algorithms. * - * NOTE: This is not a stable API and may undergo breaking changes in the future. + *

NOTE: This is not a stable API and may undergo breaking changes in the future. */ public abstract class TrailingSignatureAlgorithm { - private TrailingSignatureAlgorithm() { - /* Do not allow arbitrary subclasses */ + private TrailingSignatureAlgorithm() { + /* Do not allow arbitrary subclasses */ + } + + public abstract String getMessageDigestAlgorithm(); + + public abstract String getRawSignatureAlgorithm(); + + public abstract String getHashAndSignAlgorithm(); + + public abstract PublicKey deserializePublicKey(String keyString); + + public abstract String serializePublicKey(PublicKey key); + + public abstract KeyPair generateKey() throws GeneralSecurityException; + + /* Standards for Efficient Cryptography over a prime field */ + private static final String SEC_PRIME_FIELD_PREFIX = "secp"; + + private static final class ECDSASignatureAlgorithm extends TrailingSignatureAlgorithm { + private final ECGenParameterSpec ecSpec; + private final ECParameterSpec ecParameterSpec; + private final String messageDigestAlgorithm; + private final String hashAndSignAlgorithm; + private static final String ELLIPTIC_CURVE_ALGORITHM = "EC"; + /* Constants used by SEC-1 v2 point compression and decompression algorithms */ + private static final BigInteger TWO = BigInteger.valueOf(2); + private static final BigInteger THREE = BigInteger.valueOf(3); + private static final BigInteger FOUR = BigInteger.valueOf(4); + + private ECDSASignatureAlgorithm( + ECGenParameterSpec ecSpec, String messageDigestAlgorithm, String hashAndSignAlgorithm) { + if (!ecSpec.getName().startsWith(SEC_PRIME_FIELD_PREFIX)) { + throw new IllegalStateException("Non-prime curves are not supported at this time"); + } + + this.ecSpec = ecSpec; + this.messageDigestAlgorithm = messageDigestAlgorithm; + this.hashAndSignAlgorithm = hashAndSignAlgorithm; + + try { + final AlgorithmParameters parameters = + AlgorithmParameters.getInstance(ELLIPTIC_CURVE_ALGORITHM); + parameters.init(ecSpec); + this.ecParameterSpec = parameters.getParameterSpec(ECParameterSpec.class); + } catch (NoSuchAlgorithmException | InvalidParameterSpecException e) { + throw new IllegalStateException("Invalid algorithm", e); + } } - public abstract String getMessageDigestAlgorithm(); - public abstract String getRawSignatureAlgorithm(); - public abstract String getHashAndSignAlgorithm(); - public abstract PublicKey deserializePublicKey(String keyString); - public abstract String serializePublicKey(PublicKey key); - public abstract KeyPair generateKey() throws GeneralSecurityException; - - /* Standards for Efficient Cryptography over a prime field */ - private static final String SEC_PRIME_FIELD_PREFIX = "secp"; - - private static final class ECDSASignatureAlgorithm extends TrailingSignatureAlgorithm { - private final ECGenParameterSpec ecSpec; - private final ECParameterSpec ecParameterSpec; - private final String messageDigestAlgorithm; - private final String hashAndSignAlgorithm; - private static final String ELLIPTIC_CURVE_ALGORITHM = "EC"; - /* Constants used by SEC-1 v2 point compression and decompression algorithms */ - private static final BigInteger TWO = BigInteger.valueOf(2); - private static final BigInteger THREE = BigInteger.valueOf(3); - private static final BigInteger FOUR = BigInteger.valueOf(4); - - private ECDSASignatureAlgorithm(ECGenParameterSpec ecSpec, String messageDigestAlgorithm, String hashAndSignAlgorithm) { - if (!ecSpec.getName().startsWith(SEC_PRIME_FIELD_PREFIX)) { - throw new IllegalStateException("Non-prime curves are not supported at this time"); - } - - this.ecSpec = ecSpec; - this.messageDigestAlgorithm = messageDigestAlgorithm; - this.hashAndSignAlgorithm = hashAndSignAlgorithm; - - try { - final AlgorithmParameters parameters = AlgorithmParameters.getInstance(ELLIPTIC_CURVE_ALGORITHM); - parameters.init(ecSpec); - this.ecParameterSpec = parameters.getParameterSpec(ECParameterSpec.class); - } catch (NoSuchAlgorithmException | InvalidParameterSpecException e) { - throw new IllegalStateException("Invalid algorithm", e); - } - } - - @Override - public String toString() { - return "ECDSASignatureAlgorithm(curve=" + ecSpec.getName() + ")"; - } - - @Override - public String getMessageDigestAlgorithm() { - return messageDigestAlgorithm; - } - - @Override - public String getRawSignatureAlgorithm() { - return "NONEwithECDSA"; - } - - @Override - public String getHashAndSignAlgorithm() { - return hashAndSignAlgorithm; - } - - /** - * Decodes a compressed elliptic curve point as described in SEC-1 v2 section 2.3.4 - * - * @param keyString The serialized and compressed public key - * @return The PublicKey - * @see http://www.secg.org/sec1-v2.pdf - */ - @Override - public PublicKey deserializePublicKey(String keyString) { - notNull(keyString, "keyString is required"); - - final byte[] decodedKey = Utils.decodeBase64String(keyString); - final BigInteger x = new BigInteger(1, Arrays.copyOfRange(decodedKey, 1, decodedKey.length)); - - final byte compressedY = decodedKey[0]; - final BigInteger yOrder; - - if (compressedY == TWO.byteValue()) { - yOrder = ZERO; - } else if (compressedY == THREE.byteValue()) { - yOrder = ONE; - } else { - throw new IllegalArgumentException("Compressed y value was invalid"); - } - - final BigInteger p = ((ECFieldFp) ecParameterSpec.getCurve().getField()).getP(); - final BigInteger a = ecParameterSpec.getCurve().getA(); - final BigInteger b = ecParameterSpec.getCurve().getB(); - - //alpha must be equal to y^2, this is validated below - final BigInteger alpha = x.modPow(THREE, p) - .add(a.multiply(x).mod(p)) - .add(b) - .mod(p); - - final BigInteger beta; - if (p.mod(FOUR).equals(THREE)) { - beta = alpha.modPow(p.add(ONE).divide(FOUR), p); - } else { - throw new IllegalArgumentException("Curve not supported at this time"); - } - - final BigInteger y = beta.mod(TWO).equals(yOrder) ? beta : p.subtract(beta); - - //Validate that Y is a root of Y^2 to prevent invalid point attacks - if (!alpha.equals(y.modPow(TWO, p))) { - throw new IllegalArgumentException("Y was invalid"); - } - - try { - return KeyFactory.getInstance(ELLIPTIC_CURVE_ALGORITHM).generatePublic( - new ECPublicKeySpec(new ECPoint(x, y), ecParameterSpec)); - } catch (InvalidKeySpecException | NoSuchAlgorithmException e) { - throw new IllegalStateException("Invalid algorithm", e); - } - } - - /** - * Encodes a compressed elliptic curve point as described in SEC-1 v2 section 2.3.3 - * - * @param key The Elliptic Curve public key to compress and serialize - * @return The serialized and compressed public key - * @see http://www.secg.org/sec1-v2.pdf - */ - @Override - public String serializePublicKey(PublicKey key) { - notNull(key, "key is required"); - isInstanceOf(ECPublicKey.class, key, "key must be an instance of ECPublicKey"); - - final BigInteger x = ((ECPublicKey) key).getW().getAffineX(); - final BigInteger y = ((ECPublicKey) key).getW().getAffineY(); - final BigInteger compressedY = y.mod(TWO).equals(ZERO) ? TWO : THREE; - - final byte[] xBytes = bigIntegerToByteArray(x, - ecParameterSpec.getCurve().getField().getFieldSize() / Byte.SIZE); - - final byte[] compressedKey = new byte[xBytes.length + 1]; - System.arraycopy(xBytes, 0, compressedKey, 1, xBytes.length); - compressedKey[0] = compressedY.byteValue(); - - return encodeBase64String(compressedKey); - } - - @Override - public KeyPair generateKey() throws GeneralSecurityException { - KeyPairGenerator keyGen = KeyPairGenerator.getInstance(ELLIPTIC_CURVE_ALGORITHM); - keyGen.initialize(ecSpec, Utils.getSecureRandom()); - - return keyGen.generateKeyPair(); - } + @Override + public String toString() { + return "ECDSASignatureAlgorithm(curve=" + ecSpec.getName() + ")"; } - private static final ECDSASignatureAlgorithm SHA256_ECDSA_P256 - = new ECDSASignatureAlgorithm(new ECGenParameterSpec(SEC_PRIME_FIELD_PREFIX + "256r1"), "SHA-256", "SHA256withECDSA"); - private static final ECDSASignatureAlgorithm SHA384_ECDSA_P384 - = new ECDSASignatureAlgorithm(new ECGenParameterSpec(SEC_PRIME_FIELD_PREFIX + "384r1"), "SHA-384", "SHA384withECDSA"); - - public static TrailingSignatureAlgorithm forCryptoAlgorithm(CryptoAlgorithm algorithm) { - switch (algorithm) { - case ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256: - return SHA256_ECDSA_P256; - case ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: - case ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: - case ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384: - return SHA384_ECDSA_P384; - default: - throw new IllegalStateException("Algorithm does not support trailing signature"); - } + @Override + public String getMessageDigestAlgorithm() { + return messageDigestAlgorithm; + } + + @Override + public String getRawSignatureAlgorithm() { + return "NONEwithECDSA"; } -} + @Override + public String getHashAndSignAlgorithm() { + return hashAndSignAlgorithm; + } + + /** + * Decodes a compressed elliptic curve point as described in SEC-1 v2 section 2.3.4 + * + * @param keyString The serialized and compressed public key + * @return The PublicKey + * @see http://www.secg.org/sec1-v2.pdf + */ + @Override + public PublicKey deserializePublicKey(String keyString) { + notNull(keyString, "keyString is required"); + + final byte[] decodedKey = Utils.decodeBase64String(keyString); + final BigInteger x = new BigInteger(1, Arrays.copyOfRange(decodedKey, 1, decodedKey.length)); + + final byte compressedY = decodedKey[0]; + final BigInteger yOrder; + + if (compressedY == TWO.byteValue()) { + yOrder = ZERO; + } else if (compressedY == THREE.byteValue()) { + yOrder = ONE; + } else { + throw new IllegalArgumentException("Compressed y value was invalid"); + } + + final BigInteger p = ((ECFieldFp) ecParameterSpec.getCurve().getField()).getP(); + final BigInteger a = ecParameterSpec.getCurve().getA(); + final BigInteger b = ecParameterSpec.getCurve().getB(); + + // alpha must be equal to y^2, this is validated below + final BigInteger alpha = x.modPow(THREE, p).add(a.multiply(x).mod(p)).add(b).mod(p); + + final BigInteger beta; + if (p.mod(FOUR).equals(THREE)) { + beta = alpha.modPow(p.add(ONE).divide(FOUR), p); + } else { + throw new IllegalArgumentException("Curve not supported at this time"); + } + + final BigInteger y = beta.mod(TWO).equals(yOrder) ? beta : p.subtract(beta); + + // Validate that Y is a root of Y^2 to prevent invalid point attacks + if (!alpha.equals(y.modPow(TWO, p))) { + throw new IllegalArgumentException("Y was invalid"); + } + + try { + return KeyFactory.getInstance(ELLIPTIC_CURVE_ALGORITHM) + .generatePublic(new ECPublicKeySpec(new ECPoint(x, y), ecParameterSpec)); + } catch (InvalidKeySpecException | NoSuchAlgorithmException e) { + throw new IllegalStateException("Invalid algorithm", e); + } + } + + /** + * Encodes a compressed elliptic curve point as described in SEC-1 v2 section 2.3.3 + * + * @param key The Elliptic Curve public key to compress and serialize + * @return The serialized and compressed public key + * @see http://www.secg.org/sec1-v2.pdf + */ + @Override + public String serializePublicKey(PublicKey key) { + notNull(key, "key is required"); + isInstanceOf(ECPublicKey.class, key, "key must be an instance of ECPublicKey"); + + final BigInteger x = ((ECPublicKey) key).getW().getAffineX(); + final BigInteger y = ((ECPublicKey) key).getW().getAffineY(); + final BigInteger compressedY = y.mod(TWO).equals(ZERO) ? TWO : THREE; + + final byte[] xBytes = + bigIntegerToByteArray( + x, ecParameterSpec.getCurve().getField().getFieldSize() / Byte.SIZE); + + final byte[] compressedKey = new byte[xBytes.length + 1]; + System.arraycopy(xBytes, 0, compressedKey, 1, xBytes.length); + compressedKey[0] = compressedY.byteValue(); + + return encodeBase64String(compressedKey); + } + + @Override + public KeyPair generateKey() throws GeneralSecurityException { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance(ELLIPTIC_CURVE_ALGORITHM); + keyGen.initialize(ecSpec, Utils.getSecureRandom()); + + return keyGen.generateKeyPair(); + } + } + + private static final ECDSASignatureAlgorithm SHA256_ECDSA_P256 = + new ECDSASignatureAlgorithm( + new ECGenParameterSpec(SEC_PRIME_FIELD_PREFIX + "256r1"), "SHA-256", "SHA256withECDSA"); + private static final ECDSASignatureAlgorithm SHA384_ECDSA_P384 = + new ECDSASignatureAlgorithm( + new ECGenParameterSpec(SEC_PRIME_FIELD_PREFIX + "384r1"), "SHA-384", "SHA384withECDSA"); + + public static TrailingSignatureAlgorithm forCryptoAlgorithm(CryptoAlgorithm algorithm) { + switch (algorithm) { + case ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256: + return SHA256_ECDSA_P256; + case ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: + case ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: + case ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384: + return SHA384_ECDSA_P384; + default: + throw new IllegalStateException("Algorithm does not support trailing signature"); + } + } +} diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/Utils.java b/src/main/java/com/amazonaws/encryptionsdk/internal/Utils.java index 0c76d1e2d..815d785aa 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/Utils.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/Utils.java @@ -13,313 +13,314 @@ import java.util.Comparator; import java.util.WeakHashMap; import java.util.concurrent.atomic.AtomicLong; - import org.apache.commons.lang3.ArrayUtils; import org.bouncycastle.util.encoders.Base64; -/** - * Internal utility methods. - */ +/** Internal utility methods. */ public final class Utils { - // SecureRandom objects can both be expensive to initialize and incur synchronization costs. - // This allows us to minimize both initializations and keep SecureRandom usage thread local - // to avoid lock contention. - private static final ThreadLocal LOCAL_RANDOM = new ThreadLocal() { - @Override - protected SecureRandom initialValue() { + // SecureRandom objects can both be expensive to initialize and incur synchronization costs. + // This allows us to minimize both initializations and keep SecureRandom usage thread local + // to avoid lock contention. + private static final ThreadLocal LOCAL_RANDOM = + new ThreadLocal() { + @Override + protected SecureRandom initialValue() { final SecureRandom rnd = new SecureRandom(); rnd.nextBoolean(); // Force seeding return rnd; - } - }; - - private Utils() { - // Prevent instantiation + } + }; + + private Utils() { + // Prevent instantiation + } + + /* + * In some areas we need to be able to assign a total order over Java objects - generally with some primary sort, + * but we need a fallback sort that always works in order to ensure that we don't falsely claim objects A and B + * are equal just because the primary sort declares them to have equal rank. + * + * To do this, we'll define a fallback sort that assigns an arbitrary order to all objects. This order is + * implemented by first comparing hashcode, and in the rare case where we are asked to compare two objects with + * equal hashcode, we explicitly assign an index to them - using a WeakHashMap to track this index - and sort + * based on this index. + */ + private static AtomicLong FALLBACK_COUNTER = new AtomicLong(0); + private static WeakHashMap FALLBACK_COMPARATOR_MAP = new WeakHashMap<>(); + + private static synchronized long getFallbackObjectId(Object object) { + return FALLBACK_COMPARATOR_MAP.computeIfAbsent( + object, ignored -> FALLBACK_COUNTER.incrementAndGet()); + } + + /** + * Provides an arbitrary but consistent total ordering over all objects. This comparison + * function will return 0 if and only if a == b, and otherwise will return arbitrarily either -1 + * or 1, but will do so in a way that results in a consistent total order. + * + * @param a + * @param b + * @return -1 or 1 (consistently) if a != b; 0 if a == b. + */ + public static int compareObjectIdentity(Object a, Object b) { + if (a == b) { + return 0; } - /* - * In some areas we need to be able to assign a total order over Java objects - generally with some primary sort, - * but we need a fallback sort that always works in order to ensure that we don't falsely claim objects A and B - * are equal just because the primary sort declares them to have equal rank. - * - * To do this, we'll define a fallback sort that assigns an arbitrary order to all objects. This order is - * implemented by first comparing hashcode, and in the rare case where we are asked to compare two objects with - * equal hashcode, we explicitly assign an index to them - using a WeakHashMap to track this index - and sort - * based on this index. - */ - private static AtomicLong FALLBACK_COUNTER = new AtomicLong(0); - private static WeakHashMap FALLBACK_COMPARATOR_MAP = new WeakHashMap<>(); - - private static synchronized long getFallbackObjectId(Object object) { - return FALLBACK_COMPARATOR_MAP.computeIfAbsent(object, ignored -> FALLBACK_COUNTER.incrementAndGet()); + if (a == null) { + return -1; } - /** - * Provides an arbitrary but consistent total ordering over all objects. This comparison function will - * return 0 if and only if a == b, and otherwise will return arbitrarily either -1 or 1, but will do so in a way - * that results in a consistent total order. - * - * @param a - * @param b - * @return -1 or 1 (consistently) if a != b; 0 if a == b. - */ - public static int compareObjectIdentity(Object a, Object b) { - if (a == b) { - return 0; - } - - if (a == null) { - return -1; - } - - if (b == null) { - return 1; - } - - int hashCompare = Integer.compare(System.identityHashCode(a), System.identityHashCode(b)); - if (hashCompare != 0) { - return hashCompare; - } - - // Unfortunately these objects have identical hashcodes, so we need to find some other way to compare them. - // We'll do this by mapping them to an incrementing counter, and comparing their assigned IDs instead. - int fallbackCompare = Long.compare(getFallbackObjectId(a), getFallbackObjectId(b)); - if (fallbackCompare == 0) { - throw new AssertionError("Failed to assign unique order to objects"); - } - - return fallbackCompare; + if (b == null) { + return 1; } - public static long saturatingAdd(long a, long b) { - long r = a + b; + int hashCompare = Integer.compare(System.identityHashCode(a), System.identityHashCode(b)); + if (hashCompare != 0) { + return hashCompare; + } - if (a > 0 && b > 0 && r < a) { - return Long.MAX_VALUE; - } + // Unfortunately these objects have identical hashcodes, so we need to find some other way to + // compare them. + // We'll do this by mapping them to an incrementing counter, and comparing their assigned IDs + // instead. + int fallbackCompare = Long.compare(getFallbackObjectId(a), getFallbackObjectId(b)); + if (fallbackCompare == 0) { + throw new AssertionError("Failed to assign unique order to objects"); + } - if (a < 0 && b < 0 && r > a) { - return Long.MIN_VALUE; - } + return fallbackCompare; + } - // If the signs between a and b differ, overflow is impossible. + public static long saturatingAdd(long a, long b) { + long r = a + b; - return r; + if (a > 0 && b > 0 && r < a) { + return Long.MAX_VALUE; } - /** - * Comparator that performs a lexicographical comparison of byte arrays, treating them as unsigned. - */ - public static class ComparingByteArrays implements Comparator, Serializable { - // We don't really need to be serializable, but it doesn't hurt, and FindBugs gets annoyed if we're not. - private static final long serialVersionUID = 0xdf641037ffe509e2L; - - @Override public int compare(byte[] o1, byte[] o2) { - return new ComparingByteBuffers().compare(ByteBuffer.wrap(o1), ByteBuffer.wrap(o2)); - } + if (a < 0 && b < 0 && r > a) { + return Long.MIN_VALUE; } - public static class ComparingByteBuffers implements Comparator, Serializable { - private static final long serialVersionUID = 0xa3c4a7300fbbf043L; - - @Override public int compare(ByteBuffer o1, ByteBuffer o2) { - o1 = o1.slice(); - o2 = o2.slice(); + // If the signs between a and b differ, overflow is impossible. - int commonLength = Math.min(o1.remaining(), o2.remaining()); + return r; + } - for (int i = 0; i < commonLength; i++) { - // Perform zero-extension as we want to treat the bytes as unsigned - int v1 = o1.get(i) & 0xFF; - int v2 = o2.get(i) & 0xFF; + /** + * Comparator that performs a lexicographical comparison of byte arrays, treating them as + * unsigned. + */ + public static class ComparingByteArrays implements Comparator, Serializable { + // We don't really need to be serializable, but it doesn't hurt, and FindBugs gets annoyed if + // we're not. + private static final long serialVersionUID = 0xdf641037ffe509e2L; - if (v1 != v2) { - return v1 - v2; - } - } - - // The longer buffer is bigger (0x00 comes after end-of-buffer) - return o1.remaining() - o2.remaining(); - } + @Override + public int compare(byte[] o1, byte[] o2) { + return new ComparingByteBuffers().compare(ByteBuffer.wrap(o1), ByteBuffer.wrap(o2)); } + } - /** - * Throws {@link NullPointerException} with message {@code paramName} if {@code object} is null. - * - * @param object - * value to be null-checked - * @param paramName - * message for the potential {@link NullPointerException} - * @return {@code object} - * @throws NullPointerException - * if {@code object} is null - */ - public static T assertNonNull(final T object, final String paramName) throws NullPointerException { - if (object == null) { - throw new NullPointerException(paramName + " must not be null"); - } - return object; - } + public static class ComparingByteBuffers implements Comparator, Serializable { + private static final long serialVersionUID = 0xa3c4a7300fbbf043L; - /** - * Returns a possibly truncated version of {@code arr} which is guaranteed to be exactly - * {@code len} elements long. If {@code arr} is already exactly {@code len} elements long, then - * {@code arr} is returned without copy or modification. If {@code arr} is longer than - * {@code len}, then a truncated copy is returned. If {@code arr} is shorter than {@code len} - * then this throws an {@link IllegalArgumentException}. - */ - public static byte[] truncate(final byte[] arr, final int len) throws IllegalArgumentException { - if (arr.length == len) { - return arr; - } else if (arr.length > len) { - return Arrays.copyOf(arr, len); - } else { - throw new IllegalArgumentException("arr is not at least " + len + " elements long"); - } - } + @Override + public int compare(ByteBuffer o1, ByteBuffer o2) { + o1 = o1.slice(); + o2 = o2.slice(); - public static SecureRandom getSecureRandom() { - return LOCAL_RANDOM.get(); - } + int commonLength = Math.min(o1.remaining(), o2.remaining()); - /** - * Generate the AAD bytes to use when encrypting/decrypting content. The - * generated AAD is a block of bytes containing the provided message - * identifier, the string identifier, the sequence number, and the length of - * the content. - * - * @param messageId - * the unique message identifier for the ciphertext. - * @param idString - * the string describing the type of content processed. - * @param seqNum - * the sequence number. - * @param len - * the length of the content. - * @return - * the bytes containing the generated AAD. - */ - static byte[] generateContentAad(final byte[] messageId, final String idString, final int seqNum, final long len) { - final byte[] idBytes = idString.getBytes(StandardCharsets.UTF_8); - final int aadLen = messageId.length + idBytes.length + Integer.SIZE / Byte.SIZE + Long.SIZE / Byte.SIZE; - final ByteBuffer aad = ByteBuffer.allocate(aadLen); - - aad.put(messageId); - aad.put(idBytes); - aad.putInt(seqNum); - aad.putLong(len); - - return aad.array(); - } + for (int i = 0; i < commonLength; i++) { + // Perform zero-extension as we want to treat the bytes as unsigned + int v1 = o1.get(i) & 0xFF; + int v2 = o2.get(i) & 0xFF; - static IllegalArgumentException cannotBeNegative(String field) { - return new IllegalArgumentException(field + " cannot be negative"); - } + if (v1 != v2) { + return v1 - v2; + } + } - /** - * Equivalent to calling {@link ByteBuffer#flip()} but in a manner which is - * safe when compiled on Java 9 or newer but used on Java 8 or older. - */ - public static ByteBuffer flip(final ByteBuffer buff) { - ((Buffer) buff).flip(); - return buff; + // The longer buffer is bigger (0x00 comes after end-of-buffer) + return o1.remaining() - o2.remaining(); } - - /** - * Equivalent to calling {@link ByteBuffer#clear()} but in a manner which is - * safe when compiled on Java 9 or newer but used on Java 8 or older. - */ - public static ByteBuffer clear(final ByteBuffer buff) { - ((Buffer) buff).clear(); - return buff; + } + + /** + * Throws {@link NullPointerException} with message {@code paramName} if {@code object} is null. + * + * @param object value to be null-checked + * @param paramName message for the potential {@link NullPointerException} + * @return {@code object} + * @throws NullPointerException if {@code object} is null + */ + public static T assertNonNull(final T object, final String paramName) + throws NullPointerException { + if (object == null) { + throw new NullPointerException(paramName + " must not be null"); } - - /** - * Equivalent to calling {@link ByteBuffer#position(int)} but in a manner which is - * safe when compiled on Java 9 or newer but used on Java 8 or older. - */ - public static ByteBuffer position(final ByteBuffer buff, final int newPosition) { - ((Buffer) buff).position(newPosition); - return buff; + return object; + } + + /** + * Returns a possibly truncated version of {@code arr} which is guaranteed to be exactly {@code + * len} elements long. If {@code arr} is already exactly {@code len} elements long, then {@code + * arr} is returned without copy or modification. If {@code arr} is longer than {@code len}, then + * a truncated copy is returned. If {@code arr} is shorter than {@code len} then this throws an + * {@link IllegalArgumentException}. + */ + public static byte[] truncate(final byte[] arr, final int len) throws IllegalArgumentException { + if (arr.length == len) { + return arr; + } else if (arr.length > len) { + return Arrays.copyOf(arr, len); + } else { + throw new IllegalArgumentException("arr is not at least " + len + " elements long"); } - - /** - * Equivalent to calling {@link ByteBuffer#limit(int)} but in a manner which is - * safe when compiled on Java 9 or newer but used on Java 8 or older. - */ - public static ByteBuffer limit(final ByteBuffer buff, final int newLimit) { - ((Buffer) buff).limit(newLimit); - return buff; + } + + public static SecureRandom getSecureRandom() { + return LOCAL_RANDOM.get(); + } + + /** + * Generate the AAD bytes to use when encrypting/decrypting content. The generated AAD is a block + * of bytes containing the provided message identifier, the string identifier, the sequence + * number, and the length of the content. + * + * @param messageId the unique message identifier for the ciphertext. + * @param idString the string describing the type of content processed. + * @param seqNum the sequence number. + * @param len the length of the content. + * @return the bytes containing the generated AAD. + */ + static byte[] generateContentAad( + final byte[] messageId, final String idString, final int seqNum, final long len) { + final byte[] idBytes = idString.getBytes(StandardCharsets.UTF_8); + final int aadLen = + messageId.length + idBytes.length + Integer.SIZE / Byte.SIZE + Long.SIZE / Byte.SIZE; + final ByteBuffer aad = ByteBuffer.allocate(aadLen); + + aad.put(messageId); + aad.put(idBytes); + aad.putInt(seqNum); + aad.putLong(len); + + return aad.array(); + } + + static IllegalArgumentException cannotBeNegative(String field) { + return new IllegalArgumentException(field + " cannot be negative"); + } + + /** + * Equivalent to calling {@link ByteBuffer#flip()} but in a manner which is safe when compiled on + * Java 9 or newer but used on Java 8 or older. + */ + public static ByteBuffer flip(final ByteBuffer buff) { + ((Buffer) buff).flip(); + return buff; + } + + /** + * Equivalent to calling {@link ByteBuffer#clear()} but in a manner which is safe when compiled on + * Java 9 or newer but used on Java 8 or older. + */ + public static ByteBuffer clear(final ByteBuffer buff) { + ((Buffer) buff).clear(); + return buff; + } + + /** + * Equivalent to calling {@link ByteBuffer#position(int)} but in a manner which is safe when + * compiled on Java 9 or newer but used on Java 8 or older. + */ + public static ByteBuffer position(final ByteBuffer buff, final int newPosition) { + ((Buffer) buff).position(newPosition); + return buff; + } + + /** + * Equivalent to calling {@link ByteBuffer#limit(int)} but in a manner which is safe when compiled + * on Java 9 or newer but used on Java 8 or older. + */ + public static ByteBuffer limit(final ByteBuffer buff, final int newLimit) { + ((Buffer) buff).limit(newLimit); + return buff; + } + + /** + * Takes a Base64-encoded String, decodes it, and returns contents as a byte array. + * + * @param encoded Base64 encoded String + * @return decoded data as a byte array + */ + public static byte[] decodeBase64String(final String encoded) { + return encoded.isEmpty() ? ArrayUtils.EMPTY_BYTE_ARRAY : Base64.decode(encoded); + } + + /** + * Takes data in a byte array, encodes them in Base64, and returns the result as a String. + * + * @param data The data to encode. + * @return Base64 string that encodes the {@code data}. + */ + public static String encodeBase64String(final byte[] data) { + return Base64.toBase64String(data); + } + + /** + * Removes the leading zero sign byte from the byte array representation of a BigInteger (if + * present) and left pads with zeroes to produce a byte array of the given length. + * + * @param bigInteger The BigInteger to convert to a byte array + * @param length The length of the byte array, must be at least as long as the BigInteger byte + * array without the sign byte + * @return The byte array + */ + public static byte[] bigIntegerToByteArray(final BigInteger bigInteger, final int length) { + byte[] rawBytes = bigInteger.toByteArray(); + // If rawBytes is already the correct length, return it. + if (rawBytes.length == length) { + return rawBytes; } - /** - * Takes a Base64-encoded String, decodes it, and returns contents as a byte array. - * - * @param encoded Base64 encoded String - * @return decoded data as a byte array - */ - public static byte[] decodeBase64String(final String encoded) { - return encoded.isEmpty() ? ArrayUtils.EMPTY_BYTE_ARRAY : Base64.decode(encoded); + // If we're exactly one byte too large, but we have a leading zero byte, remove it and return. + if (rawBytes.length == length + 1 && rawBytes[0] == 0) { + return Arrays.copyOfRange(rawBytes, 1, rawBytes.length); } - /** - * Takes data in a byte array, encodes them in Base64, and returns the result as a String. - * - * @param data The data to encode. - * @return Base64 string that encodes the {@code data}. - */ - public static String encodeBase64String(final byte[] data) { - return Base64.toBase64String(data); + if (rawBytes.length > length) { + throw new IllegalArgumentException( + "Length must be at least as long as the BigInteger byte array " + + "without the sign byte"); } - /** - * Removes the leading zero sign byte from the byte array representation of a BigInteger (if present) - * and left pads with zeroes to produce a byte array of the given length. - * @param bigInteger The BigInteger to convert to a byte array - * @param length The length of the byte array, must be at least - * as long as the BigInteger byte array without the sign byte - * @return The byte array - */ - public static byte[] bigIntegerToByteArray(final BigInteger bigInteger, final int length) { - byte[] rawBytes = bigInteger.toByteArray(); - // If rawBytes is already the correct length, return it. - if (rawBytes.length == length) { - return rawBytes; - } - - // If we're exactly one byte too large, but we have a leading zero byte, remove it and return. - if(rawBytes.length == length + 1 && rawBytes[0] == 0) { - return Arrays.copyOfRange(rawBytes, 1, rawBytes.length); - } - - if (rawBytes.length > length) { - throw new IllegalArgumentException("Length must be at least as long as the BigInteger byte array " + - "without the sign byte"); - } - - final byte[] paddedResult = new byte[length]; - System.arraycopy(rawBytes, 0, paddedResult, length - rawBytes.length, rawBytes.length); - return paddedResult; + final byte[] paddedResult = new byte[length]; + System.arraycopy(rawBytes, 0, paddedResult, length - rawBytes.length, rawBytes.length); + return paddedResult; + } + + /** + * Returns true if the prefix of the given length for the input arrays are equal. This method will + * return as soon as the first difference is found, and is thus not constant-time. + * + * @param a The first array. + * @param b The second array. + * @param length The length of the prefix to compare. + * @return True if the prefixes are equal, false otherwise. + */ + public static boolean arrayPrefixEquals(final byte[] a, final byte[] b, final int length) { + if (a == null || b == null || a.length < length || b.length < length) { + return false; } - - /** - * Returns true if the prefix of the given length for the input arrays are equal. - * This method will return as soon as the first difference is found, and is thus not constant-time. - * - * @param a The first array. - * @param b The second array. - * @param length The length of the prefix to compare. - * @return True if the prefixes are equal, false otherwise. - */ - public static boolean arrayPrefixEquals(final byte[] a, final byte[] b, final int length) { - if (a == null || b == null || a.length < length || b.length < length) { - return false; - } - for (int x = 0; x < length; x++) { - if (a[x] != b[x]) { - return false; - } - } - return true; + for (int x = 0; x < length; x++) { + if (a[x] != b[x]) { + return false; + } } + return true; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/VersionInfo.java b/src/main/java/com/amazonaws/encryptionsdk/internal/VersionInfo.java index 496568360..502c3399a 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/VersionInfo.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/VersionInfo.java @@ -1,36 +1,35 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ package com.amazonaws.encryptionsdk.internal; -import java.util.Properties; + import java.io.IOException; +import java.util.Properties; -/** - * This class specifies the versioning system for the AWS KMS encryption client. - */ +/** This class specifies the versioning system for the AWS KMS encryption client. */ public class VersionInfo { - public static final String USER_AGENT_PREFIX = "AwsCrypto/"; - public static final String UNKNOWN_VERSION = "unknown"; - /* - * Loads the version of the library - */ - public static String loadUserAgent() { - try { - final Properties properties = new Properties(); - properties.load(ClassLoader.getSystemResourceAsStream("project.properties")); - return USER_AGENT_PREFIX + properties.getProperty("version"); - } catch (final IOException ex) { - return USER_AGENT_PREFIX + UNKNOWN_VERSION; - } + public static final String USER_AGENT_PREFIX = "AwsCrypto/"; + public static final String UNKNOWN_VERSION = "unknown"; + /* + * Loads the version of the library + */ + public static String loadUserAgent() { + try { + final Properties properties = new Properties(); + properties.load(ClassLoader.getSystemResourceAsStream("project.properties")); + return USER_AGENT_PREFIX + properties.getProperty("version"); + } catch (final IOException ex) { + return USER_AGENT_PREFIX + UNKNOWN_VERSION; } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/internal/package-info.java b/src/main/java/com/amazonaws/encryptionsdk/internal/package-info.java index 46d3fe389..c3f3f34ed 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/internal/package-info.java +++ b/src/main/java/com/amazonaws/encryptionsdk/internal/package-info.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -16,47 +16,29 @@ * algorithms. The package also includes auxiliary classes that implement serialization of * encryption context, parser for deserializing bytes into primitives, and generation of random * bytes. - * - * No classes in this package are intended for public consumption. They - * may be changed at any time without concern for API compatibility. - * + * + *

No classes in this package are intended for public consumption. They may be changed at any + * time without concern for API compatibility. + * *

    - *
  • - * the CryptoHandler interface that defines the contract for the methods that must be implemented by - * classes that perform encryption and decryption in this library. - * - *
  • - * the EncryptionHandler and DecryptionHandler classes handle the creation and parsing of the - * ciphertext headers as described in the message format. These two classes delegate the actual - * encryption and decryption of content to the Block and Frame handlers. - * - *
  • - * the BlockEncryptionHandler and BlockDecryptionHandler classes handle the encryption and - * decryption of content stored as a single-block as described in the message format. - * - *
  • - * the FrameEncryptionHandler and FrameDecryptionHandler classes handle the encryption and - * decryption of content stored as frames as described in the message format. - * - *
  • - * the CipherHandler that provides methods to cryptographically transform bytes using a block - * cipher. Currently, it only uses AES-GCM block cipher. - * - *
  • - * the EncContextSerializer provides methods to serialize a map containing the encryption context - * into bytes, and deserialize bytes into a map containing the encryption context. - * - *
  • - * the PrimitivesParser provides methods to parse primitive types from bytes. These methods are used - * by deserialization code. - * - *
  • - * the ContentAadGenerator provides methods to generate the Additional Authenticated Data (AAD) used - * in encrypting the content. - * - *
  • - * the Constants class that contains the constants and default values used in the library. - * + *
  • the CryptoHandler interface that defines the contract for the methods that must be + * implemented by classes that perform encryption and decryption in this library. + *
  • the EncryptionHandler and DecryptionHandler classes handle the creation and parsing of the + * ciphertext headers as described in the message format. These two classes delegate the + * actual encryption and decryption of content to the Block and Frame handlers. + *
  • the BlockEncryptionHandler and BlockDecryptionHandler classes handle the encryption and + * decryption of content stored as a single-block as described in the message format. + *
  • the FrameEncryptionHandler and FrameDecryptionHandler classes handle the encryption and + * decryption of content stored as frames as described in the message format. + *
  • the CipherHandler that provides methods to cryptographically transform bytes using a block + * cipher. Currently, it only uses AES-GCM block cipher. + *
  • the EncContextSerializer provides methods to serialize a map containing the encryption + * context into bytes, and deserialize bytes into a map containing the encryption context. + *
  • the PrimitivesParser provides methods to parse primitive types from bytes. These methods + * are used by deserialization code. + *
  • the ContentAadGenerator provides methods to generate the Additional Authenticated Data + * (AAD) used in encrypting the content. + *
  • the Constants class that contains the constants and default values used in the library. *
*/ package com.amazonaws.encryptionsdk.internal; diff --git a/src/main/java/com/amazonaws/encryptionsdk/jce/JceMasterKey.java b/src/main/java/com/amazonaws/encryptionsdk/jce/JceMasterKey.java index 4995066cb..3d40d0b07 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/jce/JceMasterKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/jce/JceMasterKey.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -21,9 +21,6 @@ import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; import com.amazonaws.encryptionsdk.internal.JceKeyCipher; import com.amazonaws.encryptionsdk.internal.Utils; - -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; import java.nio.charset.StandardCharsets; import java.security.Key; import java.security.PrivateKey; @@ -32,128 +29,153 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; /** * Represents a {@link MasterKey} backed by one (or more) JCE {@link Key}s. Instances of this should - * only be acquired using {@link #getInstance(SecretKey, String, String, String)} or - * {@link #getInstance(PublicKey, PrivateKey, String, String, String)}. + * only be acquired using {@link #getInstance(SecretKey, String, String, String)} or {@link + * #getInstance(PublicKey, PrivateKey, String, String, String)}. */ public class JceMasterKey extends MasterKey { - private final String providerName_; - private final String keyId_; - private final byte[] keyIdBytes_; - private final JceKeyCipher jceKeyCipher_; - - /** - * Returns a {@code JceMasterKey} backed by {@code key} using {@code wrappingAlgorithm}. - * Currently "{@code AES/GCM/NoPadding}" is the only supported value for - * {@code wrappingAlgorithm}. - * - * @param key - * key used to wrap/unwrap (encrypt/decrypt) {@link DataKey}s - * @param provider - * @param keyId - * @param wrappingAlgorithm - * @return - */ - public static JceMasterKey getInstance(final SecretKey key, final String provider, final String keyId, - final String wrappingAlgorithm) { - switch (wrappingAlgorithm.toUpperCase()) { - case "AES/GCM/NOPADDING": - return new JceMasterKey(provider, keyId, JceKeyCipher.aesGcm(key)); - default: - throw new IllegalArgumentException("Right now only AES/GCM/NoPadding is supported"); + private final String providerName_; + private final String keyId_; + private final byte[] keyIdBytes_; + private final JceKeyCipher jceKeyCipher_; - } + /** + * Returns a {@code JceMasterKey} backed by {@code key} using {@code wrappingAlgorithm}. Currently + * "{@code AES/GCM/NoPadding}" is the only supported value for {@code wrappingAlgorithm}. + * + * @param key key used to wrap/unwrap (encrypt/decrypt) {@link DataKey}s + * @param provider + * @param keyId + * @param wrappingAlgorithm + * @return + */ + public static JceMasterKey getInstance( + final SecretKey key, + final String provider, + final String keyId, + final String wrappingAlgorithm) { + switch (wrappingAlgorithm.toUpperCase()) { + case "AES/GCM/NOPADDING": + return new JceMasterKey(provider, keyId, JceKeyCipher.aesGcm(key)); + default: + throw new IllegalArgumentException("Right now only AES/GCM/NoPadding is supported"); } + } - /** - * Returns a {@code JceMasterKey} backed by {@code unwrappingKey} and {@code wrappingKey} using - * {@code wrappingAlgorithm}. Currently only RSA algorithms are supported for - * {@code wrappingAlgorithm}. {@code wrappingAlgorithm}. If {@code unwrappingKey} is - * {@code null} then the returned {@link JceMasterKey} can only be used for encryption. - * - * @param wrappingKey - * key used to wrap (encrypt) {@link DataKey}s - * @param unwrappingKey - * (Optional) key used to unwrap (decrypt) {@link DataKey}s. - */ - public static JceMasterKey getInstance(final PublicKey wrappingKey, final PrivateKey unwrappingKey, - final String provider, final String keyId, - final String wrappingAlgorithm) { - if (wrappingAlgorithm.toUpperCase().startsWith("RSA/ECB/")) { - return new JceMasterKey(provider, keyId, JceKeyCipher.rsa(wrappingKey, unwrappingKey, wrappingAlgorithm)); - } - throw new UnsupportedOperationException("Currently only RSA asymmetric algorithms are supported"); + /** + * Returns a {@code JceMasterKey} backed by {@code unwrappingKey} and {@code wrappingKey} using + * {@code wrappingAlgorithm}. Currently only RSA algorithms are supported for {@code + * wrappingAlgorithm}. {@code wrappingAlgorithm}. If {@code unwrappingKey} is {@code null} then + * the returned {@link JceMasterKey} can only be used for encryption. + * + * @param wrappingKey key used to wrap (encrypt) {@link DataKey}s + * @param unwrappingKey (Optional) key used to unwrap (decrypt) {@link DataKey}s. + */ + public static JceMasterKey getInstance( + final PublicKey wrappingKey, + final PrivateKey unwrappingKey, + final String provider, + final String keyId, + final String wrappingAlgorithm) { + if (wrappingAlgorithm.toUpperCase().startsWith("RSA/ECB/")) { + return new JceMasterKey( + provider, keyId, JceKeyCipher.rsa(wrappingKey, unwrappingKey, wrappingAlgorithm)); } + throw new UnsupportedOperationException( + "Currently only RSA asymmetric algorithms are supported"); + } - protected JceMasterKey(final String providerName, final String keyId, final JceKeyCipher jceKeyCipher) { - providerName_ = providerName; - keyId_ = keyId; - keyIdBytes_ = keyId_.getBytes(StandardCharsets.UTF_8); - jceKeyCipher_ = jceKeyCipher; - } + protected JceMasterKey( + final String providerName, final String keyId, final JceKeyCipher jceKeyCipher) { + providerName_ = providerName; + keyId_ = keyId; + keyIdBytes_ = keyId_.getBytes(StandardCharsets.UTF_8); + jceKeyCipher_ = jceKeyCipher; + } - @Override - public String getProviderId() { - return providerName_; - } + @Override + public String getProviderId() { + return providerName_; + } - @Override - public String getKeyId() { - return keyId_; - } + @Override + public String getKeyId() { + return keyId_; + } - @Override - public DataKey generateDataKey(final CryptoAlgorithm algorithm, - final Map encryptionContext) { - final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; - Utils.getSecureRandom().nextBytes(rawKey); - EncryptedDataKey encryptedDataKey = jceKeyCipher_.encryptKey(rawKey, keyId_, providerName_, encryptionContext); - return new DataKey<>(new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()), - encryptedDataKey.getEncryptedDataKey(), encryptedDataKey.getProviderInformation(), this); - } + @Override + public DataKey generateDataKey( + final CryptoAlgorithm algorithm, final Map encryptionContext) { + final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; + Utils.getSecureRandom().nextBytes(rawKey); + EncryptedDataKey encryptedDataKey = + jceKeyCipher_.encryptKey(rawKey, keyId_, providerName_, encryptionContext); + return new DataKey<>( + new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()), + encryptedDataKey.getEncryptedDataKey(), + encryptedDataKey.getProviderInformation(), + this); + } - @Override - public DataKey encryptDataKey(final CryptoAlgorithm algorithm, - final Map encryptionContext, - final DataKey dataKey) { - final SecretKey key = dataKey.getKey(); - if (!key.getFormat().equals("RAW")) { - throw new IllegalArgumentException("Can only re-encrypt data keys which are in RAW format, not " - + dataKey.getKey().getFormat()); - } - if (!key.getAlgorithm().equalsIgnoreCase(algorithm.getDataKeyAlgo())) { - throw new IllegalArgumentException("Incorrect key algorithm. Expected " + key.getAlgorithm() - + " but got " + algorithm.getKeyAlgo()); - } - EncryptedDataKey encryptedDataKey = jceKeyCipher_.encryptKey(key.getEncoded(), keyId_, providerName_, encryptionContext); - return new DataKey<>(key, encryptedDataKey.getEncryptedDataKey(), encryptedDataKey.getProviderInformation(), this); + @Override + public DataKey encryptDataKey( + final CryptoAlgorithm algorithm, + final Map encryptionContext, + final DataKey dataKey) { + final SecretKey key = dataKey.getKey(); + if (!key.getFormat().equals("RAW")) { + throw new IllegalArgumentException( + "Can only re-encrypt data keys which are in RAW format, not " + + dataKey.getKey().getFormat()); + } + if (!key.getAlgorithm().equalsIgnoreCase(algorithm.getDataKeyAlgo())) { + throw new IllegalArgumentException( + "Incorrect key algorithm. Expected " + + key.getAlgorithm() + + " but got " + + algorithm.getKeyAlgo()); } + EncryptedDataKey encryptedDataKey = + jceKeyCipher_.encryptKey(key.getEncoded(), keyId_, providerName_, encryptionContext); + return new DataKey<>( + key, + encryptedDataKey.getEncryptedDataKey(), + encryptedDataKey.getProviderInformation(), + this); + } - @Override - public DataKey decryptDataKey(final CryptoAlgorithm algorithm, - final Collection encryptedDataKeys, - final Map encryptionContext) - throws UnsupportedProviderException, AwsCryptoException { - final List exceptions = new ArrayList<>(); - // Find an encrypted key who's provider and info match us - for (final EncryptedDataKey edk : encryptedDataKeys) { - try { - if (edk.getProviderId().equals(getProviderId()) - && Utils.arrayPrefixEquals(edk.getProviderInformation(), keyIdBytes_, keyIdBytes_.length)) { - final byte[] decryptedKey = jceKeyCipher_.decryptKey(edk, keyId_, encryptionContext); + @Override + public DataKey decryptDataKey( + final CryptoAlgorithm algorithm, + final Collection encryptedDataKeys, + final Map encryptionContext) + throws UnsupportedProviderException, AwsCryptoException { + final List exceptions = new ArrayList<>(); + // Find an encrypted key who's provider and info match us + for (final EncryptedDataKey edk : encryptedDataKeys) { + try { + if (edk.getProviderId().equals(getProviderId()) + && Utils.arrayPrefixEquals( + edk.getProviderInformation(), keyIdBytes_, keyIdBytes_.length)) { + final byte[] decryptedKey = jceKeyCipher_.decryptKey(edk, keyId_, encryptionContext); - // Validate that the decrypted key length is as expected - if (decryptedKey.length == algorithm.getDataKeyLength()) { - return new DataKey<>(new SecretKeySpec(decryptedKey, algorithm.getDataKeyAlgo()), - edk.getEncryptedDataKey(), edk.getProviderInformation(), this); - } - } - } catch (final Exception ex) { - exceptions.add(ex); - } + // Validate that the decrypted key length is as expected + if (decryptedKey.length == algorithm.getDataKeyLength()) { + return new DataKey<>( + new SecretKeySpec(decryptedKey, algorithm.getDataKeyAlgo()), + edk.getEncryptedDataKey(), + edk.getProviderInformation(), + this); + } } - throw buildCannotDecryptDksException(exceptions); + } catch (final Exception ex) { + exceptions.add(ex); + } } + throw buildCannotDecryptDksException(exceptions); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/jce/KeyStoreProvider.java b/src/main/java/com/amazonaws/encryptionsdk/jce/KeyStoreProvider.java index 1e92cd2bd..57f44aae9 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/jce/KeyStoreProvider.java +++ b/src/main/java/com/amazonaws/encryptionsdk/jce/KeyStoreProvider.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,6 +13,14 @@ package com.amazonaws.encryptionsdk.jce; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.DataKey; +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.MasterKeyProvider; +import com.amazonaws.encryptionsdk.MasterKeyRequest; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; +import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; import java.nio.charset.StandardCharsets; import java.security.KeyStore; import java.security.KeyStore.Entry; @@ -30,160 +38,165 @@ import java.util.List; import java.util.Map; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.DataKey; -import com.amazonaws.encryptionsdk.EncryptedDataKey; -import com.amazonaws.encryptionsdk.MasterKeyProvider; -import com.amazonaws.encryptionsdk.MasterKeyRequest; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; -import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; -import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; - /** - * This {@link MasterKeyProvider} provides keys backed by a JCE {@link KeyStore}. Please see - * {@link #decryptDataKey(CryptoAlgorithm, Collection, Map)} for an of how decryption is managed and - * see {@link #getMasterKeysForEncryption(MasterKeyRequest)} for an explanation of how encryption is + * This {@link MasterKeyProvider} provides keys backed by a JCE {@link KeyStore}. Please see {@link + * #decryptDataKey(CryptoAlgorithm, Collection, Map)} for an of how decryption is managed and see + * {@link #getMasterKeysForEncryption(MasterKeyRequest)} for an explanation of how encryption is * managed. */ public class KeyStoreProvider extends MasterKeyProvider { - private final String providerName_; - private final KeyStore keystore_; - private final ProtectionParameter protection_; - private final String wrappingAlgorithm_; - private final String keyAlgorithm_; - private final List aliasNames_; + private final String providerName_; + private final KeyStore keystore_; + private final ProtectionParameter protection_; + private final String wrappingAlgorithm_; + private final String keyAlgorithm_; + private final List aliasNames_; - /** - * Creates an instance of this class using {@code wrappingAlgorithm} which will work - * for decrypt only. - */ - public KeyStoreProvider(final KeyStore keystore, final ProtectionParameter protection, - final String providerName, final String wrappingAlgorithm) { - this(keystore, protection, providerName, wrappingAlgorithm, new String[0]); - } + /** + * Creates an instance of this class using {@code wrappingAlgorithm} which will work for + * decrypt only. + */ + public KeyStoreProvider( + final KeyStore keystore, + final ProtectionParameter protection, + final String providerName, + final String wrappingAlgorithm) { + this(keystore, protection, providerName, wrappingAlgorithm, new String[0]); + } - /** - * Creates an instance of this class using {@code wrappingAlgorithm} which will encrypt data to - * the keys specified by {@code aliasNames}. - */ - public KeyStoreProvider(final KeyStore keystore, final ProtectionParameter protection, - final String providerName, final String wrappingAlgorithm, final String... aliasNames) { - keystore_ = keystore; - protection_ = protection; - wrappingAlgorithm_ = wrappingAlgorithm; - aliasNames_ = Arrays.asList(aliasNames); - providerName_ = providerName; - keyAlgorithm_ = wrappingAlgorithm.split("/", 2)[0].toUpperCase(); - } + /** + * Creates an instance of this class using {@code wrappingAlgorithm} which will encrypt data to + * the keys specified by {@code aliasNames}. + */ + public KeyStoreProvider( + final KeyStore keystore, + final ProtectionParameter protection, + final String providerName, + final String wrappingAlgorithm, + final String... aliasNames) { + keystore_ = keystore; + protection_ = protection; + wrappingAlgorithm_ = wrappingAlgorithm; + aliasNames_ = Arrays.asList(aliasNames); + providerName_ = providerName; + keyAlgorithm_ = wrappingAlgorithm.split("/", 2)[0].toUpperCase(); + } - /** - * Returns a {@link JceMasterKey} corresponding to the entry in the {@link KeyStore} with the - * specified alias and compatible algorithm. - */ - @Override - public JceMasterKey getMasterKey(final String provider, final String keyId) throws UnsupportedProviderException, - NoSuchMasterKeyException { - if (!canProvide(provider)) { - throw new UnsupportedProviderException(); - } - final JceMasterKey result = internalGetMasterKey(provider, keyId); - if (result == null) { - throw new NoSuchMasterKeyException(); - } else { - return result; - } + /** + * Returns a {@link JceMasterKey} corresponding to the entry in the {@link KeyStore} with the + * specified alias and compatible algorithm. + */ + @Override + public JceMasterKey getMasterKey(final String provider, final String keyId) + throws UnsupportedProviderException, NoSuchMasterKeyException { + if (!canProvide(provider)) { + throw new UnsupportedProviderException(); } - - private JceMasterKey internalGetMasterKey(final String provider, final String keyId) { - final Entry entry; - try { - entry = keystore_.getEntry(keyId, keystore_.isKeyEntry(keyId) ? protection_ : null); - } catch (NoSuchAlgorithmException | UnrecoverableEntryException | KeyStoreException e) { - throw new UnsupportedProviderException(e); - } - if (entry == null) { - throw new NoSuchMasterKeyException(); - } - if (entry instanceof SecretKeyEntry) { - final SecretKeyEntry skEntry = (SecretKeyEntry) entry; - if (!skEntry.getSecretKey().getAlgorithm().equals(keyAlgorithm_)) { - return null; - } - return JceMasterKey.getInstance(skEntry.getSecretKey(), provider, keyId, wrappingAlgorithm_); - } else if (entry instanceof PrivateKeyEntry) { - final PrivateKeyEntry pkEntry = (PrivateKeyEntry) entry; - if (!pkEntry.getPrivateKey().getAlgorithm().equals(keyAlgorithm_)) { - return null; - } - return JceMasterKey.getInstance(pkEntry.getCertificate().getPublicKey(), pkEntry.getPrivateKey(), provider, - keyId, wrappingAlgorithm_); - } else if (entry instanceof TrustedCertificateEntry) { - final TrustedCertificateEntry certEntry = (TrustedCertificateEntry) entry; - if (!certEntry.getTrustedCertificate().getPublicKey().getAlgorithm().equals(keyAlgorithm_)) { - return null; - } - return JceMasterKey.getInstance(certEntry.getTrustedCertificate().getPublicKey(), null, provider, keyId, - wrappingAlgorithm_); - } else { - throw new NoSuchMasterKeyException(); - } + final JceMasterKey result = internalGetMasterKey(provider, keyId); + if (result == null) { + throw new NoSuchMasterKeyException(); + } else { + return result; } + } - /** - * Returns "JavaKeyStore". - */ - @Override - public String getDefaultProviderId() { - return providerName_; + private JceMasterKey internalGetMasterKey(final String provider, final String keyId) { + final Entry entry; + try { + entry = keystore_.getEntry(keyId, keystore_.isKeyEntry(keyId) ? protection_ : null); + } catch (NoSuchAlgorithmException | UnrecoverableEntryException | KeyStoreException e) { + throw new UnsupportedProviderException(e); + } + if (entry == null) { + throw new NoSuchMasterKeyException(); } + if (entry instanceof SecretKeyEntry) { + final SecretKeyEntry skEntry = (SecretKeyEntry) entry; + if (!skEntry.getSecretKey().getAlgorithm().equals(keyAlgorithm_)) { + return null; + } + return JceMasterKey.getInstance(skEntry.getSecretKey(), provider, keyId, wrappingAlgorithm_); + } else if (entry instanceof PrivateKeyEntry) { + final PrivateKeyEntry pkEntry = (PrivateKeyEntry) entry; + if (!pkEntry.getPrivateKey().getAlgorithm().equals(keyAlgorithm_)) { + return null; + } + return JceMasterKey.getInstance( + pkEntry.getCertificate().getPublicKey(), + pkEntry.getPrivateKey(), + provider, + keyId, + wrappingAlgorithm_); + } else if (entry instanceof TrustedCertificateEntry) { + final TrustedCertificateEntry certEntry = (TrustedCertificateEntry) entry; + if (!certEntry.getTrustedCertificate().getPublicKey().getAlgorithm().equals(keyAlgorithm_)) { + return null; + } + return JceMasterKey.getInstance( + certEntry.getTrustedCertificate().getPublicKey(), + null, + provider, + keyId, + wrappingAlgorithm_); + } else { + throw new NoSuchMasterKeyException(); + } + } - /** - * Returns {@link JceMasterKey}s corresponding to the {@code aliasNames} passed into the - * constructor. - */ - @Override - public List getMasterKeysForEncryption(final MasterKeyRequest request) { - if (aliasNames_ != null) { - final List result = new ArrayList<>(); - for (final String alias : aliasNames_) { - result.add(getMasterKey(alias)); - } - return result; - } else { - return Collections.emptyList(); - } + /** Returns "JavaKeyStore". */ + @Override + public String getDefaultProviderId() { + return providerName_; + } + + /** + * Returns {@link JceMasterKey}s corresponding to the {@code aliasNames} passed into the + * constructor. + */ + @Override + public List getMasterKeysForEncryption(final MasterKeyRequest request) { + if (aliasNames_ != null) { + final List result = new ArrayList<>(); + for (final String alias : aliasNames_) { + result.add(getMasterKey(alias)); + } + return result; + } else { + return Collections.emptyList(); } + } - /** - * Attempts to decrypts the {@code encryptedDataKeys} by first iterating through all - * {@code aliasNames} specified in the constructor and then over - * all other compatible keys in the {@link KeyStore}. This includes - * {@code TrustedCertificates} as well as standard key entries. - */ - @Override - public DataKey decryptDataKey(final CryptoAlgorithm algorithm, - final Collection encryptedDataKeys, - final Map encryptionContext) - throws UnsupportedProviderException, AwsCryptoException { - final List exceptions = new ArrayList<>(); - for (final EncryptedDataKey edk : encryptedDataKeys) { - try { - if (canProvide(edk.getProviderId())) { - final String alias = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); - if (keystore_.isKeyEntry(alias)) { - final DataKey result = getMasterKey(alias).decryptDataKey(algorithm, - Collections.singletonList(edk), - encryptionContext); - if (result != null) { - return result; - } - } - } - } catch (final Exception ex) { - exceptions.add(ex); + /** + * Attempts to decrypts the {@code encryptedDataKeys} by first iterating through all {@code + * aliasNames} specified in the constructor and then over all other compatible keys in + * the {@link KeyStore}. This includes {@code TrustedCertificates} as well as standard key + * entries. + */ + @Override + public DataKey decryptDataKey( + final CryptoAlgorithm algorithm, + final Collection encryptedDataKeys, + final Map encryptionContext) + throws UnsupportedProviderException, AwsCryptoException { + final List exceptions = new ArrayList<>(); + for (final EncryptedDataKey edk : encryptedDataKeys) { + try { + if (canProvide(edk.getProviderId())) { + final String alias = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); + if (keystore_.isKeyEntry(alias)) { + final DataKey result = + getMasterKey(alias) + .decryptDataKey(algorithm, Collections.singletonList(edk), encryptionContext); + if (result != null) { + return result; } + } } - - throw buildCannotDecryptDksException(exceptions); + } catch (final Exception ex) { + exceptions.add(ex); + } } + + throw buildCannotDecryptDksException(exceptions); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/jce/package-info.java b/src/main/java/com/amazonaws/encryptionsdk/jce/package-info.java index de5b5e6f1..c74b29501 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/jce/package-info.java +++ b/src/main/java/com/amazonaws/encryptionsdk/jce/package-info.java @@ -1,18 +1,18 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ /** - * Contains logic necessary to create {@link com.amazonaws.encryptionsdk.MasterKey}s with - * raw cryptographic keys, {@link java.security.Key}s, or {@link java.security.KeyStore}. + * Contains logic necessary to create {@link com.amazonaws.encryptionsdk.MasterKey}s with raw + * cryptographic keys, {@link java.security.Key}s, or {@link java.security.KeyStore}. */ package com.amazonaws.encryptionsdk.jce; diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKey.java b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKey.java index b970dbcb4..0658816ad 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKey.java @@ -3,12 +3,7 @@ package com.amazonaws.encryptionsdk.kms; -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.*; -import java.util.function.Supplier; +import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.*; import com.amazonaws.AmazonServiceException; import com.amazonaws.AmazonWebServiceRequest; @@ -23,418 +18,406 @@ import com.amazonaws.services.kms.model.EncryptResult; import com.amazonaws.services.kms.model.GenerateDataKeyRequest; import com.amazonaws.services.kms.model.GenerateDataKeyResult; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.function.Supplier; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; -import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.*; - - -//= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.5 -//# MUST implement the Master Key Interface (../master-key- -//# interface.md#interface) +// = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.5 +// # MUST implement the Master Key Interface (../master-key- +// # interface.md#interface) // -//= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.7 -//# MUST be unchanged from the Master Key interface. +// = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.7 +// # MUST be unchanged from the Master Key interface. // -//= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.8 -//# MUST be unchanged from the Master Key interface. +// = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.8 +// # MUST be unchanged from the Master Key interface. /** - * Represents a single Aws KMS key - * and is used to encrypt/decrypt data with - * {@link AwsCrypto}. - * This key may be a multi region key, - * in which case this component - * is able to recognize - * different regional replicas - * of this multi region key as the same. + * Represents a single Aws KMS key and is used to encrypt/decrypt data with {@link AwsCrypto}. This + * key may be a multi region key, in which case this component is able to recognize different + * regional replicas of this multi region key as the same. */ -public final class AwsKmsMrkAwareMasterKey extends MasterKey implements KmsMethods { - private static final String USER_AGENT = VersionInfo.loadUserAgent(); - private final AWSKMS kmsClient_; - private final List grantTokens_ = new ArrayList<>(); - private final String awsKmsIdentifier_; - private final MasterKeyProvider sourceProvider_; - - private static T updateUserAgent(T request) { - request.getRequestClientOptions().appendUserAgent(USER_AGENT); - - return request; - } - - /** - * A light builder method. - * - * @see KmsMasterKey#getInstance(Supplier, String, MasterKeyProvider) - * @param kms An AWS KMS Client - * @param awsKmsIdentifier An identifier for an AWS KMS key. May be a raw resource. - */ - static AwsKmsMrkAwareMasterKey getInstance( - final AWSKMS kms, - final String awsKmsIdentifier, - final MasterKeyProvider provider - ) { - return new AwsKmsMrkAwareMasterKey(awsKmsIdentifier, kms, provider); +public final class AwsKmsMrkAwareMasterKey extends MasterKey + implements KmsMethods { + private static final String USER_AGENT = VersionInfo.loadUserAgent(); + private final AWSKMS kmsClient_; + private final List grantTokens_ = new ArrayList<>(); + private final String awsKmsIdentifier_; + private final MasterKeyProvider sourceProvider_; + + private static T updateUserAgent(T request) { + request.getRequestClientOptions().appendUserAgent(USER_AGENT); + + return request; + } + + /** + * A light builder method. + * + * @see KmsMasterKey#getInstance(Supplier, String, MasterKeyProvider) + * @param kms An AWS KMS Client + * @param awsKmsIdentifier An identifier for an AWS KMS key. May be a raw resource. + */ + static AwsKmsMrkAwareMasterKey getInstance( + final AWSKMS kms, + final String awsKmsIdentifier, + final MasterKeyProvider provider) { + return new AwsKmsMrkAwareMasterKey(awsKmsIdentifier, kms, provider); + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // # On initialization, the caller MUST provide: + private AwsKmsMrkAwareMasterKey( + final String awsKmsIdentifier, + final AWSKMS kmsClient, + final MasterKeyProvider provider) { + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // # The AWS KMS key identifier MUST NOT be null or empty. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // # The AWS KMS + // # key identifier MUST be a valid identifier (aws-kms-key-arn.md#a- + // # valid-aws-kms-identifier). + validAwsKmsIdentifier(awsKmsIdentifier); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // # The AWS KMS SDK client MUST not be null. + if (kmsClient == null) { + throw new IllegalArgumentException( + "AwsKmsMrkAwareMasterKey must be configured with an AWS KMS client."); } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //# On initialization, the caller MUST provide: - private AwsKmsMrkAwareMasterKey( - final String awsKmsIdentifier, - final AWSKMS kmsClient, - final MasterKeyProvider provider - ) { - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //# The AWS KMS key identifier MUST NOT be null or empty. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //# The AWS KMS - //# key identifier MUST be a valid identifier (aws-kms-key-arn.md#a- - //# valid-aws-kms-identifier). - validAwsKmsIdentifier(awsKmsIdentifier); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //# The AWS KMS SDK client MUST not be null. - if (kmsClient == null) { - throw new IllegalArgumentException("AwsKmsMrkAwareMasterKey must be configured with an AWS KMS client."); - } - - /* Precondition: A provider is required. */ - if (provider == null) { - throw new IllegalArgumentException("AwsKmsMrkAwareMasterKey must be configured with a source provider."); - } - - kmsClient_ = kmsClient; - awsKmsIdentifier_ = awsKmsIdentifier; - sourceProvider_ = provider; + /* Precondition: A provider is required. */ + if (provider == null) { + throw new IllegalArgumentException( + "AwsKmsMrkAwareMasterKey must be configured with a source provider."); } - @Override - public String getProviderId() { - return sourceProvider_.getDefaultProviderId(); + kmsClient_ = kmsClient; + awsKmsIdentifier_ = awsKmsIdentifier; + sourceProvider_ = provider; + } + + @Override + public String getProviderId() { + return sourceProvider_.getDefaultProviderId(); + } + + @Override + public String getKeyId() { + return awsKmsIdentifier_; + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // # The master key MUST be able to be configured with an optional list of + // # Grant Tokens. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // = type=exception + // # This configuration SHOULD be on initialization and + // # SHOULD be immutable. + // The existing KMS Master Key + // sets grants in this way, so we continue this interface. + /** Clears and sets all grant tokens on this instance. This is not thread safe. */ + @Override + public void setGrantTokens(final List grantTokens) { + grantTokens_.clear(); + grantTokens_.addAll(grantTokens); + } + + @Override + public List getGrantTokens() { + return grantTokens_; + } + + @Override + public void addGrantToken(final String grantToken) { + grantTokens_.add(grantToken); + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // # The inputs MUST be the same as the Master Key Generate Data Key + // # (../master-key-interface.md#generate-data-key) interface. + /** + * This is identical behavior to + * + * @see KmsMasterKey#generateDataKey(CryptoAlgorithm, Map) + */ + @Override + public DataKey generateDataKey( + final CryptoAlgorithm algorithm, final Map encryptionContext) { + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // # This + // # master key MUST use the configured AWS KMS client to make an AWS KMS + // # GenerateDatakey (https://docs.aws.amazon.com/kms/latest/APIReference/ + // # API_GenerateDataKey.html) request constructed as follows: + final GenerateDataKeyResult gdkResult = + kmsClient_.generateDataKey( + updateUserAgent( + new GenerateDataKeyRequest() + .withKeyId(awsKmsIdentifier_) + .withNumberOfBytes(algorithm.getDataKeyLength()) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens_))); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // # If the call succeeds the AWS KMS Generate Data Key response's + // # "Plaintext" MUST match the key derivation input length specified by + // # the algorithm suite included in the input. + if (gdkResult.getPlaintext().limit() != algorithm.getDataKeyLength()) { + throw new IllegalStateException("Received an unexpected number of bytes from KMS"); } - @Override - public String getKeyId() { - return awsKmsIdentifier_; - } + final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; + gdkResult.getPlaintext().get(rawKey); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //# The master key MUST be able to be configured with an optional list of - //# Grant Tokens. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //= type=exception - //# This configuration SHOULD be on initialization and - //# SHOULD be immutable. - // The existing KMS Master Key - // sets grants in this way, so we continue this interface. - /** - * Clears and sets all grant tokens on this instance. - * This is not thread safe. - */ - @Override - public void setGrantTokens(final List grantTokens) { - grantTokens_.clear(); - grantTokens_.addAll(grantTokens); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // # The response's "KeyId" + // # MUST be valid. + final String gdkResultKeyId = gdkResult.getKeyId(); + /* Exceptional Postcondition: Must have an AWS KMS ARN from AWS KMS generateDataKey. */ + if (parseInfoFromKeyArn(gdkResultKeyId) == null) { + throw new IllegalStateException("Received an empty or invalid keyId from KMS"); } - @Override - public List getGrantTokens() { - return grantTokens_; + final byte[] encryptedKey = new byte[gdkResult.getCiphertextBlob().remaining()]; + gdkResult.getCiphertextBlob().get(encryptedKey); + + final SecretKeySpec key = new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // # The output MUST be the same as the Master Key Generate Data Key + // # (../master-key-interface.md#generate-data-key) interface. + return new DataKey<>( + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // # The response's "Plaintext" MUST be the plaintext in + // # the output. + key, + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // # The response's cipher text blob MUST be used as the + // # returned as the ciphertext for the encrypted data key in the output. + encryptedKey, + gdkResultKeyId.getBytes(StandardCharsets.UTF_8), + this); + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // # The inputs MUST be the same as the Master Key Encrypt Data Key + // # (../master-key-interface.md#encrypt-data-key) interface. + /** @see KmsMasterKey#encryptDataKey(CryptoAlgorithm, Map, DataKey) */ + @Override + public DataKey encryptDataKey( + final CryptoAlgorithm algorithm, + final Map encryptionContext, + final DataKey dataKey) { + final SecretKey key = dataKey.getKey(); + /* Precondition: The key format MUST be RAW. */ + if (!key.getFormat().equals("RAW")) { + throw new IllegalArgumentException("Only RAW encoded keys are supported"); } - @Override - public void addGrantToken(final String grantToken) { - grantTokens_.add(grantToken); + try { + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // # The master + // # key MUST use the configured AWS KMS client to make an AWS KMS Encrypt + // # (https://docs.aws.amazon.com/kms/latest/APIReference/ + // # API_Encrypt.html) request constructed as follows: + final EncryptResult encryptResult = + kmsClient_.encrypt( + updateUserAgent( + new EncryptRequest() + .withKeyId(awsKmsIdentifier_) + .withPlaintext(ByteBuffer.wrap(key.getEncoded())) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens_))); + + final byte[] edk = new byte[encryptResult.getCiphertextBlob().remaining()]; + encryptResult.getCiphertextBlob().get(edk); + final String encryptResultKeyId = encryptResult.getKeyId(); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // # The AWS KMS Encrypt response MUST contain a valid "KeyId". + /* Postcondition: Must have an AWS KMS ARN from AWS KMS encrypt. */ + if (parseInfoFromKeyArn(encryptResultKeyId) == null) { + throw new IllegalStateException("Received an empty or invalid keyId from KMS"); + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // # The output MUST be the same as the Master Key Encrypt Data Key + // # (../master-key-interface.md#encrypt-data-key) interface. + return new DataKey<>( + dataKey.getKey(), + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // # The + // # response's cipher text blob MUST be used as the "ciphertext" for the + // # encrypted data key. + edk, + encryptResultKeyId.getBytes(StandardCharsets.UTF_8), + this); + } catch (final AmazonServiceException asex) { + throw new AwsCryptoException(asex); } - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //# The inputs MUST be the same as the Master Key Generate Data Key - //# (../master-key-interface.md#generate-data-key) interface. - /** - * This is identical behavior to - * @see KmsMasterKey#generateDataKey(CryptoAlgorithm, Map) - */ - @Override - public DataKey generateDataKey(final CryptoAlgorithm algorithm, - final Map encryptionContext) { - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //# This - //# master key MUST use the configured AWS KMS client to make an AWS KMS - //# GenerateDatakey (https://docs.aws.amazon.com/kms/latest/APIReference/ - //# API_GenerateDataKey.html) request constructed as follows: - final GenerateDataKeyResult gdkResult = kmsClient_.generateDataKey(updateUserAgent( - new GenerateDataKeyRequest() - .withKeyId(awsKmsIdentifier_) - .withNumberOfBytes(algorithm.getDataKeyLength()) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_) - )); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //# If the call succeeds the AWS KMS Generate Data Key response's - //# "Plaintext" MUST match the key derivation input length specified by - //# the algorithm suite included in the input. - if (gdkResult.getPlaintext().limit() != algorithm.getDataKeyLength()) { - throw new IllegalStateException("Received an unexpected number of bytes from KMS"); - } - - final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; - gdkResult.getPlaintext().get(rawKey); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //# The response's "KeyId" - //# MUST be valid. - final String gdkResultKeyId = gdkResult.getKeyId(); - /* Exceptional Postcondition: Must have an AWS KMS ARN from AWS KMS generateDataKey. */ - if (parseInfoFromKeyArn(gdkResultKeyId) == null) { - throw new IllegalStateException("Received an empty or invalid keyId from KMS"); - } - - final byte[] encryptedKey = new byte[gdkResult.getCiphertextBlob().remaining()]; - gdkResult.getCiphertextBlob().get(encryptedKey); - - final SecretKeySpec key = new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //# The output MUST be the same as the Master Key Generate Data Key - //# (../master-key-interface.md#generate-data-key) interface. - return new DataKey<>( - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //# The response's "Plaintext" MUST be the plaintext in - //# the output. - key, - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //# The response's cipher text blob MUST be used as the - //# returned as the ciphertext for the encrypted data key in the output. - encryptedKey, - gdkResultKeyId.getBytes(StandardCharsets.UTF_8), - this - ); + } + + /** + * Will attempt to decrypt if awsKmsArnMatchForDecrypt returns true in {@link + * AwsKmsMrkAwareMasterKey#filterEncryptedDataKeys(String, AwsKmsCmkArnInfo, EncryptedDataKey)}. + * An extension of {@link KmsMasterKey#decryptDataKey(CryptoAlgorithm, Collection, Map)} but with + * an awareness of the properties of multi-Region keys. + */ + @Override + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # The inputs MUST be the same as the Master Key Decrypt Data Key + // # (../master-key-interface.md#decrypt-data-key) interface. + public DataKey decryptDataKey( + final CryptoAlgorithm algorithm, + final Collection encryptedDataKeys, + final Map encryptionContext) + throws AwsCryptoException { + final List exceptions = new ArrayList<>(); + final String providerId = this.getProviderId(); + + return encryptedDataKeys.stream() + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # The set of encrypted data keys MUST first be filtered to match this + // # master key's configuration. + .filter(edk -> filterEncryptedDataKeys(providerId, awsKmsIdentifier_, edk)) + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # For each encrypted data key in the filtered set, one at a time, the + // # master key MUST attempt to decrypt the data key. + .map( + edk -> { + try { + return decryptSingleEncryptedDataKey( + this, + kmsClient_, + awsKmsIdentifier_, + grantTokens_, + algorithm, + edk, + encryptionContext); + } catch (final AmazonServiceException amazonServiceException) { + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # If this attempt + // # results in an error, then these errors MUST be collected. + exceptions.add(amazonServiceException); + } + return null; + }) + /* Need to filter null + * because an Optional + * of a null is crazy. + * Therefore `findFirst` will throw + * if it sees `null`. + */ + .filter(Objects::nonNull) + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # If the AWS KMS response satisfies the requirements then it MUST be + // # use and this function MUST return and not attempt to decrypt any more + // # encrypted data keys. + /* Order is important. + * Process the encrypted data keys in the order they exist in the encrypted message. + */ + .findFirst() + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # If all the input encrypted data keys have been processed then this + // # function MUST yield an error that includes all the collected errors. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # The output MUST be the same as the Master Key Decrypt Data Key + // # (../master-key-interface.md#decrypt-data-key) interface. + /* Exceptional Postcondition: Master key was unable to decrypt. */ + .orElseThrow(() -> buildCannotDecryptDksException(exceptions)); + } + + /** + * Pure function for decrypting and encrypted data key. This is refactored out of `decryptDataKey` + * to facilitate testing to ensure correctness. + */ + static DataKey decryptSingleEncryptedDataKey( + final AwsKmsMrkAwareMasterKey masterKey, + final AWSKMS client, + final String awsKmsIdentifier, + final List grantTokens, + final CryptoAlgorithm algorithm, + final EncryptedDataKey edk, + final Map encryptionContext) { + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # To decrypt the encrypted data key this master key MUST use the + // # configured AWS KMS client to make an AWS KMS Decrypt + // # (https://docs.aws.amazon.com/kms/latest/APIReference/ + // # API_Decrypt.html) request constructed as follows: + final DecryptResult decryptResult = + client.decrypt( + updateUserAgent( + new DecryptRequest() + .withCiphertextBlob(ByteBuffer.wrap(edk.getEncryptedDataKey())) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens) + .withKeyId(awsKmsIdentifier))); + + final String decryptResultKeyId = decryptResult.getKeyId(); + /* Exceptional Postcondition: Must have a CMK ARN from AWS KMS to match. */ + if (decryptResultKeyId == null) { + throw new IllegalStateException("Received an empty keyId from KMS"); } - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //# The inputs MUST be the same as the Master Key Encrypt Data Key - //# (../master-key-interface.md#encrypt-data-key) interface. - /** - * @see KmsMasterKey#encryptDataKey(CryptoAlgorithm, Map, DataKey) - */ - @Override - public DataKey encryptDataKey(final CryptoAlgorithm algorithm, - final Map encryptionContext, - final DataKey dataKey) { - final SecretKey key = dataKey.getKey(); - /* Precondition: The key format MUST be RAW. */ - if (!key.getFormat().equals("RAW")) { - throw new IllegalArgumentException("Only RAW encoded keys are supported"); - } - - try { - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //# The master - //# key MUST use the configured AWS KMS client to make an AWS KMS Encrypt - //# (https://docs.aws.amazon.com/kms/latest/APIReference/ - //# API_Encrypt.html) request constructed as follows: - final EncryptResult encryptResult = kmsClient_.encrypt(updateUserAgent( - new EncryptRequest() - .withKeyId(awsKmsIdentifier_) - .withPlaintext(ByteBuffer.wrap(key.getEncoded())) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_))); - - final byte[] edk = new byte[encryptResult.getCiphertextBlob().remaining()]; - encryptResult.getCiphertextBlob().get(edk); - final String encryptResultKeyId = encryptResult.getKeyId(); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //# The AWS KMS Encrypt response MUST contain a valid "KeyId". - /* Postcondition: Must have an AWS KMS ARN from AWS KMS encrypt. */ - if (parseInfoFromKeyArn(encryptResultKeyId) == null) { - throw new IllegalStateException("Received an empty or invalid keyId from KMS"); - } - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //# The output MUST be the same as the Master Key Encrypt Data Key - //# (../master-key-interface.md#encrypt-data-key) interface. - return new DataKey<>( - dataKey.getKey(), - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //# The - //# response's cipher text blob MUST be used as the "ciphertext" for the - //# encrypted data key. - edk, - encryptResultKeyId.getBytes(StandardCharsets.UTF_8), - this - ); - } catch (final AmazonServiceException asex) { - throw new AwsCryptoException(asex); - } + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # If the call succeeds then the response's "KeyId" MUST be equal to the + // # configured AWS KMS key identifier otherwise the function MUST collect + // # an error. + if (!awsKmsIdentifier.equals(decryptResultKeyId)) { + throw new IllegalStateException( + "Received an invalid response from KMS Decrypt call: Unexpected keyId."); } - /** - * Will attempt to decrypt if awsKmsArnMatchForDecrypt returns true in - * {@link AwsKmsMrkAwareMasterKey#filterEncryptedDataKeys(String, AwsKmsCmkArnInfo, EncryptedDataKey)}. - * An extension of - * {@link KmsMasterKey#decryptDataKey(CryptoAlgorithm, Collection, Map)} - * but with an awareness of the properties of multi-Region keys. - */ - @Override - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# The inputs MUST be the same as the Master Key Decrypt Data Key - //# (../master-key-interface.md#decrypt-data-key) interface. - public DataKey decryptDataKey( - final CryptoAlgorithm algorithm, - final Collection encryptedDataKeys, - final Map encryptionContext - ) throws AwsCryptoException { - final List exceptions = new ArrayList<>(); - final String providerId = this.getProviderId(); - - return encryptedDataKeys - .stream() - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# The set of encrypted data keys MUST first be filtered to match this - //# master key's configuration. - .filter(edk -> filterEncryptedDataKeys(providerId, awsKmsIdentifier_, edk)) - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# For each encrypted data key in the filtered set, one at a time, the - //# master key MUST attempt to decrypt the data key. - .map(edk -> { - try { - return decryptSingleEncryptedDataKey( - this, - kmsClient_, - awsKmsIdentifier_, - grantTokens_, - algorithm, - edk, - encryptionContext - ); - } catch (final AmazonServiceException amazonServiceException) { - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# If this attempt - //# results in an error, then these errors MUST be collected. - exceptions.add(amazonServiceException); - } - return null; - }) - /* Need to filter null - * because an Optional - * of a null is crazy. - * Therefore `findFirst` will throw - * if it sees `null`. - */ - .filter(Objects::nonNull) - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# If the AWS KMS response satisfies the requirements then it MUST be - //# use and this function MUST return and not attempt to decrypt any more - //# encrypted data keys. - /* Order is important. - * Process the encrypted data keys in the order they exist in the encrypted message. - */ - .findFirst() - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# If all the input encrypted data keys have been processed then this - //# function MUST yield an error that includes all the collected errors. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# The output MUST be the same as the Master Key Decrypt Data Key - //# (../master-key-interface.md#decrypt-data-key) interface. - /* Exceptional Postcondition: Master key was unable to decrypt. */ - .orElseThrow(() -> buildCannotDecryptDksException(exceptions)); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # The response's "Plaintext"'s length MUST equal the length + // # required by the requested algorithm suite otherwise the function MUST + // # collect an error. + if (decryptResult.getPlaintext().limit() != algorithm.getDataKeyLength()) { + throw new IllegalStateException("Received an unexpected number of bytes from KMS"); } - /** - * Pure function for decrypting and encrypted data key. - * This is refactored out of `decryptDataKey` - * to facilitate testing to ensure correctness. - * - */ - static DataKey decryptSingleEncryptedDataKey( - final AwsKmsMrkAwareMasterKey masterKey, - final AWSKMS client, - final String awsKmsIdentifier, - final List grantTokens, - final CryptoAlgorithm algorithm, - final EncryptedDataKey edk, - final Map encryptionContext - ) { - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# To decrypt the encrypted data key this master key MUST use the - //# configured AWS KMS client to make an AWS KMS Decrypt - //# (https://docs.aws.amazon.com/kms/latest/APIReference/ - //# API_Decrypt.html) request constructed as follows: - final DecryptResult decryptResult = client.decrypt(updateUserAgent( - new DecryptRequest() - .withCiphertextBlob(ByteBuffer.wrap(edk.getEncryptedDataKey())) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens) - .withKeyId(awsKmsIdentifier))); - - final String decryptResultKeyId = decryptResult.getKeyId(); - /* Exceptional Postcondition: Must have a CMK ARN from AWS KMS to match. */ - if (decryptResultKeyId == null) { - throw new IllegalStateException("Received an empty keyId from KMS"); - } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# If the call succeeds then the response's "KeyId" MUST be equal to the - //# configured AWS KMS key identifier otherwise the function MUST collect - //# an error. - if (!awsKmsIdentifier.equals(decryptResultKeyId)) { - throw new IllegalStateException("Received an invalid response from KMS Decrypt call: Unexpected keyId."); - } - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# The response's "Plaintext"'s length MUST equal the length - //# required by the requested algorithm suite otherwise the function MUST - //# collect an error. - if (decryptResult.getPlaintext().limit() != algorithm.getDataKeyLength()) { - throw new IllegalStateException("Received an unexpected number of bytes from KMS"); - } - - final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; - decryptResult.getPlaintext().get(rawKey); - - return new DataKey<>( - new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()), - edk.getEncryptedDataKey(), - edk.getProviderInformation(), - masterKey); + final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; + decryptResult.getPlaintext().get(rawKey); + + return new DataKey<>( + new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()), + edk.getEncryptedDataKey(), + edk.getProviderInformation(), + masterKey); + } + + /** + * A pure function to filter encrypted data keys. This function is refactored out from + * `decryptDataKey` to facilitate testing and ensure correctness. + * + *

An AWS KMS Master key should only attempt to process an Encrypted Data Key if the + * information in the Encrypted Data Key matches the master keys configuration. + */ + static boolean filterEncryptedDataKeys( + final String providerId, final String awsKmsIdentifier_, final EncryptedDataKey edk) { + final String edkKeyId = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); + + final AwsKmsCmkArnInfo providerArnInfo = parseInfoFromKeyArn(edkKeyId); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # Additionally each provider info MUST be a valid AWS KMS ARN + // # (aws-kms-key-arn.md#a-valid-aws-kms-arn) with a resource type of + // # "key". + if (providerArnInfo == null || !"key".equals(providerArnInfo.getResourceType())) { + throw new IllegalStateException("Invalid provider info in message."); } - /** - * A pure function to filter encrypted data keys. - * This function is refactored out from `decryptDataKey` - * to facilitate testing and ensure correctness. - * - * An AWS KMS Master key should only attempt - * to process an Encrypted Data Key - * if the information in the Encrypted Data Key - * matches the master keys configuration. - * - */ - static boolean filterEncryptedDataKeys ( - final String providerId, - final String awsKmsIdentifier_, - final EncryptedDataKey edk - ) { - final String edkKeyId = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); - - final AwsKmsCmkArnInfo providerArnInfo = parseInfoFromKeyArn(edkKeyId); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# Additionally each provider info MUST be a valid AWS KMS ARN - //# (aws-kms-key-arn.md#a-valid-aws-kms-arn) with a resource type of - //# "key". - if (providerArnInfo == null || !"key".equals(providerArnInfo.getResourceType())) { - throw new IllegalStateException("Invalid provider info in message."); - } - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //# To match the encrypted data key's - //# provider ID MUST exactly match the value "aws-kms" and the the - //# function AWS KMS MRK Match for Decrypt (aws-kms-mrk-match-for- - //# decrypt.md#implementation) called with the configured AWS KMS key - //# identifier and the encrypted data key's provider info MUST return - //# "true". - return edk.getProviderId().equals(providerId) && - awsKmsArnMatchForDecrypt(awsKmsIdentifier_, edkKeyId); - } + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // # To match the encrypted data key's + // # provider ID MUST exactly match the value "aws-kms" and the the + // # function AWS KMS MRK Match for Decrypt (aws-kms-mrk-match-for- + // # decrypt.md#implementation) called with the configured AWS KMS key + // # identifier and the encrypted data key's provider info MUST return + // # "true". + return edk.getProviderId().equals(providerId) + && awsKmsArnMatchForDecrypt(awsKmsIdentifier_, edkKeyId); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyProvider.java b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyProvider.java index 8c0e2874f..5cec7b51e 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyProvider.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyProvider.java @@ -3,6 +3,12 @@ package com.amazonaws.encryptionsdk.kms; +import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.*; +import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.parseInfoFromKeyArn; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; + import com.amazonaws.SdkClientException; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; @@ -16,816 +22,786 @@ import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClient; import com.amazonaws.services.kms.AWSKMSClientBuilder; - import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; -import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.*; -import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.parseInfoFromKeyArn; -import static java.util.Arrays.asList; -import static java.util.Collections.emptyList; -import static java.util.Collections.singletonList; - -//= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.5 -//# MUST implement the Master Key Provider Interface (../master-key- -//# provider-interface.md#interface) +// = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.5 +// # MUST implement the Master Key Provider Interface (../master-key- +// # provider-interface.md#interface) /** - * Represents a list Aws KMS keys - * and is used to encrypt/decrypt data with - * {@link AwsCrypto}. - * Some of these keys may be multi region keys, - * in which case this component - * is able to recognize - * different regional replicas - * of this multi region key as the same. + * Represents a list Aws KMS keys and is used to encrypt/decrypt data with {@link AwsCrypto}. Some + * of these keys may be multi region keys, in which case this component is able to recognize + * different regional replicas of this multi region key as the same. */ -public final class AwsKmsMrkAwareMasterKeyProvider extends MasterKeyProvider { - private static final String PROVIDER_NAME = "aws-kms"; - private final List keyIds_; - private final List grantTokens_; - - private final boolean isDiscovery_; - private final DiscoveryFilter discoveryFilter_; - private final String discoveryMrkRegion_; - - private final KmsMasterKeyProvider.RegionalClientSupplier regionalClientSupplier_; - private final String defaultRegion_; - - public static class Builder implements Cloneable { - private String defaultRegion_ = getSdkDefaultRegion(); - private Optional regionalClientSupplier_ = Optional.empty(); - private AWSKMSClientBuilder templateBuilder_ = null; - private DiscoveryFilter discoveryFilter_ = null; - private String discoveryMrkRegion_ = this.defaultRegion_; - - Builder() { - // Default access: Don't allow outside classes to extend this class - } - - public Builder clone() { - try { - AwsKmsMrkAwareMasterKeyProvider.Builder cloned = (AwsKmsMrkAwareMasterKeyProvider.Builder) super.clone(); - - if (templateBuilder_ != null) { - cloned.templateBuilder_ = cloneClientBuilder(templateBuilder_); - } - - return cloned; - } catch (CloneNotSupportedException e) { - throw new Error("Impossible: CloneNotSupportedException", e); - } - } - - /** - * Sets the default region. - * This region will be used when specifying key IDs for encryption or in - * {@link AwsKmsMrkAwareMasterKeyProvider#getMasterKey(String)} that are not full ARNs, but are instead bare key IDs or - * aliases. - *

- * If the default region is not specified, - * the AWS SDK default region will be used. - * @see KmsMasterKeyProvider.Builder#withDefaultRegion(String) - * @param defaultRegion The default region to use. - */ - public AwsKmsMrkAwareMasterKeyProvider.Builder withDefaultRegion(String defaultRegion) { - this.defaultRegion_ = defaultRegion; - return this; - } - - /** - * Sets the region contacted for multi-region keys - * when in Discovery mode. - * This region will be used when a multi-region key is discovered - * on decrypt by {@link AwsKmsMrkAwareMasterKeyProvider#getMasterKey(String)}. - *

- * - * @param discoveryMrkRegion The region to contact to attempt to decrypt multi-region keys. - */ - public AwsKmsMrkAwareMasterKeyProvider.Builder withDiscoveryMrkRegion(String discoveryMrkRegion) { - this.discoveryMrkRegion_ = discoveryMrkRegion; - return this; - } - - /** - * Provides a custom factory function that will vend KMS clients. This is provided for advanced use cases which - * require complete control over the client construction process. - *

- * Because the regional client supplier fully controls the client construction process, it is not possible to - * configure the client through methods such as {@link #withCredentials(AWSCredentialsProvider)} or - * {@link #withClientBuilder(AWSKMSClientBuilder)}; if you try to use these in combination, an - * {@link IllegalStateException} will be thrown. - * - * @see KmsMasterKeyProvider.Builder#withCustomClientFactory(KmsMasterKeyProvider.RegionalClientSupplier) - */ - public AwsKmsMrkAwareMasterKeyProvider.Builder withCustomClientFactory(KmsMasterKeyProvider.RegionalClientSupplier regionalClientSupplier) { - if (templateBuilder_ != null) { - throw clientSupplierComboException(); - } - - regionalClientSupplier_ = Optional.of(regionalClientSupplier); - return this; - } - - private RuntimeException clientSupplierComboException() { - return new IllegalStateException("withCustomClientFactory cannot be used in conjunction with " + - "withCredentials or withClientBuilder"); - } - - /** - * Configures the {@link AwsKmsMrkAwareMasterKeyProvider} to use specific credentials. If a builder was previously set, - * this will override whatever credentials it set. - * - * @see KmsMasterKeyProvider.Builder#withCredentials(AWSCredentialsProvider) - */ - public AwsKmsMrkAwareMasterKeyProvider.Builder withCredentials(AWSCredentialsProvider credentialsProvider) { - if (regionalClientSupplier_.isPresent()) { - throw clientSupplierComboException(); - } - - if (templateBuilder_ == null) { - templateBuilder_ = AWSKMSClientBuilder.standard(); - } - - templateBuilder_.setCredentials(credentialsProvider); - - return this; - } - - /** - * Configures the {@link AwsKmsMrkAwareMasterKeyProvider} to use specific credentials. If a builder was previously set, - * this will override whatever credentials it set. - * - * @see KmsMasterKeyProvider.Builder#withCredentials(AWSCredentials) - */ - public AwsKmsMrkAwareMasterKeyProvider.Builder withCredentials(AWSCredentials credentials) { - return withCredentials(new AWSStaticCredentialsProvider(credentials)); - } - - /** - * Configures the {@link AwsKmsMrkAwareMasterKeyProvider} to use settings from this {@link AWSKMSClientBuilder} to - * configure KMS clients. Note that the region set on this builder will be ignored, but all other settings - * will be propagated into the regional clients. - *

- * This method will overwrite any credentials set using {@link #withCredentials(AWSCredentialsProvider)}. - * - * @see KmsMasterKeyProvider.Builder#withClientBuilder(AWSKMSClientBuilder) - */ - public AwsKmsMrkAwareMasterKeyProvider.Builder withClientBuilder(AWSKMSClientBuilder builder) { - if (regionalClientSupplier_.isPresent()) { - throw clientSupplierComboException(); - } - final AWSKMSClientBuilder newBuilder = cloneClientBuilder(builder); - this.templateBuilder_ = newBuilder; - - return this; - } - - /** - * Builds the master key provider in Discovery Mode. - * In Discovery Mode the KMS Master Key Provider will attempt to decrypt using any - * key identifier it discovers in the encrypted message. - * KMS Master Key Providers in Discovery Mode will not encrypt data keys. - * - * @see KmsMasterKeyProvider.Builder#buildDiscovery() - * - */ - public AwsKmsMrkAwareMasterKeyProvider buildDiscovery() { - final boolean isDiscovery = true; - - return new AwsKmsMrkAwareMasterKeyProvider( - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# The regional client - //# supplier MUST be defined in discovery mode. - regionalClientSupplier_.orElse(clientFactory(new ConcurrentHashMap<>(), templateBuilder_)), - defaultRegion_, - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# The key id list MUST be empty in discovery mode. - emptyList(), - emptyList(), - isDiscovery, - discoveryFilter_, - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# In - //# discovery mode if a default MRK Region is not configured the AWS SDK - //# Default Region MUST be used. - discoveryMrkRegion_ == null ? defaultRegion_ : discoveryMrkRegion_ - ); - } - - /** - * Builds the master key provider in Discovery Mode with a {@link DiscoveryFilter}. - * In Discovery Mode the KMS Master Key Provider will attempt to decrypt using any - * key identifier it discovers in the encrypted message that is accepted by the - * {@code filter}. - * KMS Master Key Providers in Discovery Mode will not encrypt data keys. - * - * @see KmsMasterKeyProvider.Builder#buildDiscovery(DiscoveryFilter) - */ - public AwsKmsMrkAwareMasterKeyProvider buildDiscovery(DiscoveryFilter filter) { - discoveryFilter_ = filter; +public final class AwsKmsMrkAwareMasterKeyProvider + extends MasterKeyProvider { + private static final String PROVIDER_NAME = "aws-kms"; + private final List keyIds_; + private final List grantTokens_; + + private final boolean isDiscovery_; + private final DiscoveryFilter discoveryFilter_; + private final String discoveryMrkRegion_; + + private final KmsMasterKeyProvider.RegionalClientSupplier regionalClientSupplier_; + private final String defaultRegion_; + + public static class Builder implements Cloneable { + private String defaultRegion_ = getSdkDefaultRegion(); + private Optional regionalClientSupplier_ = + Optional.empty(); + private AWSKMSClientBuilder templateBuilder_ = null; + private DiscoveryFilter discoveryFilter_ = null; + private String discoveryMrkRegion_ = this.defaultRegion_; + + Builder() { + // Default access: Don't allow outside classes to extend this class + } - return buildDiscovery(); - } + public Builder clone() { + try { + AwsKmsMrkAwareMasterKeyProvider.Builder cloned = + (AwsKmsMrkAwareMasterKeyProvider.Builder) super.clone(); - /** - * Builds the master key provider in Strict Mode. - * KMS Master Key Providers in Strict Mode will only attempt to decrypt using - * key ARNs listed in {@code keyIds}. - * KMS Master Key Providers in Strict Mode will encrypt data keys using the keys - * listed in {@code keyIds} - *

- * In Strict Mode, one or more CMKs must be provided. - * For Master Key Providers that will only be used for encryption, - * you can use any valid KMS key identifier. - * For providers that will be used for decryption, - * you must use the key ARN; - * key ids, alias names, and alias ARNs are not supported. - * - * @see KmsMasterKeyProvider.Builder#buildStrict(List) - */ - public AwsKmsMrkAwareMasterKeyProvider buildStrict(List keyIds) { - final boolean isDiscovery = false; - - return new AwsKmsMrkAwareMasterKeyProvider( - regionalClientSupplier_.orElse(clientFactory(new ConcurrentHashMap<>(), templateBuilder_)), - defaultRegion_, - new ArrayList(keyIds), - emptyList(), - isDiscovery, - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# A discovery filter MUST NOT be configured in strict mode. - null, - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# A default MRK Region MUST NOT be configured in strict mode. - null - ); + if (templateBuilder_ != null) { + cloned.templateBuilder_ = cloneClientBuilder(templateBuilder_); } - /** - * Builds the master key provider in strict mode. - * KMS Master Key Providers in Strict Mode will only attempt to decrypt using - * key ARNs listed in {@code keyIds}. - * KMS Master Key Providers in Strict Mode will encrypt data keys using the keys - * listed in {@code keyIds} - *

- * In Strict Mode, one or more CMKs must be provided. - * For Master Key Providers that will only be used for encryption, - * you can use any valid KMS key identifier. - * For providers that will be used for decryption, - * you must use the key ARN; - * key ids, alias names, and alias ARNs are not supported. - * - * @see KmsMasterKeyProvider.Builder#buildStrict(String...) - * - */ - public AwsKmsMrkAwareMasterKeyProvider buildStrict(String... keyIds) { - return buildStrict(asList(keyIds)); - } + return cloned; + } catch (CloneNotSupportedException e) { + throw new Error("Impossible: CloneNotSupportedException", e); + } + } - static KmsMasterKeyProvider.RegionalClientSupplier clientFactory( - ConcurrentHashMap clientCache, - AWSKMSClientBuilder templateBuilder - ) { - - // Clone again; this MKP builder might be reused to build a second MKP with different creds. - AWSKMSClientBuilder builder = templateBuilder != null ? cloneClientBuilder(templateBuilder) - : AWSKMSClientBuilder.standard(); - - return region -> { - /* Check for early return (Postcondition): If a client already exists, use that. */ - if (clientCache.containsKey(region)) { - return clientCache.get(region); - } - - // We can't just use computeIfAbsent as we need to avoid leaking KMS clients if we're asked to decrypt - // an EDK with a bogus region in its ARN. So we'll install a request handler to identify the first - // successful call, and cache it when we see that. - final KmsMasterKeyProvider.SuccessfulRequestCacher cacher = new KmsMasterKeyProvider.SuccessfulRequestCacher(clientCache, region); - final ArrayList handlers = new ArrayList<>(); - if (builder.getRequestHandlers() != null) { - handlers.addAll(builder.getRequestHandlers()); - } - handlers.add(cacher); - - final AWSKMS kms = cloneClientBuilder(builder) - .withRegion(region) - .withRequestHandlers(handlers.toArray(new RequestHandler2[handlers.size()])) - .build(); - return cacher.setClient(kms); - }; - } + /** + * Sets the default region. This region will be used when specifying key IDs for encryption or + * in {@link AwsKmsMrkAwareMasterKeyProvider#getMasterKey(String)} that are not full ARNs, but + * are instead bare key IDs or aliases. + * + *

If the default region is not specified, the AWS SDK default region will be used. + * + * @see KmsMasterKeyProvider.Builder#withDefaultRegion(String) + * @param defaultRegion The default region to use. + */ + public AwsKmsMrkAwareMasterKeyProvider.Builder withDefaultRegion(String defaultRegion) { + this.defaultRegion_ = defaultRegion; + return this; + } - static AWSKMSClientBuilder cloneClientBuilder(final AWSKMSClientBuilder builder) { - // We need to copy all arguments out of the builder in case it's mutated later on. - // Unfortunately AWSKMSClientBuilder doesn't support .clone() so we'll have to do it by hand. - - if (builder.getEndpoint() != null) { - // We won't be able to set the region later if a custom endpoint is set. - throw new IllegalArgumentException("Setting endpoint configuration is not compatible with passing a " + - "builder to the KmsMasterKeyProvider. Use withCustomClientFactory" + - " instead."); - } - - final AWSKMSClientBuilder newBuilder = AWSKMSClient.builder(); - newBuilder.setClientConfiguration(builder.getClientConfiguration()); - newBuilder.setCredentials(builder.getCredentials()); - newBuilder.setEndpointConfiguration(builder.getEndpoint()); - newBuilder.setMetricsCollector(builder.getMetricsCollector()); - if (builder.getRequestHandlers() != null) { - newBuilder.setRequestHandlers(builder.getRequestHandlers().toArray(new RequestHandler2[0])); - } - return newBuilder; - } + /** + * Sets the region contacted for multi-region keys when in Discovery mode. This region will be + * used when a multi-region key is discovered on decrypt by {@link + * AwsKmsMrkAwareMasterKeyProvider#getMasterKey(String)}. + * + *

+ * + * @param discoveryMrkRegion The region to contact to attempt to decrypt multi-region keys. + */ + public AwsKmsMrkAwareMasterKeyProvider.Builder withDiscoveryMrkRegion( + String discoveryMrkRegion) { + this.discoveryMrkRegion_ = discoveryMrkRegion; + return this; + } - /** - * The AWS SDK has a default process for evaluating the default Region. - * This returns null if no default region is found. - * Because a default region _may_ not be needed. - * - */ - private static String getSdkDefaultRegion() { - try { - return new com.amazonaws.regions.DefaultAwsRegionProviderChain().getRegion(); - } catch (SdkClientException ex) { - return null; - } - } + /** + * Provides a custom factory function that will vend KMS clients. This is provided for advanced + * use cases which require complete control over the client construction process. + * + *

Because the regional client supplier fully controls the client construction process, it is + * not possible to configure the client through methods such as {@link + * #withCredentials(AWSCredentialsProvider)} or {@link #withClientBuilder(AWSKMSClientBuilder)}; + * if you try to use these in combination, an {@link IllegalStateException} will be thrown. + * + * @see + * KmsMasterKeyProvider.Builder#withCustomClientFactory(KmsMasterKeyProvider.RegionalClientSupplier) + */ + public AwsKmsMrkAwareMasterKeyProvider.Builder withCustomClientFactory( + KmsMasterKeyProvider.RegionalClientSupplier regionalClientSupplier) { + if (templateBuilder_ != null) { + throw clientSupplierComboException(); + } + + regionalClientSupplier_ = Optional.of(regionalClientSupplier); + return this; } - public static AwsKmsMrkAwareMasterKeyProvider.Builder builder() { - return new AwsKmsMrkAwareMasterKeyProvider.Builder(); + private RuntimeException clientSupplierComboException() { + return new IllegalStateException( + "withCustomClientFactory cannot be used in conjunction with " + + "withCredentials or withClientBuilder"); } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# On initialization the caller MUST provide: - private AwsKmsMrkAwareMasterKeyProvider( - - KmsMasterKeyProvider.RegionalClientSupplier supplier, - String defaultRegion, - List keyIds, - List grantTokens, - boolean isDiscovery, - DiscoveryFilter discoveryFilter, - String discoveryMrkRegion - ) { - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# The key id list MUST NOT be empty or null in strict mode. - if (!isDiscovery && (keyIds == null || keyIds.isEmpty())) { - throw new IllegalArgumentException("Strict mode must be configured with a non-empty " + - "list of keyIds."); - } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# The key id - //# list MUST NOT contain any null or empty string values. - if (!isDiscovery && (keyIds.contains(null) || keyIds.contains(""))) { - throw new IllegalArgumentException("Strict mode cannot be configured with a " + - "null key identifier."); - } + /** + * Configures the {@link AwsKmsMrkAwareMasterKeyProvider} to use specific credentials. If a + * builder was previously set, this will override whatever credentials it set. + * + * @see KmsMasterKeyProvider.Builder#withCredentials(AWSCredentialsProvider) + */ + public AwsKmsMrkAwareMasterKeyProvider.Builder withCredentials( + AWSCredentialsProvider credentialsProvider) { + if (regionalClientSupplier_.isPresent()) { + throw clientSupplierComboException(); + } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# All AWS KMS - //# key identifiers are be passed to Assert AWS KMS MRK are unique (aws- - //# kms-mrk-are-unique.md#Implementation) and the function MUST return - //# success. - assertMrksAreUnique(keyIds); - /* Precondition: A region is required to contact AWS KMS. - * This is an edge case because the default region will be the same as the SDK default, - * but it is still possible. - */ - if ( - !isDiscovery && - defaultRegion == null && - keyIds - .stream() - .map(identifier -> parseInfoFromKeyArn(identifier)) - .anyMatch(info -> info == null) - ) { - throw new AwsCryptoException("Can't use non-ARN key identifiers or aliases when " + - "no default region is set"); - } - /* Precondition (untested): Discovery filter is only valid in discovery mode. */ - if (!isDiscovery && discoveryFilter != null) { - throw new IllegalArgumentException("Strict mode cannot be configured with a " + - "discovery filter."); - } - /* Precondition (untested): Discovery mode can not have any keys to filter. */ - if (isDiscovery && !keyIds.isEmpty()) { - throw new IllegalArgumentException("Discovery mode can not be configured with keys."); - } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //# If an AWS SDK Default Region can not be - //# obtained initialization MUST fail. - if (isDiscovery && discoveryMrkRegion == null) { - throw new IllegalArgumentException("Discovery MRK region can not be null."); - } + if (templateBuilder_ == null) { + templateBuilder_ = AWSKMSClientBuilder.standard(); + } - this.regionalClientSupplier_ = supplier; - this.defaultRegion_ = defaultRegion; - this.keyIds_ = Collections.unmodifiableList(new ArrayList<>(keyIds)); + templateBuilder_.setCredentials(credentialsProvider); - this.isDiscovery_ = isDiscovery; - this.discoveryFilter_ = discoveryFilter; - this.discoveryMrkRegion_ = discoveryMrkRegion; - this.grantTokens_ = grantTokens; + return this; } - //= compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 - //# The caller MUST provide: /** - * Refactored into a pure function - * to facilitate testing and correctness. + * Configures the {@link AwsKmsMrkAwareMasterKeyProvider} to use specific credentials. If a + * builder was previously set, this will override whatever credentials it set. * + * @see KmsMasterKeyProvider.Builder#withCredentials(AWSCredentials) */ - static void assertMrksAreUnique(List keyIdentifiers) { - - List duplicateMultiRegionKeyIdentifiers = keyIdentifiers - .stream() - /* Collect a map of resource to identifier. - * This lets me group duplicates by "resource". - * This is because the identifier can be either an ARN or a raw identifier. - * By having the both the key id and the identifier I can ensure the uniqueness of - * the key id and the error message to the caller can contain both identifiers - * to facilitate debugging. - */ - .collect(Collectors.groupingBy(AwsKmsMrkAwareMasterKeyProvider::getResourceForResourceTypeKey)) - .entrySet() - .stream() - //= compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 - //# If there are zero duplicate resource ids between the multi-region - //# keys, this function MUST exit successfully - .filter(maybeDuplicate -> maybeDuplicate.getValue().size() > 1) - //= compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 - //# If the list does not contain any multi-Region keys (aws-kms-key- - //# arn.md#identifying-an-aws-kms-multi-region-key) this function MUST - //# exit successfully. - // - /* Postcondition: Filter out duplicate resources that are not multi-region keys. - * I expect only have duplicates of specific multi-region keys. - * In JSON something like - * { - * "mrk-edb7fe6942894d32ac46dbb1c922d574" : [ - * "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - * "arn:aws:kms:us-east-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - * ] - * } - */ - .filter(maybeMrk -> isMRK(maybeMrk.getKey())) - /* Flatten the duplicate identifiers into a single list. */ - .flatMap(mrkEntry -> mrkEntry.getValue().stream()) - .collect(Collectors.toList()); - - //= compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 - //# If any duplicate multi-region resource ids exist, this function MUST - //# yield an error that includes all identifiers with duplicate resource - //# ids not only the first duplicate found. - if (duplicateMultiRegionKeyIdentifiers.size() > 1) { - throw new IllegalArgumentException("Duplicate multi-region keys are not allowed:\n" + - String.join(", ", duplicateMultiRegionKeyIdentifiers)); - } + public AwsKmsMrkAwareMasterKeyProvider.Builder withCredentials(AWSCredentials credentials) { + return withCredentials(new AWSStaticCredentialsProvider(credentials)); } /** - * Helper method for - * @see AwsKmsMrkAwareMasterKeyProvider#assertMrksAreUnique(List) + * Configures the {@link AwsKmsMrkAwareMasterKeyProvider} to use settings from this {@link + * AWSKMSClientBuilder} to configure KMS clients. Note that the region set on this builder will + * be ignored, but all other settings will be propagated into the regional clients. * - * Refoactored into a pure function - * to simplify testing and ensure correctness. + *

This method will overwrite any credentials set using {@link + * #withCredentials(AWSCredentialsProvider)}. * + * @see KmsMasterKeyProvider.Builder#withClientBuilder(AWSKMSClientBuilder) */ - static String getResourceForResourceTypeKey(String identifier) { - final AwsKmsCmkArnInfo info = parseInfoFromKeyArn(identifier); - /* Check for early return (Postcondition): Non-ARNs may be raw resources. - * Raw aliases ('alias/my-key') - * or key ids ('mrk-edb7fe6942894d32ac46dbb1c922d574'). - */ - if (info == null) return identifier; - - /* Check for early return (Postcondition): Return the identifier for non-key resource types. - * I only care about duplicate multi-region *keys*. - * Any other resource type - * should get filtered out. - * I return the entire identifier - * on the off chance that - * a customer has created - * an alias with a name `mrk-*`. - * This way such an alias - * can never accidentally - * collided with an existing multi-region key - * or a duplicate alias. - */ - if (!info.getResourceType().equals("key")) { - return identifier; - } - - /* Postcondition: Return the key id. - * This will be used - * to find different regional replicas of - * the same multi-region key - * because the key id for replicas is always the same. - */ - return info.getResource(); + public AwsKmsMrkAwareMasterKeyProvider.Builder withClientBuilder(AWSKMSClientBuilder builder) { + if (regionalClientSupplier_.isPresent()) { + throw clientSupplierComboException(); + } + final AWSKMSClientBuilder newBuilder = cloneClientBuilder(builder); + this.templateBuilder_ = newBuilder; + + return this; } /** - * Returns "aws-kms" + * Builds the master key provider in Discovery Mode. In Discovery Mode the KMS Master Key + * Provider will attempt to decrypt using any key identifier it discovers in the encrypted + * message. KMS Master Key Providers in Discovery Mode will not encrypt data keys. + * + * @see KmsMasterKeyProvider.Builder#buildDiscovery() */ - @Override - public String getDefaultProviderId() { - return PROVIDER_NAME; + public AwsKmsMrkAwareMasterKeyProvider buildDiscovery() { + final boolean isDiscovery = true; + + return new AwsKmsMrkAwareMasterKeyProvider( + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # The regional client + // # supplier MUST be defined in discovery mode. + regionalClientSupplier_.orElse( + clientFactory(new ConcurrentHashMap<>(), templateBuilder_)), + defaultRegion_, + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # The key id list MUST be empty in discovery mode. + emptyList(), + emptyList(), + isDiscovery, + discoveryFilter_, + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # In + // # discovery mode if a default MRK Region is not configured the AWS SDK + // # Default Region MUST be used. + discoveryMrkRegion_ == null ? defaultRegion_ : discoveryMrkRegion_); } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# The input MUST be the same as the Master Key Provider Get Master Key - //# (../master-key-provider-interface.md#get-master-key) interface. /** - * Added flexibility in matching multi-Region keys from different regions. + * Builds the master key provider in Discovery Mode with a {@link DiscoveryFilter}. In Discovery + * Mode the KMS Master Key Provider will attempt to decrypt using any key identifier it + * discovers in the encrypted message that is accepted by the {@code filter}. KMS Master Key + * Providers in Discovery Mode will not encrypt data keys. * - * @see KmsMasterKey#getMasterKey(String, String) + * @see KmsMasterKeyProvider.Builder#buildDiscovery(DiscoveryFilter) */ - @Override - public AwsKmsMrkAwareMasterKey getMasterKey( - final String providerId, - final String requestedKeyArn - ) throws UnsupportedProviderException, NoSuchMasterKeyException { - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# The function MUST only provide master keys if the input provider id - //# equals "aws-kms". - if (!canProvide(providerId)) { - throw new UnsupportedProviderException(); - } - - /* There SHOULD only be one match. - * An unambiguous multi-region key for the family - * of related multi-region keys is required. - * See `assertMrksAreUnique`. - * However, in the case of single region keys or aliases, - * duplicates _are_ possible. - */ - Optional matchedArn = keyIds_ - .stream() - .filter(t -> awsKmsArnMatchForDecrypt(t, requestedKeyArn)) - .findFirst(); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# In strict mode, the requested AWS KMS key ARN MUST - //# match a member of the configured key ids by using AWS KMS MRK Match - //# for Decrypt (aws-kms-mrk-match-for-decrypt.md#implementation) - //# otherwise this function MUST error. - if (!isDiscovery_ && !matchedArn.isPresent()) { - throw new NoSuchMasterKeyException("Key must be in supplied list of keyIds."); - } - - final AwsKmsCmkArnInfo requestedKeyArnInfo = parseInfoFromKeyArn(requestedKeyArn); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# In discovery mode, the requested - //# AWS KMS key identifier MUST be a well formed AWS KMS ARN. - /* Precondition: Discovery mode requires requestedKeyArn be an ARN. - * This function is called on the encrypt path. - * It _may_ be the case that a raw key id, for example, was configured. - */ - if (isDiscovery_ && requestedKeyArnInfo == null) { - throw new NoSuchMasterKeyException("Cannot use AWS KMS identifiers " - + "when in discovery mode."); - } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# In - //# discovery mode if a discovery filter is configured the requested AWS - //# KMS key ARN's "partition" MUST match the discovery filter's - //# "partition" and the AWS KMS key ARN's "account" MUST exist in the - //# discovery filter's account id set. - if (isDiscovery_ && discoveryFilter_ != null && - !discoveryFilter_.allowsPartitionAndAccount(requestedKeyArnInfo.getPartition(), requestedKeyArnInfo.getAccountId()) - ) { - throw new NoSuchMasterKeyException("Cannot use key in partition " + requestedKeyArnInfo.getPartition() + - " with account id " + requestedKeyArnInfo.getAccountId() + " with configured discovery filter."); - } + public AwsKmsMrkAwareMasterKeyProvider buildDiscovery(DiscoveryFilter filter) { + discoveryFilter_ = filter; - final String regionName_ = extractRegion( - defaultRegion_, - discoveryMrkRegion_, - matchedArn, - requestedKeyArnInfo, - isDiscovery_ - ); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# An AWS KMS client - //# MUST be obtained by calling the regional client supplier with this - //# AWS Region. - AWSKMS kms = regionalClientSupplier_.getClient(regionName_); - - String keyIdentifier = isDiscovery_ - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# In discovery mode a AWS KMS MRK Aware Master Key (aws-kms-mrk-aware- - //# master-key.md) MUST be returned configured with - ? requestedKeyArnInfo.toString(regionName_) - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# In strict mode a AWS KMS MRK Aware Master Key (aws-kms-mrk-aware- - //# master-key.md) MUST be returned configured with - : matchedArn.get(); - - final AwsKmsMrkAwareMasterKey result = AwsKmsMrkAwareMasterKey - .getInstance(kms, keyIdentifier, this); - result.setGrantTokens(grantTokens_); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# The output MUST be the same as the Master Key Provider Get Master Key - //# (../master-key-provider-interface.md#get-master-key) interface. - return result; + return buildDiscovery(); } /** - * Select the correct region from multiple default configurations - * and potentially related multi-Region keys from different regions. + * Builds the master key provider in Strict Mode. KMS Master Key Providers in Strict Mode will + * only attempt to decrypt using key ARNs listed in {@code keyIds}. KMS Master Key Providers in + * Strict Mode will encrypt data keys using the keys listed in {@code keyIds} * - * Refactored into a pure function to facilitate testing and ensure correctness. + *

In Strict Mode, one or more CMKs must be provided. For Master Key Providers that will only + * be used for encryption, you can use any valid KMS key identifier. For providers that will be + * used for decryption, you must use the key ARN; key ids, alias names, and alias ARNs are not + * supported. * + * @see KmsMasterKeyProvider.Builder#buildStrict(List) */ - static String extractRegion( - final String defaultRegion, - final String discoveryMrkRegion, - final Optional matchedArn, - final AwsKmsCmkArnInfo requestedKeyArnInfo, - final boolean isDiscovery - ) { - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# If the requested AWS KMS key identifier is not a well formed ARN the - //# AWS Region MUST be the configured default region this SHOULD be - //# obtained from the AWS SDK. - if (requestedKeyArnInfo == null) return defaultRegion; - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# Otherwise if the requested AWS KMS key - //# identifier is identified as a multi-Region key (aws-kms-key- - //# arn.md#identifying-an-aws-kms-multi-region-key), then AWS Region MUST - //# be the region from the AWS KMS key ARN stored in the provider info - //# from the encrypted data key. - if ( - !isMRK(requestedKeyArnInfo.getResource()) || - !requestedKeyArnInfo.getResourceType().equals("key") - ) { - return requestedKeyArnInfo.getRegion(); - } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# Otherwise if the mode is discovery then - //# the AWS Region MUST be the discovery MRK region. - if (isDiscovery) return discoveryMrkRegion; - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //# Finally if the - //# provider info is identified as a multi-Region key (aws-kms-key- - //# arn.md#identifying-an-aws-kms-multi-region-key) the AWS Region MUST - //# be the region from the AWS KMS key in the configured key ids matched - //# to the requested AWS KMS key by using AWS KMS MRK Match for Decrypt - //# (aws-kms-mrk-match-for-decrypt.md#implementation). - return parseInfoFromKeyArn(matchedArn.get()).getRegion(); + public AwsKmsMrkAwareMasterKeyProvider buildStrict(List keyIds) { + final boolean isDiscovery = false; + + return new AwsKmsMrkAwareMasterKeyProvider( + regionalClientSupplier_.orElse( + clientFactory(new ConcurrentHashMap<>(), templateBuilder_)), + defaultRegion_, + new ArrayList(keyIds), + emptyList(), + isDiscovery, + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # A discovery filter MUST NOT be configured in strict mode. + null, + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # A default MRK Region MUST NOT be configured in strict mode. + null); } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 - //# The input MUST be the same as the Master Key Provider Get Master Keys - //# For Encryption (../master-key-provider-interface.md#get-master-keys- - //# for-encryption) interface. /** - * Returns all CMKs provided to the constructor of this object. - * @see KmsMasterKey#getMasterKeysForEncryption(MasterKeyRequest) + * Builds the master key provider in strict mode. KMS Master Key Providers in Strict Mode will + * only attempt to decrypt using key ARNs listed in {@code keyIds}. KMS Master Key Providers in + * Strict Mode will encrypt data keys using the keys listed in {@code keyIds} + * + *

In Strict Mode, one or more CMKs must be provided. For Master Key Providers that will only + * be used for encryption, you can use any valid KMS key identifier. For providers that will be + * used for decryption, you must use the key ARN; key ids, alias names, and alias ARNs are not + * supported. + * + * @see KmsMasterKeyProvider.Builder#buildStrict(String...) */ - @Override - public List getMasterKeysForEncryption(final MasterKeyRequest request) { - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 - //# If the configured mode is discovery the function MUST return an empty - //# list. - if (isDiscovery_) { - return emptyList(); + public AwsKmsMrkAwareMasterKeyProvider buildStrict(String... keyIds) { + return buildStrict(asList(keyIds)); + } + + static KmsMasterKeyProvider.RegionalClientSupplier clientFactory( + ConcurrentHashMap clientCache, AWSKMSClientBuilder templateBuilder) { + + // Clone again; this MKP builder might be reused to build a second MKP with different creds. + AWSKMSClientBuilder builder = + templateBuilder != null + ? cloneClientBuilder(templateBuilder) + : AWSKMSClientBuilder.standard(); + + return region -> { + /* Check for early return (Postcondition): If a client already exists, use that. */ + if (clientCache.containsKey(region)) { + return clientCache.get(region); } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 - //# If the configured mode is strict this function MUST return a - //# list of master keys obtained by calling Get Master Key (aws-kms-mrk- - //# aware-master-key-provider.md#get-master-key) for each AWS KMS key - //# identifier in the configured key ids - List result = new ArrayList<>(keyIds_.size()); - for (String id : keyIds_) { - result.add(getMasterKey(id)); + + // We can't just use computeIfAbsent as we need to avoid leaking KMS clients if we're asked + // to decrypt + // an EDK with a bogus region in its ARN. So we'll install a request handler to identify the + // first + // successful call, and cache it when we see that. + final KmsMasterKeyProvider.SuccessfulRequestCacher cacher = + new KmsMasterKeyProvider.SuccessfulRequestCacher(clientCache, region); + final ArrayList handlers = new ArrayList<>(); + if (builder.getRequestHandlers() != null) { + handlers.addAll(builder.getRequestHandlers()); } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 - //# The output MUST be the same as the Master Key Provider Get Master - //# Keys For Encryption (../master-key-provider-interface.md#get-master- - //# keys-for-encryption) interface. - return result; + handlers.add(cacher); + + final AWSKMS kms = + cloneClientBuilder(builder) + .withRegion(region) + .withRequestHandlers(handlers.toArray(new RequestHandler2[handlers.size()])) + .build(); + return cacher.setClient(kms); + }; + } + + static AWSKMSClientBuilder cloneClientBuilder(final AWSKMSClientBuilder builder) { + // We need to copy all arguments out of the builder in case it's mutated later on. + // Unfortunately AWSKMSClientBuilder doesn't support .clone() so we'll have to do it by hand. + + if (builder.getEndpoint() != null) { + // We won't be able to set the region later if a custom endpoint is set. + throw new IllegalArgumentException( + "Setting endpoint configuration is not compatible with passing a " + + "builder to the KmsMasterKeyProvider. Use withCustomClientFactory" + + " instead."); + } + + final AWSKMSClientBuilder newBuilder = AWSKMSClient.builder(); + newBuilder.setClientConfiguration(builder.getClientConfiguration()); + newBuilder.setCredentials(builder.getCredentials()); + newBuilder.setEndpointConfiguration(builder.getEndpoint()); + newBuilder.setMetricsCollector(builder.getMetricsCollector()); + if (builder.getRequestHandlers() != null) { + newBuilder.setRequestHandlers(builder.getRequestHandlers().toArray(new RequestHandler2[0])); + } + return newBuilder; } - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# The input MUST be the same as the Master Key Provider Decrypt Data - //# Key (../master-key-provider-interface.md#decrypt-data-key) interface. /** - * @see KmsMasterKey#decryptDataKey(CryptoAlgorithm, Collection, Map) - * @throws AwsCryptoException + * The AWS SDK has a default process for evaluating the default Region. This returns null if no + * default region is found. Because a default region _may_ not be needed. */ - @Override - public DataKey decryptDataKey(final CryptoAlgorithm algorithm, - final Collection encryptedDataKeys, - final Map encryptionContext) - throws AwsCryptoException { - final List exceptions = new ArrayList<>(); - - return encryptedDataKeys - .stream() - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# The set of encrypted data keys MUST first be filtered to match this - //# master key's configuration. - .filter(edk -> { - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# To match the encrypted data key's - //# provider ID MUST exactly match the value "aws-kms". - if (!canProvide(edk.getProviderId())) return false; - - final String providerInfo = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); - final AwsKmsCmkArnInfo providerArnInfo = parseInfoFromKeyArn(providerInfo); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# Additionally - //# each provider info MUST be a valid AWS KMS ARN (aws-kms-key-arn.md#a- - //# valid-aws-kms-arn) with a resource type of "key". - if (providerArnInfo == null || !"key".equals(providerArnInfo.getResourceType())) { - throw new IllegalStateException("Invalid provider info in message."); - } - return true; - }) - .map(edk -> { - try { - final String keyArn = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# For each encrypted data key in the filtered set, one at a time, the - //# master key provider MUST call Get Master Key (aws-kms-mrk-aware- - //# master-key-provider.md#get-master-key) with the encrypted data key's - //# provider info as the AWS KMS key ARN. - // This will throw if we can't use this key for whatever reason - return getMasterKey( - edk.getProviderId(), - keyArn) - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# It MUST call Decrypt Data Key - //# (aws-kms-mrk-aware-master-key.md#decrypt-data-key) on this master key - //# with the input algorithm, this single encrypted data key, and the - //# input encryption context. - .decryptDataKey(algorithm, singletonList(edk), encryptionContext); - } catch (final Exception ex) { - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# If this attempt results in an error, then - //# these errors MUST be collected. - exceptions.add(ex); - return null; - } - }) - /* Need to filter null because an Optional of a null is crazy. - * `findFirst` will throw if it sees `null`. - */ - .filter(Objects::nonNull) - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# If the decrypt data key call is - //# successful, then this function MUST return this result and not - //# attempt to decrypt any more encrypted data keys. - .findFirst() - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# If all the input encrypted data keys have been processed then this - //# function MUST yield an error that includes all the collected errors. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //# The output MUST be the same as the Master Key Provider Decrypt Data - //# Key (../master-key-provider-interface.md#decrypt-data-key) interface. - .orElseThrow(() -> buildCannotDecryptDksException(exceptions)); + private static String getSdkDefaultRegion() { + try { + return new com.amazonaws.regions.DefaultAwsRegionProviderChain().getRegion(); + } catch (SdkClientException ex) { + return null; + } + } + } + + public static AwsKmsMrkAwareMasterKeyProvider.Builder builder() { + return new AwsKmsMrkAwareMasterKeyProvider.Builder(); + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # On initialization the caller MUST provide: + private AwsKmsMrkAwareMasterKeyProvider( + KmsMasterKeyProvider.RegionalClientSupplier supplier, + String defaultRegion, + List keyIds, + List grantTokens, + boolean isDiscovery, + DiscoveryFilter discoveryFilter, + String discoveryMrkRegion) { + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # The key id list MUST NOT be empty or null in strict mode. + if (!isDiscovery && (keyIds == null || keyIds.isEmpty())) { + throw new IllegalArgumentException( + "Strict mode must be configured with a non-empty " + "list of keyIds."); + } + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # The key id + // # list MUST NOT contain any null or empty string values. + if (!isDiscovery && (keyIds.contains(null) || keyIds.contains(""))) { + throw new IllegalArgumentException( + "Strict mode cannot be configured with a " + "null key identifier."); } - public List getGrantTokens() { - return new ArrayList<>(grantTokens_); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # All AWS KMS + // # key identifiers are be passed to Assert AWS KMS MRK are unique (aws- + // # kms-mrk-are-unique.md#Implementation) and the function MUST return + // # success. + assertMrksAreUnique(keyIds); + /* Precondition: A region is required to contact AWS KMS. + * This is an edge case because the default region will be the same as the SDK default, + * but it is still possible. + */ + if (!isDiscovery + && defaultRegion == null + && keyIds.stream() + .map(identifier -> parseInfoFromKeyArn(identifier)) + .anyMatch(info -> info == null)) { + throw new AwsCryptoException( + "Can't use non-ARN key identifiers or aliases when " + "no default region is set"); + } + /* Precondition (untested): Discovery filter is only valid in discovery mode. */ + if (!isDiscovery && discoveryFilter != null) { + throw new IllegalArgumentException( + "Strict mode cannot be configured with a " + "discovery filter."); + } + /* Precondition (untested): Discovery mode can not have any keys to filter. */ + if (isDiscovery && !keyIds.isEmpty()) { + throw new IllegalArgumentException("Discovery mode can not be configured with keys."); + } + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // # If an AWS SDK Default Region can not be + // # obtained initialization MUST fail. + if (isDiscovery && discoveryMrkRegion == null) { + throw new IllegalArgumentException("Discovery MRK region can not be null."); } - /** - * Returns a new {@link AwsKmsMrkAwareMasterKeyProvider} that is configured identically to this one, except with the given list - * of grant tokens. The grant token list in the returned provider is immutable (but can be further overridden by - * invoking withGrantTokens again). - * + this.regionalClientSupplier_ = supplier; + this.defaultRegion_ = defaultRegion; + this.keyIds_ = Collections.unmodifiableList(new ArrayList<>(keyIds)); + + this.isDiscovery_ = isDiscovery; + this.discoveryFilter_ = discoveryFilter; + this.discoveryMrkRegion_ = discoveryMrkRegion; + this.grantTokens_ = grantTokens; + } + + // = compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 + // # The caller MUST provide: + /** Refactored into a pure function to facilitate testing and correctness. */ + static void assertMrksAreUnique(List keyIdentifiers) { + + List duplicateMultiRegionKeyIdentifiers = + keyIdentifiers.stream() + /* Collect a map of resource to identifier. + * This lets me group duplicates by "resource". + * This is because the identifier can be either an ARN or a raw identifier. + * By having the both the key id and the identifier I can ensure the uniqueness of + * the key id and the error message to the caller can contain both identifiers + * to facilitate debugging. + */ + .collect( + Collectors.groupingBy( + AwsKmsMrkAwareMasterKeyProvider::getResourceForResourceTypeKey)) + .entrySet() + .stream() + // = compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 + // # If there are zero duplicate resource ids between the multi-region + // # keys, this function MUST exit successfully + .filter(maybeDuplicate -> maybeDuplicate.getValue().size() > 1) + // = compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 + // # If the list does not contain any multi-Region keys (aws-kms-key- + // # arn.md#identifying-an-aws-kms-multi-region-key) this function MUST + // # exit successfully. + // + /* Postcondition: Filter out duplicate resources that are not multi-region keys. + * I expect only have duplicates of specific multi-region keys. + * In JSON something like + * { + * "mrk-edb7fe6942894d32ac46dbb1c922d574" : [ + * "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + * "arn:aws:kms:us-east-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" + * ] + * } + */ + .filter(maybeMrk -> isMRK(maybeMrk.getKey())) + /* Flatten the duplicate identifiers into a single list. */ + .flatMap(mrkEntry -> mrkEntry.getValue().stream()) + .collect(Collectors.toList()); + + // = compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 + // # If any duplicate multi-region resource ids exist, this function MUST + // # yield an error that includes all identifiers with duplicate resource + // # ids not only the first duplicate found. + if (duplicateMultiRegionKeyIdentifiers.size() > 1) { + throw new IllegalArgumentException( + "Duplicate multi-region keys are not allowed:\n" + + String.join(", ", duplicateMultiRegionKeyIdentifiers)); + } + } + + /** + * Helper method for + * + * @see AwsKmsMrkAwareMasterKeyProvider#assertMrksAreUnique(List) + *

Refoactored into a pure function to simplify testing and ensure correctness. + */ + static String getResourceForResourceTypeKey(String identifier) { + final AwsKmsCmkArnInfo info = parseInfoFromKeyArn(identifier); + /* Check for early return (Postcondition): Non-ARNs may be raw resources. + * Raw aliases ('alias/my-key') + * or key ids ('mrk-edb7fe6942894d32ac46dbb1c922d574'). + */ + if (info == null) return identifier; + + /* Check for early return (Postcondition): Return the identifier for non-key resource types. + * I only care about duplicate multi-region *keys*. + * Any other resource type + * should get filtered out. + * I return the entire identifier + * on the off chance that + * a customer has created + * an alias with a name `mrk-*`. + * This way such an alias + * can never accidentally + * collided with an existing multi-region key + * or a duplicate alias. */ - public AwsKmsMrkAwareMasterKeyProvider withGrantTokens(List grantTokens) { - grantTokens = Collections.unmodifiableList(new ArrayList<>(grantTokens)); - - return new AwsKmsMrkAwareMasterKeyProvider( - regionalClientSupplier_, - defaultRegion_, - keyIds_, - grantTokens, - isDiscovery_, - discoveryFilter_, - discoveryMrkRegion_ - ); + if (!info.getResourceType().equals("key")) { + return identifier; } - /** - * Returns a new {@link AwsKmsMrkAwareMasterKeyProvider} that is configured identically to this one, except with the given list - * of grant tokens. The grant token list in the returned provider is immutable (but can be further overridden by - * invoking withGrantTokens again). - * + /* Postcondition: Return the key id. + * This will be used + * to find different regional replicas of + * the same multi-region key + * because the key id for replicas is always the same. + */ + return info.getResource(); + } + + /** Returns "aws-kms" */ + @Override + public String getDefaultProviderId() { + return PROVIDER_NAME; + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # The input MUST be the same as the Master Key Provider Get Master Key + // # (../master-key-provider-interface.md#get-master-key) interface. + /** + * Added flexibility in matching multi-Region keys from different regions. + * + * @see KmsMasterKey#getMasterKey(String, String) + */ + @Override + public AwsKmsMrkAwareMasterKey getMasterKey(final String providerId, final String requestedKeyArn) + throws UnsupportedProviderException, NoSuchMasterKeyException { + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # The function MUST only provide master keys if the input provider id + // # equals "aws-kms". + if (!canProvide(providerId)) { + throw new UnsupportedProviderException(); + } + + /* There SHOULD only be one match. + * An unambiguous multi-region key for the family + * of related multi-region keys is required. + * See `assertMrksAreUnique`. + * However, in the case of single region keys or aliases, + * duplicates _are_ possible. + */ + Optional matchedArn = + keyIds_.stream().filter(t -> awsKmsArnMatchForDecrypt(t, requestedKeyArn)).findFirst(); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # In strict mode, the requested AWS KMS key ARN MUST + // # match a member of the configured key ids by using AWS KMS MRK Match + // # for Decrypt (aws-kms-mrk-match-for-decrypt.md#implementation) + // # otherwise this function MUST error. + if (!isDiscovery_ && !matchedArn.isPresent()) { + throw new NoSuchMasterKeyException("Key must be in supplied list of keyIds."); + } + + final AwsKmsCmkArnInfo requestedKeyArnInfo = parseInfoFromKeyArn(requestedKeyArn); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # In discovery mode, the requested + // # AWS KMS key identifier MUST be a well formed AWS KMS ARN. + /* Precondition: Discovery mode requires requestedKeyArn be an ARN. + * This function is called on the encrypt path. + * It _may_ be the case that a raw key id, for example, was configured. */ - public AwsKmsMrkAwareMasterKeyProvider withGrantTokens(String... grantTokens) { - return withGrantTokens(asList(grantTokens)); + if (isDiscovery_ && requestedKeyArnInfo == null) { + throw new NoSuchMasterKeyException( + "Cannot use AWS KMS identifiers " + "when in discovery mode."); + } + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # In + // # discovery mode if a discovery filter is configured the requested AWS + // # KMS key ARN's "partition" MUST match the discovery filter's + // # "partition" and the AWS KMS key ARN's "account" MUST exist in the + // # discovery filter's account id set. + if (isDiscovery_ + && discoveryFilter_ != null + && !discoveryFilter_.allowsPartitionAndAccount( + requestedKeyArnInfo.getPartition(), requestedKeyArnInfo.getAccountId())) { + throw new NoSuchMasterKeyException( + "Cannot use key in partition " + + requestedKeyArnInfo.getPartition() + + " with account id " + + requestedKeyArnInfo.getAccountId() + + " with configured discovery filter."); } + final String regionName_ = + extractRegion( + defaultRegion_, discoveryMrkRegion_, matchedArn, requestedKeyArnInfo, isDiscovery_); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # An AWS KMS client + // # MUST be obtained by calling the regional client supplier with this + // # AWS Region. + AWSKMS kms = regionalClientSupplier_.getClient(regionName_); + + String keyIdentifier = + isDiscovery_ + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # In discovery mode a AWS KMS MRK Aware Master Key (aws-kms-mrk-aware- + // # master-key.md) MUST be returned configured with + ? requestedKeyArnInfo.toString(regionName_) + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # In strict mode a AWS KMS MRK Aware Master Key (aws-kms-mrk-aware- + // # master-key.md) MUST be returned configured with + : matchedArn.get(); + + final AwsKmsMrkAwareMasterKey result = + AwsKmsMrkAwareMasterKey.getInstance(kms, keyIdentifier, this); + result.setGrantTokens(grantTokens_); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # The output MUST be the same as the Master Key Provider Get Master Key + // # (../master-key-provider-interface.md#get-master-key) interface. + return result; + } + + /** + * Select the correct region from multiple default configurations and potentially related + * multi-Region keys from different regions. + * + *

Refactored into a pure function to facilitate testing and ensure correctness. + */ + static String extractRegion( + final String defaultRegion, + final String discoveryMrkRegion, + final Optional matchedArn, + final AwsKmsCmkArnInfo requestedKeyArnInfo, + final boolean isDiscovery) { + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # If the requested AWS KMS key identifier is not a well formed ARN the + // # AWS Region MUST be the configured default region this SHOULD be + // # obtained from the AWS SDK. + if (requestedKeyArnInfo == null) return defaultRegion; + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # Otherwise if the requested AWS KMS key + // # identifier is identified as a multi-Region key (aws-kms-key- + // # arn.md#identifying-an-aws-kms-multi-region-key), then AWS Region MUST + // # be the region from the AWS KMS key ARN stored in the provider info + // # from the encrypted data key. + if (!isMRK(requestedKeyArnInfo.getResource()) + || !requestedKeyArnInfo.getResourceType().equals("key")) { + return requestedKeyArnInfo.getRegion(); + } + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # Otherwise if the mode is discovery then + // # the AWS Region MUST be the discovery MRK region. + if (isDiscovery) return discoveryMrkRegion; + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // # Finally if the + // # provider info is identified as a multi-Region key (aws-kms-key- + // # arn.md#identifying-an-aws-kms-multi-region-key) the AWS Region MUST + // # be the region from the AWS KMS key in the configured key ids matched + // # to the requested AWS KMS key by using AWS KMS MRK Match for Decrypt + // # (aws-kms-mrk-match-for-decrypt.md#implementation). + return parseInfoFromKeyArn(matchedArn.get()).getRegion(); + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 + // # The input MUST be the same as the Master Key Provider Get Master Keys + // # For Encryption (../master-key-provider-interface.md#get-master-keys- + // # for-encryption) interface. + /** + * Returns all CMKs provided to the constructor of this object. + * + * @see KmsMasterKey#getMasterKeysForEncryption(MasterKeyRequest) + */ + @Override + public List getMasterKeysForEncryption(final MasterKeyRequest request) { + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 + // # If the configured mode is discovery the function MUST return an empty + // # list. + if (isDiscovery_) { + return emptyList(); + } + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 + // # If the configured mode is strict this function MUST return a + // # list of master keys obtained by calling Get Master Key (aws-kms-mrk- + // # aware-master-key-provider.md#get-master-key) for each AWS KMS key + // # identifier in the configured key ids + List result = new ArrayList<>(keyIds_.size()); + for (String id : keyIds_) { + result.add(getMasterKey(id)); + } + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 + // # The output MUST be the same as the Master Key Provider Get Master + // # Keys For Encryption (../master-key-provider-interface.md#get-master- + // # keys-for-encryption) interface. + return result; + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # The input MUST be the same as the Master Key Provider Decrypt Data + // # Key (../master-key-provider-interface.md#decrypt-data-key) interface. + /** + * @see KmsMasterKey#decryptDataKey(CryptoAlgorithm, Collection, Map) + * @throws AwsCryptoException + */ + @Override + public DataKey decryptDataKey( + final CryptoAlgorithm algorithm, + final Collection encryptedDataKeys, + final Map encryptionContext) + throws AwsCryptoException { + final List exceptions = new ArrayList<>(); + + return encryptedDataKeys.stream() + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # The set of encrypted data keys MUST first be filtered to match this + // # master key's configuration. + .filter( + edk -> { + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # To match the encrypted data key's + // # provider ID MUST exactly match the value "aws-kms". + if (!canProvide(edk.getProviderId())) return false; + + final String providerInfo = + new String(edk.getProviderInformation(), StandardCharsets.UTF_8); + final AwsKmsCmkArnInfo providerArnInfo = parseInfoFromKeyArn(providerInfo); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # Additionally + // # each provider info MUST be a valid AWS KMS ARN (aws-kms-key-arn.md#a- + // # valid-aws-kms-arn) with a resource type of "key". + if (providerArnInfo == null || !"key".equals(providerArnInfo.getResourceType())) { + throw new IllegalStateException("Invalid provider info in message."); + } + return true; + }) + .map( + edk -> { + try { + final String keyArn = + new String(edk.getProviderInformation(), StandardCharsets.UTF_8); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # For each encrypted data key in the filtered set, one at a time, the + // # master key provider MUST call Get Master Key (aws-kms-mrk-aware- + // # master-key-provider.md#get-master-key) with the encrypted data key's + // # provider info as the AWS KMS key ARN. + // This will throw if we can't use this key for whatever reason + return getMasterKey(edk.getProviderId(), keyArn) + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # It MUST call Decrypt Data Key + // # (aws-kms-mrk-aware-master-key.md#decrypt-data-key) on this master key + // # with the input algorithm, this single encrypted data key, and the + // # input encryption context. + .decryptDataKey(algorithm, singletonList(edk), encryptionContext); + } catch (final Exception ex) { + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # If this attempt results in an error, then + // # these errors MUST be collected. + exceptions.add(ex); + return null; + } + }) + /* Need to filter null because an Optional of a null is crazy. + * `findFirst` will throw if it sees `null`. + */ + .filter(Objects::nonNull) + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # If the decrypt data key call is + // # successful, then this function MUST return this result and not + // # attempt to decrypt any more encrypted data keys. + .findFirst() + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # If all the input encrypted data keys have been processed then this + // # function MUST yield an error that includes all the collected errors. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // # The output MUST be the same as the Master Key Provider Decrypt Data + // # Key (../master-key-provider-interface.md#decrypt-data-key) interface. + .orElseThrow(() -> buildCannotDecryptDksException(exceptions)); + } + + public List getGrantTokens() { + return new ArrayList<>(grantTokens_); + } + + /** + * Returns a new {@link AwsKmsMrkAwareMasterKeyProvider} that is configured identically to this + * one, except with the given list of grant tokens. The grant token list in the returned provider + * is immutable (but can be further overridden by invoking withGrantTokens again). + */ + public AwsKmsMrkAwareMasterKeyProvider withGrantTokens(List grantTokens) { + grantTokens = Collections.unmodifiableList(new ArrayList<>(grantTokens)); + + return new AwsKmsMrkAwareMasterKeyProvider( + regionalClientSupplier_, + defaultRegion_, + keyIds_, + grantTokens, + isDiscovery_, + discoveryFilter_, + discoveryMrkRegion_); + } + + /** + * Returns a new {@link AwsKmsMrkAwareMasterKeyProvider} that is configured identically to this + * one, except with the given list of grant tokens. The grant token list in the returned provider + * is immutable (but can be further overridden by invoking withGrantTokens again). + */ + public AwsKmsMrkAwareMasterKeyProvider withGrantTokens(String... grantTokens) { + return withGrantTokens(asList(grantTokens)); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/DiscoveryFilter.java b/src/main/java/com/amazonaws/encryptionsdk/kms/DiscoveryFilter.java index 32e9313f1..1eda3541b 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/DiscoveryFilter.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/DiscoveryFilter.java @@ -3,52 +3,51 @@ package com.amazonaws.encryptionsdk.kms; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; /** - * This class stores the configuration for filtering AWS KMS CMK ARNs - * by AWS account ID and partition. + * This class stores the configuration for filtering AWS KMS CMK ARNs by AWS account ID and + * partition. * - * The filter allows a KMS CMK if its partition matches {@code partition} - * and its accountId is included in {@code accountIds}. + *

The filter allows a KMS CMK if its partition matches {@code partition} and its accountId is + * included in {@code accountIds}. */ public final class DiscoveryFilter { - private final String partition_; - private final Collection accountIds_; - - public DiscoveryFilter(String partition, String... accountIds) { - this(partition, Arrays.asList(accountIds)); + private final String partition_; + private final Collection accountIds_; + + public DiscoveryFilter(String partition, String... accountIds) { + this(partition, Arrays.asList(accountIds)); + } + + public DiscoveryFilter(String partition, Collection accountIds) { + if (partition == null || partition.isEmpty()) { + throw new IllegalArgumentException( + "Discovery filter cannot be configured without " + "a partition."); + } else if (accountIds == null || accountIds.isEmpty()) { + throw new IllegalArgumentException( + "Discovery filter cannot be configured without " + "account IDs."); + } else if (accountIds.contains(null) || accountIds.contains("")) { + throw new IllegalArgumentException( + "Discovery filter cannot be configured with " + "null or empty account IDs."); } - public DiscoveryFilter(String partition, Collection accountIds) { - if (partition == null || partition.isEmpty()) { - throw new IllegalArgumentException("Discovery filter cannot be configured without " + - "a partition."); - } else if (accountIds == null || accountIds.isEmpty()) { - throw new IllegalArgumentException("Discovery filter cannot be configured without " + - "account IDs."); - } else if (accountIds.contains(null) || accountIds.contains("")) { - throw new IllegalArgumentException("Discovery filter cannot be configured with " + - "null or empty account IDs."); - } - - partition_ = partition; - accountIds_ = new HashSet(accountIds); - } + partition_ = partition; + accountIds_ = new HashSet(accountIds); + } - public String getPartition() { - return partition_; - } + public String getPartition() { + return partition_; + } - public Collection getAccountIds() { - return Collections.unmodifiableSet(new HashSet<>(accountIds_)); - } + public Collection getAccountIds() { + return Collections.unmodifiableSet(new HashSet<>(accountIds_)); + } - public boolean allowsPartitionAndAccount(String partition, String accountId) { - return (partition_.equals(partition) && accountIds_.contains(accountId)); - } + public boolean allowsPartitionAndAccount(String partition, String accountId) { + return (partition_.equals(partition) && accountIds_.contains(accountId)); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java index e4783d6d2..2a74d0c1e 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,20 +13,8 @@ package com.amazonaws.encryptionsdk.kms; -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.function.Supplier; - import com.amazonaws.AmazonServiceException; import com.amazonaws.AmazonWebServiceRequest; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.DataKey; @@ -43,146 +31,169 @@ import com.amazonaws.services.kms.model.EncryptResult; import com.amazonaws.services.kms.model.GenerateDataKeyRequest; import com.amazonaws.services.kms.model.GenerateDataKeyResult; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; /** - * Represents a single Customer Master Key (CMK) and is used to encrypt/decrypt data with - * {@link AwsCrypto}. + * Represents a single Customer Master Key (CMK) and is used to encrypt/decrypt data with {@link + * AwsCrypto}. * - * This component is not multi-Region key aware, - * and will treat every AWS KMS identifier as regionally isolated. + *

This component is not multi-Region key aware, and will treat every AWS KMS identifier as + * regionally isolated. */ public final class KmsMasterKey extends MasterKey implements KmsMethods { - private static final String USER_AGENT = VersionInfo.loadUserAgent(); - private final Supplier kms_; - private final MasterKeyProvider sourceProvider_; - private final String id_; - private final List grantTokens_ = new ArrayList<>(); - - private T updateUserAgent(T request) { - request.getRequestClientOptions().appendUserAgent(USER_AGENT); - - return request; - } - - static KmsMasterKey getInstance(final Supplier kms, final String id, - final MasterKeyProvider provider) { - return new KmsMasterKey(kms, id, provider); - } - - private KmsMasterKey(final Supplier kms, final String id, final MasterKeyProvider provider) { - kms_ = kms; - id_ = id; - sourceProvider_ = provider; - } - - @Override - public String getProviderId() { - return sourceProvider_.getDefaultProviderId(); - } - - @Override - public String getKeyId() { - return id_; - } - - @Override - public DataKey generateDataKey(final CryptoAlgorithm algorithm, - final Map encryptionContext) { - final GenerateDataKeyResult gdkResult = kms_.get().generateDataKey(updateUserAgent( - new GenerateDataKeyRequest() + private static final String USER_AGENT = VersionInfo.loadUserAgent(); + private final Supplier kms_; + private final MasterKeyProvider sourceProvider_; + private final String id_; + private final List grantTokens_ = new ArrayList<>(); + + private T updateUserAgent(T request) { + request.getRequestClientOptions().appendUserAgent(USER_AGENT); + + return request; + } + + static KmsMasterKey getInstance( + final Supplier kms, final String id, final MasterKeyProvider provider) { + return new KmsMasterKey(kms, id, provider); + } + + private KmsMasterKey( + final Supplier kms, final String id, final MasterKeyProvider provider) { + kms_ = kms; + id_ = id; + sourceProvider_ = provider; + } + + @Override + public String getProviderId() { + return sourceProvider_.getDefaultProviderId(); + } + + @Override + public String getKeyId() { + return id_; + } + + @Override + public DataKey generateDataKey( + final CryptoAlgorithm algorithm, final Map encryptionContext) { + final GenerateDataKeyResult gdkResult = + kms_.get() + .generateDataKey( + updateUserAgent( + new GenerateDataKeyRequest() .withKeyId(getKeyId()) .withNumberOfBytes(algorithm.getDataKeyLength()) .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_) - )); - final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; - gdkResult.getPlaintext().get(rawKey); - if (gdkResult.getPlaintext().remaining() > 0) { - throw new IllegalStateException("Recieved an unexpected number of bytes from KMS"); - } - final byte[] encryptedKey = new byte[gdkResult.getCiphertextBlob().remaining()]; - gdkResult.getCiphertextBlob().get(encryptedKey); - - final SecretKeySpec key = new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()); - return new DataKey<>(key, encryptedKey, gdkResult.getKeyId().getBytes(StandardCharsets.UTF_8), this); + .withGrantTokens(grantTokens_))); + final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; + gdkResult.getPlaintext().get(rawKey); + if (gdkResult.getPlaintext().remaining() > 0) { + throw new IllegalStateException("Recieved an unexpected number of bytes from KMS"); } - - @Override - public void setGrantTokens(final List grantTokens) { - grantTokens_.clear(); - grantTokens_.addAll(grantTokens); - } - - @Override - public List getGrantTokens() { - return grantTokens_; + final byte[] encryptedKey = new byte[gdkResult.getCiphertextBlob().remaining()]; + gdkResult.getCiphertextBlob().get(encryptedKey); + + final SecretKeySpec key = new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()); + return new DataKey<>( + key, encryptedKey, gdkResult.getKeyId().getBytes(StandardCharsets.UTF_8), this); + } + + @Override + public void setGrantTokens(final List grantTokens) { + grantTokens_.clear(); + grantTokens_.addAll(grantTokens); + } + + @Override + public List getGrantTokens() { + return grantTokens_; + } + + @Override + public void addGrantToken(final String grantToken) { + grantTokens_.add(grantToken); + } + + @Override + public DataKey encryptDataKey( + final CryptoAlgorithm algorithm, + final Map encryptionContext, + final DataKey dataKey) { + final SecretKey key = dataKey.getKey(); + if (!key.getFormat().equals("RAW")) { + throw new IllegalArgumentException("Only RAW encoded keys are supported"); } - - @Override - public void addGrantToken(final String grantToken) { - grantTokens_.add(grantToken); + try { + final EncryptResult encryptResult = + kms_.get() + .encrypt( + updateUserAgent( + new EncryptRequest() + .withKeyId(id_) + .withPlaintext(ByteBuffer.wrap(key.getEncoded())) + .withEncryptionContext(encryptionContext) + .withGrantTokens(grantTokens_))); + final byte[] edk = new byte[encryptResult.getCiphertextBlob().remaining()]; + encryptResult.getCiphertextBlob().get(edk); + return new DataKey<>( + dataKey.getKey(), edk, encryptResult.getKeyId().getBytes(StandardCharsets.UTF_8), this); + } catch (final AmazonServiceException asex) { + throw new AwsCryptoException(asex); } - - @Override - public DataKey encryptDataKey(final CryptoAlgorithm algorithm, - final Map encryptionContext, - final DataKey dataKey) { - final SecretKey key = dataKey.getKey(); - if (!key.getFormat().equals("RAW")) { - throw new IllegalArgumentException("Only RAW encoded keys are supported"); + } + + @Override + public DataKey decryptDataKey( + final CryptoAlgorithm algorithm, + final Collection encryptedDataKeys, + final Map encryptionContext) + throws UnsupportedProviderException, AwsCryptoException { + final List exceptions = new ArrayList<>(); + for (final EncryptedDataKey edk : encryptedDataKeys) { + try { + final String edkKeyId = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); + if (!edkKeyId.equals(id_)) { + continue; } - try { - final EncryptResult encryptResult = kms_.get().encrypt(updateUserAgent( - new EncryptRequest() - .withKeyId(id_) - .withPlaintext(ByteBuffer.wrap(key.getEncoded())) + final DecryptResult decryptResult = + kms_.get() + .decrypt( + updateUserAgent( + new DecryptRequest() + .withCiphertextBlob(ByteBuffer.wrap(edk.getEncryptedDataKey())) .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_))); - final byte[] edk = new byte[encryptResult.getCiphertextBlob().remaining()]; - encryptResult.getCiphertextBlob().get(edk); - return new DataKey<>(dataKey.getKey(), edk, encryptResult.getKeyId().getBytes(StandardCharsets.UTF_8), this); - } catch (final AmazonServiceException asex) { - throw new AwsCryptoException(asex); + .withGrantTokens(grantTokens_) + .withKeyId(edkKeyId))); + if (decryptResult.getKeyId() == null) { + throw new IllegalStateException("Received an empty keyId from KMS"); } - } - - @Override - public DataKey decryptDataKey(final CryptoAlgorithm algorithm, - final Collection encryptedDataKeys, - final Map encryptionContext) - throws UnsupportedProviderException, AwsCryptoException { - final List exceptions = new ArrayList<>(); - for (final EncryptedDataKey edk : encryptedDataKeys) { - try { - final String edkKeyId = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); - if (!edkKeyId.equals(id_)) { - continue; - } - final DecryptResult decryptResult = kms_.get().decrypt(updateUserAgent( - new DecryptRequest() - .withCiphertextBlob(ByteBuffer.wrap(edk.getEncryptedDataKey())) - .withEncryptionContext(encryptionContext) - .withGrantTokens(grantTokens_) - .withKeyId(edkKeyId))); - if (decryptResult.getKeyId() == null) { - throw new IllegalStateException("Received an empty keyId from KMS"); - } - if (decryptResult.getKeyId().equals(id_)) { - final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; - decryptResult.getPlaintext().get(rawKey); - if (decryptResult.getPlaintext().remaining() > 0) { - throw new IllegalStateException("Received an unexpected number of bytes from KMS"); - } - return new DataKey<>( - new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()), - edk.getEncryptedDataKey(), - edk.getProviderInformation(), this); - } - } catch (final AmazonServiceException awsex) { - exceptions.add(awsex); - } + if (decryptResult.getKeyId().equals(id_)) { + final byte[] rawKey = new byte[algorithm.getDataKeyLength()]; + decryptResult.getPlaintext().get(rawKey); + if (decryptResult.getPlaintext().remaining() > 0) { + throw new IllegalStateException("Received an unexpected number of bytes from KMS"); + } + return new DataKey<>( + new SecretKeySpec(rawKey, algorithm.getDataKeyAlgo()), + edk.getEncryptedDataKey(), + edk.getProviderInformation(), + this); } - - throw buildCannotDecryptDksException(exceptions); + } catch (final AmazonServiceException awsex) { + exceptions.add(awsex); + } } + + throw buildCannotDecryptDksException(exceptions); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java index 8a1837de7..10aa8f985 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java @@ -3,22 +3,12 @@ package com.amazonaws.encryptionsdk.kms; +import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.parseInfoFromKeyArn; import static java.util.Arrays.asList; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Supplier; - import com.amazonaws.AmazonServiceException; -import com.amazonaws.ClientConfiguration; import com.amazonaws.Request; import com.amazonaws.Response; import com.amazonaws.auth.AWSCredentials; @@ -35,540 +25,556 @@ import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; import com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo; import com.amazonaws.handlers.RequestHandler2; -import com.amazonaws.regions.Region; -import com.amazonaws.regions.Regions; -import com.amazonaws.regions.RegionUtils; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClient; import com.amazonaws.services.kms.AWSKMSClientBuilder; -import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.parseInfoFromKeyArn; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; /** * Provides {@link MasterKey}s backed by the AWS Key Management Service. This object is regional and * if you want to use keys from multiple regions, you'll need multiple copies of this object. * - * This component is not multi-Region key aware, and will treat every AWS KMS identifier as + *

This component is not multi-Region key aware, and will treat every AWS KMS identifier as * regionally isolated. */ public class KmsMasterKeyProvider extends MasterKeyProvider implements KmsMethods { - private static final String PROVIDER_NAME = "aws-kms"; - private final List keyIds_; - private final List grantTokens_; - - private final boolean isDiscovery_; - private final DiscoveryFilter discoveryFilter_; - - private final RegionalClientSupplier regionalClientSupplier_; - private final String defaultRegion_; - - @FunctionalInterface - public interface RegionalClientSupplier { - /** - * Supplies an AWSKMS instance to use for a given region. The {@link KmsMasterKeyProvider} will not cache the - * result of this function. - * - * @param regionName The region to get a client for - * @return The client to use, or null if this region cannot or should not be used. - */ - AWSKMS getClient(String regionName); - } - - public static class Builder implements Cloneable { - private String defaultRegion_ = null; - private RegionalClientSupplier regionalClientSupplier_ = null; - private AWSKMSClientBuilder templateBuilder_ = null; - private DiscoveryFilter discoveryFilter_ = null; - - Builder() { - // Default access: Don't allow outside classes to extend this class - } - - public Builder clone() { - try { - Builder cloned = (Builder) super.clone(); + private static final String PROVIDER_NAME = "aws-kms"; + private final List keyIds_; + private final List grantTokens_; - if (templateBuilder_ != null) { - cloned.templateBuilder_ = cloneClientBuilder(templateBuilder_); - } - - return cloned; - } catch (CloneNotSupportedException e) { - throw new Error("Impossible: CloneNotSupportedException", e); - } - } - - /** - * Sets the default region. This region will be used when specifying key IDs for encryption or in - * {@link KmsMasterKeyProvider#getMasterKey(String)} that are not full ARNs, but are instead bare key IDs or - * aliases. - * - * If the default region is not specified, only full key ARNs will be usable. - * - * @param defaultRegion The default region to use. - * @return - */ - public Builder withDefaultRegion(String defaultRegion) { - this.defaultRegion_ = defaultRegion; - return this; - } - - /** - * Provides a custom factory function that will vend KMS clients. This is provided for advanced use cases which - * require complete control over the client construction process. - * - * Because the regional client supplier fully controls the client construction process, it is not possible to - * configure the client through methods such as {@link #withCredentials(AWSCredentialsProvider)} or - * {@link #withClientBuilder(AWSKMSClientBuilder)}; if you try to use these in combination, an - * {@link IllegalStateException} will be thrown. - * - * @param regionalClientSupplier - * @return - */ - public Builder withCustomClientFactory(RegionalClientSupplier regionalClientSupplier) { - if (templateBuilder_ != null) { - throw clientSupplierComboException(); - } - - regionalClientSupplier_ = regionalClientSupplier; - return this; - } + private final boolean isDiscovery_; + private final DiscoveryFilter discoveryFilter_; - private RuntimeException clientSupplierComboException() { - return new IllegalStateException("withCustomClientFactory cannot be used in conjunction with " + - "withCredentials or withClientBuilder"); - } - - /** - * Configures the {@link KmsMasterKeyProvider} to use specific credentials. If a builder was previously set, - * this will override whatever credentials it set. - * @param credentialsProvider - * @return - */ - public Builder withCredentials(AWSCredentialsProvider credentialsProvider) { - if (regionalClientSupplier_ != null) { - throw clientSupplierComboException(); - } - - if (templateBuilder_ == null) { - templateBuilder_ = AWSKMSClientBuilder.standard(); - } - - templateBuilder_.setCredentials(credentialsProvider); - - return this; - } - - /** - * Configures the {@link KmsMasterKeyProvider} to use specific credentials. If a builder was previously set, - * this will override whatever credentials it set. - * @param credentials - * @return - */ - public Builder withCredentials(AWSCredentials credentials) { - return withCredentials(new AWSStaticCredentialsProvider(credentials)); - } + private final RegionalClientSupplier regionalClientSupplier_; + private final String defaultRegion_; - /** - * Configures the {@link KmsMasterKeyProvider} to use settings from this {@link AWSKMSClientBuilder} to - * configure KMS clients. Note that the region set on this builder will be ignored, but all other settings - * will be propagated into the regional clients. - * - * This method will overwrite any credentials set using {@link #withCredentials(AWSCredentialsProvider)}. - * - * @param builder - * @return - */ - public Builder withClientBuilder(AWSKMSClientBuilder builder) { - if (regionalClientSupplier_ != null) { - throw clientSupplierComboException(); - } - final AWSKMSClientBuilder newBuilder = cloneClientBuilder(builder); - - - this.templateBuilder_ = newBuilder; - - return this; - } + @FunctionalInterface + public interface RegionalClientSupplier { + /** + * Supplies an AWSKMS instance to use for a given region. The {@link KmsMasterKeyProvider} will + * not cache the result of this function. + * + * @param regionName The region to get a client for + * @return The client to use, or null if this region cannot or should not be used. + */ + AWSKMS getClient(String regionName); + } - AWSKMSClientBuilder cloneClientBuilder(final AWSKMSClientBuilder builder) { - // We need to copy all arguments out of the builder in case it's mutated later on. - // Unfortunately AWSKMSClientBuilder doesn't support .clone() so we'll have to do it by hand. - - if (builder.getEndpoint() != null) { - // We won't be able to set the region later if a custom endpoint is set. - throw new IllegalArgumentException("Setting endpoint configuration is not compatible with passing a " + - "builder to the KmsMasterKeyProvider. Use withCustomClientFactory" + - " instead."); - } - - final AWSKMSClientBuilder newBuilder = AWSKMSClient.builder(); - newBuilder.setClientConfiguration(builder.getClientConfiguration()); - newBuilder.setCredentials(builder.getCredentials()); - newBuilder.setEndpointConfiguration(builder.getEndpoint()); - newBuilder.setMetricsCollector(builder.getMetricsCollector()); - if (builder.getRequestHandlers() != null) { - newBuilder.setRequestHandlers(builder.getRequestHandlers().toArray(new RequestHandler2[0])); - } - return newBuilder; - } + public static class Builder implements Cloneable { + private String defaultRegion_ = null; + private RegionalClientSupplier regionalClientSupplier_ = null; + private AWSKMSClientBuilder templateBuilder_ = null; + private DiscoveryFilter discoveryFilter_ = null; - /** - * Builds the master key provider in Discovery Mode. - * In Discovery Mode the KMS Master Key Provider will attempt to decrypt using any - * key identifier it discovers in the encrypted message. - * KMS Master Key Providers in Discovery Mode will not encrypt data keys. - * - * @return - */ - public KmsMasterKeyProvider buildDiscovery() { - final boolean isDiscovery = true; - RegionalClientSupplier supplier = clientFactory(); - - return new KmsMasterKeyProvider(supplier, defaultRegion_, emptyList(), emptyList(), isDiscovery, discoveryFilter_); - } - - /** - * Builds the master key provider in Discovery Mode with a {@link DiscoveryFilter}. - * In Discovery Mode the KMS Master Key Provider will attempt to decrypt using any - * key identifier it discovers in the encrypted message that is accepted by the - * {@code filter}. - * KMS Master Key Providers in Discovery Mode will not encrypt data keys. - * - * @param filter - * @return - */ - public KmsMasterKeyProvider buildDiscovery(DiscoveryFilter filter) { - if (filter == null) { - throw new IllegalArgumentException("Discovery filter must not be null if specifying " + - "a discovery filter."); - } - discoveryFilter_ = filter; - - return buildDiscovery(); - } + Builder() { + // Default access: Don't allow outside classes to extend this class + } - /** - * Builds the master key provider in Strict Mode. - * KMS Master Key Providers in Strict Mode will only attempt to decrypt using - * key ARNs listed in {@code keyIds}. - * KMS Master Key Providers in Strict Mode will encrypt data keys using the keys - * listed in {@code keyIds} - * - * In Strict Mode, one or more CMKs must be provided. - * For providers that will only be used for encryption, - * you can use any valid KMS key identifier. - * For providers that will be used for decryption, - * you must use the key ARN; - * key ids, alias names, and alias ARNs are not supported. - * - * @param keyIds - * @return - */ - public KmsMasterKeyProvider buildStrict(List keyIds) { - if (keyIds == null) { - throw new IllegalArgumentException("Strict mode must be configured with a non-empty " + - "list of keyIds."); - } - - final boolean isDiscovery = false; - RegionalClientSupplier supplier = clientFactory(); - - return new KmsMasterKeyProvider(supplier, defaultRegion_, new ArrayList(keyIds), emptyList(), isDiscovery, null); - } + public Builder clone() { + try { + Builder cloned = (Builder) super.clone(); - /** - * Builds the master key provider in strict mode. - * KMS Master Key Providers in Strict Mode will only attempt to decrypt using - * key ARNs listed in {@code keyIds}. - * KMS Master Key Providers in Strict Mode will encrypt data keys using the keys - * listed in {@code keyIds} - * - * In Strict Mode, one or more CMKs must be provided. - * For providers that will only be used for encryption, - * you can use any valid KMS key identifier. - * For providers that will be used for decryption, - * you must use the key ARN; - * key ids, alias names, and alias ARNs are not supported. - * - * @param keyIds - * @return - */ - public KmsMasterKeyProvider buildStrict(String... keyIds) { - return buildStrict(asList(keyIds)); + if (templateBuilder_ != null) { + cloned.templateBuilder_ = cloneClientBuilder(templateBuilder_); } - RegionalClientSupplier clientFactory() { - if (regionalClientSupplier_ != null) { - return regionalClientSupplier_; - } + return cloned; + } catch (CloneNotSupportedException e) { + throw new Error("Impossible: CloneNotSupportedException", e); + } + } - // Clone again; this MKP builder might be reused to build a second MKP with different creds. - AWSKMSClientBuilder builder = templateBuilder_ != null ? cloneClientBuilder(templateBuilder_) - : AWSKMSClientBuilder.standard(); + /** + * Sets the default region. This region will be used when specifying key IDs for encryption or + * in {@link KmsMasterKeyProvider#getMasterKey(String)} that are not full ARNs, but are instead + * bare key IDs or aliases. + * + *

If the default region is not specified, only full key ARNs will be usable. + * + * @param defaultRegion The default region to use. + * @return + */ + public Builder withDefaultRegion(String defaultRegion) { + this.defaultRegion_ = defaultRegion; + return this; + } - ConcurrentHashMap clientCache = new ConcurrentHashMap<>(); - snoopClientCache(clientCache); + /** + * Provides a custom factory function that will vend KMS clients. This is provided for advanced + * use cases which require complete control over the client construction process. + * + *

Because the regional client supplier fully controls the client construction process, it is + * not possible to configure the client through methods such as {@link + * #withCredentials(AWSCredentialsProvider)} or {@link #withClientBuilder(AWSKMSClientBuilder)}; + * if you try to use these in combination, an {@link IllegalStateException} will be thrown. + * + * @param regionalClientSupplier + * @return + */ + public Builder withCustomClientFactory(RegionalClientSupplier regionalClientSupplier) { + if (templateBuilder_ != null) { + throw clientSupplierComboException(); + } - return region -> { - AWSKMS kms = clientCache.get(region); + regionalClientSupplier_ = regionalClientSupplier; + return this; + } - if (kms != null) return kms; + private RuntimeException clientSupplierComboException() { + return new IllegalStateException( + "withCustomClientFactory cannot be used in conjunction with " + + "withCredentials or withClientBuilder"); + } - // We can't just use computeIfAbsent as we need to avoid leaking KMS clients if we're asked to decrypt - // an EDK with a bogus region in its ARN. So we'll install a request handler to identify the first - // successful call, and cache it when we see that. - SuccessfulRequestCacher cacher = new SuccessfulRequestCacher(clientCache, region); - ArrayList handlers = new ArrayList<>(); - if (builder.getRequestHandlers() != null) { - handlers.addAll(builder.getRequestHandlers()); - } - handlers.add(cacher); + /** + * Configures the {@link KmsMasterKeyProvider} to use specific credentials. If a builder was + * previously set, this will override whatever credentials it set. + * + * @param credentialsProvider + * @return + */ + public Builder withCredentials(AWSCredentialsProvider credentialsProvider) { + if (regionalClientSupplier_ != null) { + throw clientSupplierComboException(); + } - kms = cloneClientBuilder(builder) - .withRegion(region) - .withRequestHandlers(handlers.toArray(new RequestHandler2[handlers.size()])) - .build(); + if (templateBuilder_ == null) { + templateBuilder_ = AWSKMSClientBuilder.standard(); + } - return cacher.setClient(kms); - }; - } + templateBuilder_.setCredentials(credentialsProvider); - protected void snoopClientCache(ConcurrentHashMap map) { - // no-op - this is a test hook - } + return this; } - static class SuccessfulRequestCacher extends RequestHandler2 { - private final ConcurrentHashMap cache_; - private final String region_; - private AWSKMS client_; + /** + * Configures the {@link KmsMasterKeyProvider} to use specific credentials. If a builder was + * previously set, this will override whatever credentials it set. + * + * @param credentials + * @return + */ + public Builder withCredentials(AWSCredentials credentials) { + return withCredentials(new AWSStaticCredentialsProvider(credentials)); + } - volatile boolean ranBefore_ = false; + /** + * Configures the {@link KmsMasterKeyProvider} to use settings from this {@link + * AWSKMSClientBuilder} to configure KMS clients. Note that the region set on this builder will + * be ignored, but all other settings will be propagated into the regional clients. + * + *

This method will overwrite any credentials set using {@link + * #withCredentials(AWSCredentialsProvider)}. + * + * @param builder + * @return + */ + public Builder withClientBuilder(AWSKMSClientBuilder builder) { + if (regionalClientSupplier_ != null) { + throw clientSupplierComboException(); + } + final AWSKMSClientBuilder newBuilder = cloneClientBuilder(builder); - SuccessfulRequestCacher( - final ConcurrentHashMap cache, - final String region - ) { - this.region_ = region; - this.cache_ = cache; - } + this.templateBuilder_ = newBuilder; - public AWSKMS setClient(final AWSKMS client) { - client_ = client; - return client; - } + return this; + } - @Override public void afterResponse(final Request request, final Response response) { - if (ranBefore_) return; - ranBefore_ = true; + AWSKMSClientBuilder cloneClientBuilder(final AWSKMSClientBuilder builder) { + // We need to copy all arguments out of the builder in case it's mutated later on. + // Unfortunately AWSKMSClientBuilder doesn't support .clone() so we'll have to do it by hand. + + if (builder.getEndpoint() != null) { + // We won't be able to set the region later if a custom endpoint is set. + throw new IllegalArgumentException( + "Setting endpoint configuration is not compatible with passing a " + + "builder to the KmsMasterKeyProvider. Use withCustomClientFactory" + + " instead."); + } + + final AWSKMSClientBuilder newBuilder = AWSKMSClient.builder(); + newBuilder.setClientConfiguration(builder.getClientConfiguration()); + newBuilder.setCredentials(builder.getCredentials()); + newBuilder.setEndpointConfiguration(builder.getEndpoint()); + newBuilder.setMetricsCollector(builder.getMetricsCollector()); + if (builder.getRequestHandlers() != null) { + newBuilder.setRequestHandlers(builder.getRequestHandlers().toArray(new RequestHandler2[0])); + } + return newBuilder; + } - cache_.putIfAbsent(region_, client_); - } + /** + * Builds the master key provider in Discovery Mode. In Discovery Mode the KMS Master Key + * Provider will attempt to decrypt using any key identifier it discovers in the encrypted + * message. KMS Master Key Providers in Discovery Mode will not encrypt data keys. + * + * @return + */ + public KmsMasterKeyProvider buildDiscovery() { + final boolean isDiscovery = true; + RegionalClientSupplier supplier = clientFactory(); - @Override public void afterError(final Request request, final Response response, final Exception e) { - if (ranBefore_) return; - if (e instanceof AmazonServiceException) { - ranBefore_ = true; - cache_.putIfAbsent(region_, client_); - } - } + return new KmsMasterKeyProvider( + supplier, defaultRegion_, emptyList(), emptyList(), isDiscovery, discoveryFilter_); } - public static Builder builder() { - return new Builder(); + /** + * Builds the master key provider in Discovery Mode with a {@link DiscoveryFilter}. In Discovery + * Mode the KMS Master Key Provider will attempt to decrypt using any key identifier it + * discovers in the encrypted message that is accepted by the {@code filter}. KMS Master Key + * Providers in Discovery Mode will not encrypt data keys. + * + * @param filter + * @return + */ + public KmsMasterKeyProvider buildDiscovery(DiscoveryFilter filter) { + if (filter == null) { + throw new IllegalArgumentException( + "Discovery filter must not be null if specifying " + "a discovery filter."); + } + discoveryFilter_ = filter; + + return buildDiscovery(); } - KmsMasterKeyProvider( - RegionalClientSupplier supplier, - String defaultRegion, - List keyIds, - List grantTokens, - boolean isDiscovery, - DiscoveryFilter discoveryFilter - ) { - if (!isDiscovery && (keyIds == null || keyIds.isEmpty())) { - throw new IllegalArgumentException("Strict mode must be configured with a non-empty " + - "list of keyIds."); - } - if (!isDiscovery && keyIds.contains(null)) { - throw new IllegalArgumentException("Strict mode cannot be configured with a " + - "null key identifier."); - } - if (!isDiscovery && discoveryFilter != null) { - throw new IllegalArgumentException("Strict mode cannot be configured with a " + - "discovery filter."); - } - // If we don't have a default region, we need to check that all key IDs will be usable - if (!isDiscovery && defaultRegion == null) { - for (String keyId : keyIds) { - final AwsKmsCmkArnInfo arnInfo = parseInfoFromKeyArn(keyId); - if (arnInfo == null) { - throw new AwsCryptoException("Can't use non-ARN key identifiers or aliases when " + - "no default region is set"); - } - } - } - + /** + * Builds the master key provider in Strict Mode. KMS Master Key Providers in Strict Mode will + * only attempt to decrypt using key ARNs listed in {@code keyIds}. KMS Master Key Providers in + * Strict Mode will encrypt data keys using the keys listed in {@code keyIds} + * + *

In Strict Mode, one or more CMKs must be provided. For providers that will only be used + * for encryption, you can use any valid KMS key identifier. For providers that will be used for + * decryption, you must use the key ARN; key ids, alias names, and alias ARNs are not supported. + * + * @param keyIds + * @return + */ + public KmsMasterKeyProvider buildStrict(List keyIds) { + if (keyIds == null) { + throw new IllegalArgumentException( + "Strict mode must be configured with a non-empty " + "list of keyIds."); + } - this.regionalClientSupplier_ = supplier; - this.defaultRegion_ = defaultRegion; - this.keyIds_ = Collections.unmodifiableList(new ArrayList<>(keyIds)); + final boolean isDiscovery = false; + RegionalClientSupplier supplier = clientFactory(); - this.isDiscovery_ = isDiscovery; - this.discoveryFilter_ = discoveryFilter; - this.grantTokens_ = grantTokens; - } - - private static RegionalClientSupplier defaultProvider() { - return builder().clientFactory(); + return new KmsMasterKeyProvider( + supplier, defaultRegion_, new ArrayList(keyIds), emptyList(), isDiscovery, null); } /** - * Returns "aws-kms" + * Builds the master key provider in strict mode. KMS Master Key Providers in Strict Mode will + * only attempt to decrypt using key ARNs listed in {@code keyIds}. KMS Master Key Providers in + * Strict Mode will encrypt data keys using the keys listed in {@code keyIds} + * + *

In Strict Mode, one or more CMKs must be provided. For providers that will only be used + * for encryption, you can use any valid KMS key identifier. For providers that will be used for + * decryption, you must use the key ARN; key ids, alias names, and alias ARNs are not supported. + * + * @param keyIds + * @return */ - @Override - public String getDefaultProviderId() { - return PROVIDER_NAME; + public KmsMasterKeyProvider buildStrict(String... keyIds) { + return buildStrict(asList(keyIds)); } - @Override - public KmsMasterKey getMasterKey(final String provider, final String keyId) throws UnsupportedProviderException, - NoSuchMasterKeyException { - if (!canProvide(provider)) { - throw new UnsupportedProviderException(); - } - - if (!isDiscovery_ && !keyIds_.contains(keyId)) { - throw new NoSuchMasterKeyException("Key must be in supplied list of keyIds."); - } - - final AwsKmsCmkArnInfo arnInfo = parseInfoFromKeyArn(keyId); + RegionalClientSupplier clientFactory() { + if (regionalClientSupplier_ != null) { + return regionalClientSupplier_; + } + + // Clone again; this MKP builder might be reused to build a second MKP with different creds. + AWSKMSClientBuilder builder = + templateBuilder_ != null + ? cloneClientBuilder(templateBuilder_) + : AWSKMSClientBuilder.standard(); + + ConcurrentHashMap clientCache = new ConcurrentHashMap<>(); + snoopClientCache(clientCache); + + return region -> { + AWSKMS kms = clientCache.get(region); + + if (kms != null) return kms; + + // We can't just use computeIfAbsent as we need to avoid leaking KMS clients if we're asked + // to decrypt + // an EDK with a bogus region in its ARN. So we'll install a request handler to identify the + // first + // successful call, and cache it when we see that. + SuccessfulRequestCacher cacher = new SuccessfulRequestCacher(clientCache, region); + ArrayList handlers = new ArrayList<>(); + if (builder.getRequestHandlers() != null) { + handlers.addAll(builder.getRequestHandlers()); + } + handlers.add(cacher); + + kms = + cloneClientBuilder(builder) + .withRegion(region) + .withRequestHandlers(handlers.toArray(new RequestHandler2[handlers.size()])) + .build(); + + return cacher.setClient(kms); + }; + } - if (isDiscovery_ && discoveryFilter_ != null && (arnInfo == null)) { - throw new NoSuchMasterKeyException("Cannot use non-ARN key identifiers or aliases if " - + "discovery filter is configured."); - } else if (isDiscovery_ && discoveryFilter_ != null && - !discoveryFilter_.allowsPartitionAndAccount(arnInfo.getPartition(), arnInfo.getAccountId())) { - throw new NoSuchMasterKeyException("Cannot use key in partition " + arnInfo.getPartition() + - " with account id " + arnInfo.getAccountId() + " with configured discovery filter."); - } + protected void snoopClientCache(ConcurrentHashMap map) { + // no-op - this is a test hook + } + } - String regionName = defaultRegion_; - if (arnInfo != null) { - regionName = arnInfo.getRegion(); - } + static class SuccessfulRequestCacher extends RequestHandler2 { + private final ConcurrentHashMap cache_; + private final String region_; + private AWSKMS client_; - String regionName_ = regionName; + volatile boolean ranBefore_ = false; - Supplier kmsSupplier = () -> { - AWSKMS kms = regionalClientSupplier_.getClient(regionName_); - if (kms == null) { - throw new AwsCryptoException("Can't use keys from region " + regionName_); - } - return kms; - }; + SuccessfulRequestCacher(final ConcurrentHashMap cache, final String region) { + this.region_ = region; + this.cache_ = cache; + } - final KmsMasterKey result = KmsMasterKey.getInstance(kmsSupplier, keyId, this); - result.setGrantTokens(grantTokens_); - return result; + public AWSKMS setClient(final AWSKMS client) { + client_ = client; + return client; } - /** - * Returns all CMKs provided to the constructor of this object. - */ @Override - public List getMasterKeysForEncryption(final MasterKeyRequest request) { - if (keyIds_ == null) { - return emptyList(); - } - List result = new ArrayList<>(keyIds_.size()); - for (String id : keyIds_) { - result.add(getMasterKey(id)); - } - return result; + public void afterResponse(final Request request, final Response response) { + if (ranBefore_) return; + ranBefore_ = true; + + cache_.putIfAbsent(region_, client_); } @Override - public DataKey decryptDataKey(final CryptoAlgorithm algorithm, - final Collection encryptedDataKeys, final Map encryptionContext) - throws AwsCryptoException { - final List exceptions = new ArrayList<>(); - for (final EncryptedDataKey edk : encryptedDataKeys) { - if (canProvide(edk.getProviderId())) { - try { - final String keyArn = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); - // This will throw if we can't use this key for whatever reason - return getMasterKey(keyArn).decryptDataKey(algorithm, singletonList(edk), encryptionContext); - } catch (final Exception ex) { - exceptions.add(ex); - } - } + public void afterError( + final Request request, final Response response, final Exception e) { + if (ranBefore_) return; + if (e instanceof AmazonServiceException) { + ranBefore_ = true; + cache_.putIfAbsent(region_, client_); + } + } + } + + public static Builder builder() { + return new Builder(); + } + + KmsMasterKeyProvider( + RegionalClientSupplier supplier, + String defaultRegion, + List keyIds, + List grantTokens, + boolean isDiscovery, + DiscoveryFilter discoveryFilter) { + if (!isDiscovery && (keyIds == null || keyIds.isEmpty())) { + throw new IllegalArgumentException( + "Strict mode must be configured with a non-empty " + "list of keyIds."); + } + if (!isDiscovery && keyIds.contains(null)) { + throw new IllegalArgumentException( + "Strict mode cannot be configured with a " + "null key identifier."); + } + if (!isDiscovery && discoveryFilter != null) { + throw new IllegalArgumentException( + "Strict mode cannot be configured with a " + "discovery filter."); + } + // If we don't have a default region, we need to check that all key IDs will be usable + if (!isDiscovery && defaultRegion == null) { + for (String keyId : keyIds) { + final AwsKmsCmkArnInfo arnInfo = parseInfoFromKeyArn(keyId); + if (arnInfo == null) { + throw new AwsCryptoException( + "Can't use non-ARN key identifiers or aliases when " + "no default region is set"); } - throw buildCannotDecryptDksException(exceptions); + } } - /** - * @deprecated This method is inherently not thread safe. Use {@link KmsMasterKey#setGrantTokens(List)} instead. - * {@link KmsMasterKeyProvider}s constructed using the builder will throw an exception on attempts to modify the - * list of grant tokens. - */ - @Deprecated - @Override - public void setGrantTokens(final List grantTokens) { - try { - this.grantTokens_.clear(); - this.grantTokens_.addAll(grantTokens); - } catch (UnsupportedOperationException e) { - throw grantTokenError(); - } + this.regionalClientSupplier_ = supplier; + this.defaultRegion_ = defaultRegion; + this.keyIds_ = Collections.unmodifiableList(new ArrayList<>(keyIds)); + + this.isDiscovery_ = isDiscovery; + this.discoveryFilter_ = discoveryFilter; + this.grantTokens_ = grantTokens; + } + + private static RegionalClientSupplier defaultProvider() { + return builder().clientFactory(); + } + + /** Returns "aws-kms" */ + @Override + public String getDefaultProviderId() { + return PROVIDER_NAME; + } + + @Override + public KmsMasterKey getMasterKey(final String provider, final String keyId) + throws UnsupportedProviderException, NoSuchMasterKeyException { + if (!canProvide(provider)) { + throw new UnsupportedProviderException(); } - @Override - public List getGrantTokens() { - return new ArrayList<>(grantTokens_); + if (!isDiscovery_ && !keyIds_.contains(keyId)) { + throw new NoSuchMasterKeyException("Key must be in supplied list of keyIds."); } - /** - * @deprecated This method is inherently not thread safe. Use {@link #withGrantTokens(List)} or - * {@link KmsMasterKey#setGrantTokens(List)} instead. {@link KmsMasterKeyProvider}s constructed using the builder - * will throw an exception on attempts to modify the list of grant tokens. - */ - @Deprecated - @Override - public void addGrantToken(final String grantToken) { - try { - grantTokens_.add(grantToken); - } catch (UnsupportedOperationException e) { - throw grantTokenError(); - } + final AwsKmsCmkArnInfo arnInfo = parseInfoFromKeyArn(keyId); + + if (isDiscovery_ && discoveryFilter_ != null && (arnInfo == null)) { + throw new NoSuchMasterKeyException( + "Cannot use non-ARN key identifiers or aliases if " + "discovery filter is configured."); + } else if (isDiscovery_ + && discoveryFilter_ != null + && !discoveryFilter_.allowsPartitionAndAccount( + arnInfo.getPartition(), arnInfo.getAccountId())) { + throw new NoSuchMasterKeyException( + "Cannot use key in partition " + + arnInfo.getPartition() + + " with account id " + + arnInfo.getAccountId() + + " with configured discovery filter."); } - private RuntimeException grantTokenError() { - return new IllegalStateException("This master key provider is immutable. Use withGrantTokens instead."); + String regionName = defaultRegion_; + if (arnInfo != null) { + regionName = arnInfo.getRegion(); } - /** - * Returns a new {@link KmsMasterKeyProvider} that is configured identically to this one, except with the given list - * of grant tokens. The grant token list in the returned provider is immutable (but can be further overridden by - * invoking withGrantTokens again). - * @param grantTokens - * @return - */ - public KmsMasterKeyProvider withGrantTokens(List grantTokens) { - grantTokens = Collections.unmodifiableList(new ArrayList<>(grantTokens)); + String regionName_ = regionName; - return new KmsMasterKeyProvider(regionalClientSupplier_, defaultRegion_, keyIds_, grantTokens, isDiscovery_, discoveryFilter_); - } + Supplier kmsSupplier = + () -> { + AWSKMS kms = regionalClientSupplier_.getClient(regionName_); + if (kms == null) { + throw new AwsCryptoException("Can't use keys from region " + regionName_); + } + return kms; + }; - /** - * Returns a new {@link KmsMasterKeyProvider} that is configured identically to this one, except with the given list - * of grant tokens. The grant token list in the returned provider is immutable (but can be further overridden by - * invoking withGrantTokens again). - * @param grantTokens - * @return - */ - public KmsMasterKeyProvider withGrantTokens(String... grantTokens) { - return withGrantTokens(asList(grantTokens)); - } + final KmsMasterKey result = KmsMasterKey.getInstance(kmsSupplier, keyId, this); + result.setGrantTokens(grantTokens_); + return result; + } + /** Returns all CMKs provided to the constructor of this object. */ + @Override + public List getMasterKeysForEncryption(final MasterKeyRequest request) { + if (keyIds_ == null) { + return emptyList(); + } + List result = new ArrayList<>(keyIds_.size()); + for (String id : keyIds_) { + result.add(getMasterKey(id)); + } + return result; + } + + @Override + public DataKey decryptDataKey( + final CryptoAlgorithm algorithm, + final Collection encryptedDataKeys, + final Map encryptionContext) + throws AwsCryptoException { + final List exceptions = new ArrayList<>(); + for (final EncryptedDataKey edk : encryptedDataKeys) { + if (canProvide(edk.getProviderId())) { + try { + final String keyArn = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); + // This will throw if we can't use this key for whatever reason + return getMasterKey(keyArn) + .decryptDataKey(algorithm, singletonList(edk), encryptionContext); + } catch (final Exception ex) { + exceptions.add(ex); + } + } + } + throw buildCannotDecryptDksException(exceptions); + } + + /** + * @deprecated This method is inherently not thread safe. Use {@link + * KmsMasterKey#setGrantTokens(List)} instead. {@link KmsMasterKeyProvider}s constructed using + * the builder will throw an exception on attempts to modify the list of grant tokens. + */ + @Deprecated + @Override + public void setGrantTokens(final List grantTokens) { + try { + this.grantTokens_.clear(); + this.grantTokens_.addAll(grantTokens); + } catch (UnsupportedOperationException e) { + throw grantTokenError(); + } + } + + @Override + public List getGrantTokens() { + return new ArrayList<>(grantTokens_); + } + + /** + * @deprecated This method is inherently not thread safe. Use {@link #withGrantTokens(List)} or + * {@link KmsMasterKey#setGrantTokens(List)} instead. {@link KmsMasterKeyProvider}s + * constructed using the builder will throw an exception on attempts to modify the list of + * grant tokens. + */ + @Deprecated + @Override + public void addGrantToken(final String grantToken) { + try { + grantTokens_.add(grantToken); + } catch (UnsupportedOperationException e) { + throw grantTokenError(); + } + } + + private RuntimeException grantTokenError() { + return new IllegalStateException( + "This master key provider is immutable. Use withGrantTokens instead."); + } + + /** + * Returns a new {@link KmsMasterKeyProvider} that is configured identically to this one, except + * with the given list of grant tokens. The grant token list in the returned provider is immutable + * (but can be further overridden by invoking withGrantTokens again). + * + * @param grantTokens + * @return + */ + public KmsMasterKeyProvider withGrantTokens(List grantTokens) { + grantTokens = Collections.unmodifiableList(new ArrayList<>(grantTokens)); + + return new KmsMasterKeyProvider( + regionalClientSupplier_, + defaultRegion_, + keyIds_, + grantTokens, + isDiscovery_, + discoveryFilter_); + } + + /** + * Returns a new {@link KmsMasterKeyProvider} that is configured identically to this one, except + * with the given list of grant tokens. The grant token list in the returned provider is immutable + * (but can be further overridden by invoking withGrantTokens again). + * + * @param grantTokens + * @return + */ + public KmsMasterKeyProvider withGrantTokens(String... grantTokens) { + return withGrantTokens(asList(grantTokens)); + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMethods.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMethods.java index 46671ac5f..632874372 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMethods.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMethods.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -15,22 +15,14 @@ import java.util.List; -/** - * Methods common to all classes which interact with KMS. - */ +/** Methods common to all classes which interact with KMS. */ public interface KmsMethods { - /** - * Sets the {@code grantTokens} which should be submitted to KMS when calling it. - */ - public void setGrantTokens(List grantTokens); + /** Sets the {@code grantTokens} which should be submitted to KMS when calling it. */ + public void setGrantTokens(List grantTokens); - /** - * Returns the grantTokens which this object sends to KMS when calling it. - */ - public List getGrantTokens(); + /** Returns the grantTokens which this object sends to KMS when calling it. */ + public List getGrantTokens(); - /** - * Adds {@code grantToken} to the list of grantTokens sent to KMS when this class calls it. - */ - public void addGrantToken(String grantToken); + /** Adds {@code grantToken} to the list of grantTokens sent to KMS when this class calls it. */ + public void addGrantToken(String grantToken); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/package-info.java b/src/main/java/com/amazonaws/encryptionsdk/kms/package-info.java index 0e5182bb8..e69745e23 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/package-info.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/package-info.java @@ -1,18 +1,18 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ /** - * Contains logic necessary to create {@link com.amazonaws.encryptionsdk.MasterKey}s backed - * by AWS KMS keys. + * Contains logic necessary to create {@link com.amazonaws.encryptionsdk.MasterKey}s backed by AWS + * KMS keys. */ package com.amazonaws.encryptionsdk.kms; diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/CipherBlockHeaders.java b/src/main/java/com/amazonaws/encryptionsdk/model/CipherBlockHeaders.java index e1c789ae6..9e8a033ff 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/CipherBlockHeaders.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/CipherBlockHeaders.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,381 +13,349 @@ package com.amazonaws.encryptionsdk.model; -import java.nio.ByteBuffer; -import java.util.Arrays; - import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.exception.ParseException; import com.amazonaws.encryptionsdk.internal.Constants; import com.amazonaws.encryptionsdk.internal.PrimitivesParser; +import java.nio.ByteBuffer; +import java.util.Arrays; /** - * This class implements the headers for the encrypted content stored in a - * single block. These headers are parsed and used when the encrypted content - * in the single block is decrypted. - * - *

- * It contains the following fields in order: + * This class implements the headers for the encrypted content stored in a single block. These + * headers are parsed and used when the encrypted content in the single block is decrypted. + * + *

It contains the following fields in order: + * *

    - *
  1. nonce
  2. - *
  3. length of content
  4. + *
  5. nonce + *
  6. length of content *
*/ -//@ non_null_by_default +// @ non_null_by_default public final class CipherBlockHeaders { - //@ spec_public nullable - private byte[] nonce_; - //@ spec_public - private long contentLength_ = -1; + // @ spec_public nullable + private byte[] nonce_; + // @ spec_public + private long contentLength_ = -1; + + // This is set after the nonce length is parsed in the CiphertextHeaders + // during decryption. This can be set only using its setter. + // @ spec_public + private short nonceLength_ = 0; + // @ public invariant nonceLength_ >= 0; - // This is set after the nonce length is parsed in the CiphertextHeaders - // during decryption. This can be set only using its setter. - //@ spec_public - private short nonceLength_ = 0; - //@ public invariant nonceLength_ >= 0; + // @ spec_public + private boolean isComplete_; - //@ spec_public - private boolean isComplete_; + /** Default constructor. */ + // @ public normal_behavior + // @ ensures nonce_ == null; + // @ ensures contentLength_ == -1; + // @ ensures nonceLength_ == 0; + // @ ensures isComplete_ == false; + public CipherBlockHeaders() {} - /** - * Default constructor. - */ - //@ public normal_behavior - //@ ensures nonce_ == null; - //@ ensures contentLength_ == -1; - //@ ensures nonceLength_ == 0; - //@ ensures isComplete_ == false; - public CipherBlockHeaders() { + /** + * Construct the single block headers using the provided nonce and length of content. + * + * @param nonce the bytes containing the nonce. + * @param contentLen the length of the content in the block. + */ + // @ public normal_behavior + // @ requires nonce != null && nonce.length <= Constants.MAX_NONCE_LENGTH; + // @ ensures \fresh(nonce_) && nonce_.length == nonce.length; + // @ ensures Arrays.equalArrays(nonce_, nonce); + // @ ensures contentLength_ == contentLen; + // @ ensures nonceLength_ == 0; + // @ ensures isComplete_ == false; + // @ also private exceptional_behavior + // @ requires nonce == null || nonce.length > Constants.MAX_NONCE_LENGTH; + // @ signals_only AwsCryptoException; + // @ pure + public CipherBlockHeaders(/*@ nullable @*/ final byte[] nonce, final long contentLen) { + if (nonce == null) { + throw new AwsCryptoException("Nonce cannot be null."); + } + if (nonce.length > Constants.MAX_NONCE_LENGTH) { + throw new AwsCryptoException( + "Nonce length is greater than the maximum value of an unsigned byte."); } - /** - * Construct the single block headers using the provided nonce - * and length of content. - * - * @param nonce - * the bytes containing the nonce. - * @param contentLen - * the length of the content in the block. - */ - //@ public normal_behavior - //@ requires nonce != null && nonce.length <= Constants.MAX_NONCE_LENGTH; - //@ ensures \fresh(nonce_) && nonce_.length == nonce.length; - //@ ensures Arrays.equalArrays(nonce_, nonce); - //@ ensures contentLength_ == contentLen; - //@ ensures nonceLength_ == 0; - //@ ensures isComplete_ == false; - //@ also private exceptional_behavior - //@ requires nonce == null || nonce.length > Constants.MAX_NONCE_LENGTH; - //@ signals_only AwsCryptoException; - //@ pure - public CipherBlockHeaders(/*@ nullable @*/ final byte[] nonce, final long contentLen) { - if (nonce == null) { - throw new AwsCryptoException("Nonce cannot be null."); - } - if (nonce.length > Constants.MAX_NONCE_LENGTH) { - throw new AwsCryptoException( - "Nonce length is greater than the maximum value of an unsigned byte."); - } + nonce_ = nonce.clone(); + contentLength_ = contentLen; + } - nonce_ = nonce.clone(); - contentLength_ = contentLen; - } + /** + * Serialize the header into a byte array. + * + * @return the serialized bytes of the header. + */ + /*@ public normal_behavior + @ requires nonce_ != null; + @ old int nLen = nonce_.length; + @ requires nonce_.length <= Integer.MAX_VALUE - (Long.SIZE / Byte.SIZE); + @ ensures \result.length == nonce_.length + (Long.SIZE / Byte.SIZE); + @ ensures (\forall int i; 0<=i && iIf successful, it returns the size of the parsed bytes which is the nonce length. On + * failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the nonce length. + * @throws ParseException if there are not sufficient bytes to parse the nonce. + */ + // @ private normal_behavior + // @ requires nonceLength_ > 0; + // @ requires 0 <= off; + // @ requires b.length - off >= nonceLength_; + // @ assignable nonce_; + // @ ensures nonce_ != null && \fresh(nonce_); + // @ ensures Arrays.equalArrays(b, off, nonce_, 0, nonceLength_); + // @ ensures \result == nonceLength_; + // @ also private exceptional_behavior + // @ // add exceptions from arrays.copyofrange + // @ requires b.length - off < nonceLength_; + // @ assignable \nothing; + // @ signals_only ParseException; + private int parseNonce(final byte[] b, final int off) throws ParseException { + final int bytesToParseLen = b.length - off; + if (bytesToParseLen >= nonceLength_) { + nonce_ = Arrays.copyOfRange(b, off, off + nonceLength_); + return nonceLength_; + } else { + throw new ParseException("Not enough bytes to parse nonce"); } + } - /** - * Parse the nonce in the provided bytes. It looks for bytes of size - * defined by the nonce length in the provided bytes starting at the - * specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the nonce - * length. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the nonce length. - * @throws ParseException - * if there are not sufficient bytes to parse the nonce. - */ - //@ private normal_behavior - //@ requires nonceLength_ > 0; - //@ requires 0 <= off; - //@ requires b.length - off >= nonceLength_; - //@ assignable nonce_; - //@ ensures nonce_ != null && \fresh(nonce_); - //@ ensures Arrays.equalArrays(b, off, nonce_, 0, nonceLength_); - //@ ensures \result == nonceLength_; - //@ also private exceptional_behavior - //@ // add exceptions from arrays.copyofrange - //@ requires b.length - off < nonceLength_; - //@ assignable \nothing; - //@ signals_only ParseException; - private int parseNonce(final byte[] b, final int off) throws ParseException { - final int bytesToParseLen = b.length - off; - if (bytesToParseLen >= nonceLength_) { - nonce_ = Arrays.copyOfRange(b, off, off + nonceLength_); - return nonceLength_; - } else { - throw new ParseException("Not enough bytes to parse nonce"); - } + /** + * Parse the content length in the provided bytes. It looks for 8 bytes representing a long + * primitive type in the provided bytes starting at the specified off. + * + *

If successful, it returns the size of the parsed bytes which is the size of the long + * primitive type. On failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the size of the long primitive type. + * @throws ParseException if there are not sufficient bytes to parse the content length. + */ + // @ private behavior + // @ requires off >= 0; + // @ requires b.length - off >= Long.BYTES; + // @ old long len = + // Long.asLong(b[off],b[off+1],b[off+2],b[off+3],b[off+4],b[off+5],b[off+6],b[off+7]); + // @ assignable contentLength_; + // @ ensures len >= 0; + // @ ensures contentLength_ == len; + // @ ensures \result == Long.BYTES; + // @ signals_only BadCiphertextException; + // @ signals (BadCiphertextException) len < 0 && contentLength_ == len; + // @ also private exceptional_behavior + // @ requires b.length - off < Long.BYTES; + // @ assignable \nothing; + // @ signals_only ParseException; + private int parseContentLength(final byte[] b, final int off) throws ParseException { + contentLength_ = PrimitivesParser.parseLong(b, off); + if (contentLength_ < 0) { + throw new BadCiphertextException("Invalid content length in ciphertext"); } + return Long.SIZE / Byte.SIZE; + } - /** - * Parse the content length in the provided bytes. It looks for 8 bytes - * representing a long primitive type in the provided bytes starting at the - * specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the size - * of the long primitive type. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the size of the long - * primitive type. - * @throws ParseException - * if there are not sufficient bytes to parse the content - * length. - */ - //@ private behavior - //@ requires off >= 0; - //@ requires b.length - off >= Long.BYTES; - //@ old long len = Long.asLong(b[off],b[off+1],b[off+2],b[off+3],b[off+4],b[off+5],b[off+6],b[off+7]); - //@ assignable contentLength_; - //@ ensures len >= 0; - //@ ensures contentLength_ == len; - //@ ensures \result == Long.BYTES; - //@ signals_only BadCiphertextException; - //@ signals (BadCiphertextException) len < 0 && contentLength_ == len; - //@ also private exceptional_behavior - //@ requires b.length - off < Long.BYTES; - //@ assignable \nothing; - //@ signals_only ParseException; - private int parseContentLength(final byte[] b, final int off) throws ParseException { - contentLength_ = PrimitivesParser.parseLong(b, off); - if (contentLength_ < 0) { - throw new BadCiphertextException("Invalid content length in ciphertext"); - } - return Long.SIZE / Byte.SIZE; + /** + * Deserialize the provided bytes starting at the specified offset to construct an instance of + * this class. + * + *

This method parses the provided bytes for the individual fields in this class. This methods + * also supports partial parsing where not all the bytes required for parsing the fields + * successfully are available. + * + * @param b the byte array to deserialize. + * @param off the offset in the byte array to use for deserialization. + * @return the number of bytes consumed in deserialization. + */ + /*@ public normal_behavior + @ requires b == null; + @ assignable \nothing; + @ ensures \result == 0; + @ also + @ // case: do not need to parse either value + @ public normal_behavior + @ requires b != null && contentLength_ >= 0 && (nonce_ != null || nonceLength_ == 0); + @ assignable isComplete_; + @ ensures \result == 0; + @ ensures isComplete_; + @ also + @ // case: parse nonce (parse exception) + @ public normal_behavior + @ requires b != null && nonce_ == null && nonceLength_ > 0; + @ requires b.length - off < nonceLength_; + @ assignable \nothing; + @ ensures \result == 0; + @ also + @ // case: parse nonce (normally) and not content length + @ public normal_behavior + @ requires b != null && nonce_ == null && nonceLength_ > 0; + @ requires off >= 0 && b.length - off >= nonceLength_; + @ requires contentLength_ >= 0; + @ assignable nonce_, isComplete_; + @ ensures nonce_ != null && \fresh(nonce_); + @ ensures Arrays.equalArrays(b, off, nonce_, 0, nonceLength_); + @ ensures \result == nonceLength_; + @ ensures isComplete_; + @ also + @ // case: do not parse nonce and parse content length (parse exception) + @ public normal_behavior + @ requires b != null && (nonce_ != null || nonceLength_ == 0); + @ requires contentLength_ < 0; + @ requires b.length - off < Long.BYTES; + @ assignable \nothing; + @ ensures \result == 0; + @ also + @ // case: parse nonce (normally) and parse content length (parse exception) + @ public normal_behavior + @ requires b != null && nonce_ == null && nonceLength_ > 0; + @ requires off >= 0 && b.length - off >= nonceLength_; + @ requires contentLength_ < 0; + @ requires b.length - (off + nonceLength_) < Long.BYTES; + @ assignable nonce_; + @ ensures Arrays.equalArrays(b, off, nonce_, 0, nonceLength_); + @ ensures \result == nonceLength_; + @ also + @ // case: do not parse nonce and parse content length (normally) + @ public behavior + @ requires b != null && (nonce_ != null || nonceLength_ == 0); + @ requires contentLength_ < 0; + @ requires off >= 0; + @ requires b.length - off >= Long.BYTES; + @ assignable contentLength_, isComplete_; + @ ensures isComplete_ && contentLength_ >= 0; + @ ensures contentLength_ == Long.asLong(b[off], b[off+1], b[off+2], b[off+3], + @ b[off+4], b[off+5], b[off+6], b[off+7]); + @ ensures \result == Long.BYTES; + @ signals_only BadCiphertextException; + @ signals (BadCiphertextException) contentLength_ < 0 && isComplete_ == \old(isComplete_); + @ also + @ // case: parse both normally + @ public behavior + @ old int nLen = nonceLength_; + @ requires b != null; + @ requires nonce_ == null && nonceLength_ > 0 && contentLength_ < 0; + @ requires off >= 0 && b.length - off >= nonceLength_; + @ requires b.length - (off + nonceLength_) >= Long.BYTES; + @ requires nonceLength_ <= Integer.MAX_VALUE - Long.BYTES; + @ assignable nonce_, contentLength_, isComplete_; + @ ensures isComplete_ && contentLength_ >= 0; + @ ensures Arrays.equalArrays(b, off, nonce_, 0, nonceLength_); + @ ensures contentLength_ == Long.asLong(b[nLen+off], b[nLen+off+1], b[nLen+off+2], + @ b[nLen+off+3], b[nLen+off+4], b[nLen+off+5], + @ b[nLen+off+6], b[nLen+off+7]); + @ ensures \result == nonceLength_ + Long.BYTES; + @ signals_only BadCiphertextException; + @ signals (BadCiphertextException) (contentLength_ < 0 && isComplete_ == \old(isComplete_) + @ && Arrays.equalArrays(b, off, nonce_, 0, nonceLength_)); + @*/ + public int deserialize(/*@ nullable */ final byte[] b, final int off) { + if (b == null) { + return 0; } - /** - * Deserialize the provided bytes starting at the specified offset to - * construct an instance of this class. - * - *

- * This method parses the provided bytes for the individual fields in this - * class. This methods also supports partial parsing where not all the bytes - * required for parsing the fields successfully are available. - * - * @param b - * the byte array to deserialize. - * @param off - * the offset in the byte array to use for deserialization. - * @return - * the number of bytes consumed in deserialization. - */ - /*@ public normal_behavior - @ requires b == null; - @ assignable \nothing; - @ ensures \result == 0; - @ also - @ // case: do not need to parse either value - @ public normal_behavior - @ requires b != null && contentLength_ >= 0 && (nonce_ != null || nonceLength_ == 0); - @ assignable isComplete_; - @ ensures \result == 0; - @ ensures isComplete_; - @ also - @ // case: parse nonce (parse exception) - @ public normal_behavior - @ requires b != null && nonce_ == null && nonceLength_ > 0; - @ requires b.length - off < nonceLength_; - @ assignable \nothing; - @ ensures \result == 0; - @ also - @ // case: parse nonce (normally) and not content length - @ public normal_behavior - @ requires b != null && nonce_ == null && nonceLength_ > 0; - @ requires off >= 0 && b.length - off >= nonceLength_; - @ requires contentLength_ >= 0; - @ assignable nonce_, isComplete_; - @ ensures nonce_ != null && \fresh(nonce_); - @ ensures Arrays.equalArrays(b, off, nonce_, 0, nonceLength_); - @ ensures \result == nonceLength_; - @ ensures isComplete_; - @ also - @ // case: do not parse nonce and parse content length (parse exception) - @ public normal_behavior - @ requires b != null && (nonce_ != null || nonceLength_ == 0); - @ requires contentLength_ < 0; - @ requires b.length - off < Long.BYTES; - @ assignable \nothing; - @ ensures \result == 0; - @ also - @ // case: parse nonce (normally) and parse content length (parse exception) - @ public normal_behavior - @ requires b != null && nonce_ == null && nonceLength_ > 0; - @ requires off >= 0 && b.length - off >= nonceLength_; - @ requires contentLength_ < 0; - @ requires b.length - (off + nonceLength_) < Long.BYTES; - @ assignable nonce_; - @ ensures Arrays.equalArrays(b, off, nonce_, 0, nonceLength_); - @ ensures \result == nonceLength_; - @ also - @ // case: do not parse nonce and parse content length (normally) - @ public behavior - @ requires b != null && (nonce_ != null || nonceLength_ == 0); - @ requires contentLength_ < 0; - @ requires off >= 0; - @ requires b.length - off >= Long.BYTES; - @ assignable contentLength_, isComplete_; - @ ensures isComplete_ && contentLength_ >= 0; - @ ensures contentLength_ == Long.asLong(b[off], b[off+1], b[off+2], b[off+3], - @ b[off+4], b[off+5], b[off+6], b[off+7]); - @ ensures \result == Long.BYTES; - @ signals_only BadCiphertextException; - @ signals (BadCiphertextException) contentLength_ < 0 && isComplete_ == \old(isComplete_); - @ also - @ // case: parse both normally - @ public behavior - @ old int nLen = nonceLength_; - @ requires b != null; - @ requires nonce_ == null && nonceLength_ > 0 && contentLength_ < 0; - @ requires off >= 0 && b.length - off >= nonceLength_; - @ requires b.length - (off + nonceLength_) >= Long.BYTES; - @ requires nonceLength_ <= Integer.MAX_VALUE - Long.BYTES; - @ assignable nonce_, contentLength_, isComplete_; - @ ensures isComplete_ && contentLength_ >= 0; - @ ensures Arrays.equalArrays(b, off, nonce_, 0, nonceLength_); - @ ensures contentLength_ == Long.asLong(b[nLen+off], b[nLen+off+1], b[nLen+off+2], - @ b[nLen+off+3], b[nLen+off+4], b[nLen+off+5], - @ b[nLen+off+6], b[nLen+off+7]); - @ ensures \result == nonceLength_ + Long.BYTES; - @ signals_only BadCiphertextException; - @ signals (BadCiphertextException) (contentLength_ < 0 && isComplete_ == \old(isComplete_) - @ && Arrays.equalArrays(b, off, nonce_, 0, nonceLength_)); - @*/ - public int deserialize(/*@ nullable */ final byte[] b, final int off) { - if (b == null) { - return 0; - } - - //@ assert b != null; - int parsedBytes = 0; - try { - if (nonceLength_ > 0 && nonce_ == null) { - parsedBytes += parseNonce(b, off + parsedBytes); - } - - if (contentLength_ < 0) { - parsedBytes += parseContentLength(b, off + parsedBytes); - } + // @ assert b != null; + int parsedBytes = 0; + try { + if (nonceLength_ > 0 && nonce_ == null) { + parsedBytes += parseNonce(b, off + parsedBytes); + } - isComplete_ = true; - } catch (ParseException e) { - // this results when we do partial parsing and there aren't enough - // bytes to parse; so just return the bytes parsed thus far. - } + if (contentLength_ < 0) { + parsedBytes += parseContentLength(b, off + parsedBytes); + } - return parsedBytes; + isComplete_ = true; + } catch (ParseException e) { + // this results when we do partial parsing and there aren't enough + // bytes to parse; so just return the bytes parsed thus far. } - /** - * Check if this object has all the header fields populated and available - * for reading. - * - * @return - * true if this object containing the single block header fields - * is complete; false otherwise. - */ - //@ public normal_behavior - //@ ensures \result == isComplete_; - //@ pure - public boolean isComplete() { - return isComplete_; - } + return parsedBytes; + } - /** - * Return the nonce set in the single block header. - * - * @return - * the bytes containing the nonce set in the single block header. - */ - //@ public normal_behavior - //@ requires nonce_ == null; - //@ ensures \result == null; - //@ also public normal_behavior - //@ requires nonce_ != null; - //@ ensures \result != null; - //@ ensures \fresh(\result); - //@ ensures \result != null; - //@ ensures \result.length == nonce_.length; - //@ ensures java.util.Arrays.equalArrays(\result,nonce_); - //@ pure nullable - public byte[] getNonce() { - return nonce_ != null ? nonce_.clone() : null; - } + /** + * Check if this object has all the header fields populated and available for reading. + * + * @return true if this object containing the single block header fields is complete; false + * otherwise. + */ + // @ public normal_behavior + // @ ensures \result == isComplete_; + // @ pure + public boolean isComplete() { + return isComplete_; + } - /** - * Return the content length set in the single block header. - * - * @return - * the content length set in the single block header. - */ - //@ public normal_behavior - //@ ensures \result == contentLength_; - //@ pure - public long getContentLength() { - return contentLength_; - } + /** + * Return the nonce set in the single block header. + * + * @return the bytes containing the nonce set in the single block header. + */ + // @ public normal_behavior + // @ requires nonce_ == null; + // @ ensures \result == null; + // @ also public normal_behavior + // @ requires nonce_ != null; + // @ ensures \result != null; + // @ ensures \fresh(\result); + // @ ensures \result != null; + // @ ensures \result.length == nonce_.length; + // @ ensures java.util.Arrays.equalArrays(\result,nonce_); + // @ pure nullable + public byte[] getNonce() { + return nonce_ != null ? nonce_.clone() : null; + } - /** - * Set the length of the nonce used in the encryption of the content stored - * in the single block. - * - * @param nonceLength - * the length of the nonce used in the encryption of the content - * stored in the single block. - */ - //@ public normal_behavior - //@ requires nonceLength >= 0; - //@ assignable nonceLength_; - //@ ensures nonceLength_ == nonceLength; - public void setNonceLength(final short nonceLength) { - nonceLength_ = nonceLength; - } + /** + * Return the content length set in the single block header. + * + * @return the content length set in the single block header. + */ + // @ public normal_behavior + // @ ensures \result == contentLength_; + // @ pure + public long getContentLength() { + return contentLength_; + } + + /** + * Set the length of the nonce used in the encryption of the content stored in the single block. + * + * @param nonceLength the length of the nonce used in the encryption of the content stored in the + * single block. + */ + // @ public normal_behavior + // @ requires nonceLength >= 0; + // @ assignable nonceLength_; + // @ ensures nonceLength_ == nonceLength; + public void setNonceLength(final short nonceLength) { + nonceLength_ = nonceLength; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/CipherFrameHeaders.java b/src/main/java/com/amazonaws/encryptionsdk/model/CipherFrameHeaders.java index 10a3d0221..fb27fd23d 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/CipherFrameHeaders.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/CipherFrameHeaders.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,319 +13,275 @@ package com.amazonaws.encryptionsdk.model; -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.util.Arrays; - import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.exception.ParseException; import com.amazonaws.encryptionsdk.internal.Constants; import com.amazonaws.encryptionsdk.internal.PrimitivesParser; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Arrays; /** - * This class implements the headers for the encrypted content stored in a - * frame. These headers are parsed and used when the encrypted content in the - * frame is decrypted. - * - *

- * It contains the following fields in order: + * This class implements the headers for the encrypted content stored in a frame. These headers are + * parsed and used when the encrypted content in the frame is decrypted. + * + *

It contains the following fields in order: + * *

    - *
  1. final sequence number marker if final frame
  2. - *
  3. sequence number
  4. - *
  5. nonce
  6. - *
  7. length of content in frame
  8. + *
  9. final sequence number marker if final frame + *
  10. sequence number + *
  11. nonce + *
  12. length of content in frame *
*/ public final class CipherFrameHeaders { - private int sequenceNumber_ = 0; // this is okay since sequence numbers in - // frames start at 1 - private byte[] nonce_; - private int frameContentLength_ = -1; - - // This is set after the nonce length is parsed in the CiphertextHeaders - // during decryption. This can be set only using its setter. - private short nonceLength_ = 0; - - private boolean includeFrameSize_; - private boolean isComplete_; - private boolean isFinalFrame_; - - /** - * Default constructor. - */ - public CipherFrameHeaders() { + private int sequenceNumber_ = 0; // this is okay since sequence numbers in + // frames start at 1 + private byte[] nonce_; + private int frameContentLength_ = -1; + + // This is set after the nonce length is parsed in the CiphertextHeaders + // during decryption. This can be set only using its setter. + private short nonceLength_ = 0; + + private boolean includeFrameSize_; + private boolean isComplete_; + private boolean isFinalFrame_; + + /** Default constructor. */ + public CipherFrameHeaders() {} + + /** + * Construct the frame headers using the provided sequence number, nonce, length of content, and + * boolean value indicating if it is the final frame. + * + * @param sequenceNumber the sequence number of the frame + * @param nonce the bytes containing the nonce. + * @param frameContentLen the length of the content in the frame. + * @param isFinal boolean value indicating if it is the final frame. + */ + public CipherFrameHeaders( + final int sequenceNumber, + final byte[] nonce, + final int frameContentLen, + final boolean isFinal) { + sequenceNumber_ = sequenceNumber; + + if (nonce == null) { + throw new AwsCryptoException("Nonce cannot be null."); } - - /** - * Construct the frame headers using the provided sequence number, nonce, - * length of content, and boolean value indicating if it is the final frame. - * - * @param sequenceNumber - * the sequence number of the frame - * @param nonce - * the bytes containing the nonce. - * @param frameContentLen - * the length of the content in the frame. - * @param isFinal - * boolean value indicating if it is the final frame. - */ - public CipherFrameHeaders(final int sequenceNumber, final byte[] nonce, final int frameContentLen, - final boolean isFinal) { - sequenceNumber_ = sequenceNumber; - - if (nonce == null) { - throw new AwsCryptoException("Nonce cannot be null."); - } - if (nonce.length > Constants.MAX_NONCE_LENGTH) { - throw new AwsCryptoException( - "Nonce length is greater than the maximum value of an unsigned byte."); - } - - nonce_ = nonce.clone(); - isFinalFrame_ = isFinal; - frameContentLength_ = frameContentLen; + if (nonce.length > Constants.MAX_NONCE_LENGTH) { + throw new AwsCryptoException( + "Nonce length is greater than the maximum value of an unsigned byte."); } - /** - * Serialize the header into a byte array. - * - * @return - * the serialized bytes of the header. - */ - public byte[] toByteArray() { - try { - ByteArrayOutputStream outBytes = new ByteArrayOutputStream(); - DataOutputStream dataStream = new DataOutputStream(outBytes); - - if (isFinalFrame_) { - dataStream.writeInt(Constants.ENDFRAME_SEQUENCE_NUMBER); - } - - dataStream.writeInt(sequenceNumber_); - dataStream.write(nonce_); - - if (includeFrameSize_ || isFinalFrame_) { - dataStream.writeInt(frameContentLength_); - } - - dataStream.close(); - return outBytes.toByteArray(); - } catch (IOException e) { - throw new AwsCryptoException("Failed to serialize cipher frame headers", e); - } + nonce_ = nonce.clone(); + isFinalFrame_ = isFinal; + frameContentLength_ = frameContentLen; + } + + /** + * Serialize the header into a byte array. + * + * @return the serialized bytes of the header. + */ + public byte[] toByteArray() { + try { + ByteArrayOutputStream outBytes = new ByteArrayOutputStream(); + DataOutputStream dataStream = new DataOutputStream(outBytes); + + if (isFinalFrame_) { + dataStream.writeInt(Constants.ENDFRAME_SEQUENCE_NUMBER); + } + + dataStream.writeInt(sequenceNumber_); + dataStream.write(nonce_); + + if (includeFrameSize_ || isFinalFrame_) { + dataStream.writeInt(frameContentLength_); + } + + dataStream.close(); + return outBytes.toByteArray(); + } catch (IOException e) { + throw new AwsCryptoException("Failed to serialize cipher frame headers", e); } - - /** - * Parse the sequence number in the provided bytes. It looks for 4 bytes - * representing a integer primitive type in the provided bytes starting at - * the specified offset. - * - *

- * If successful, it returns the size of the parsed bytes which is the size - * of the integer primitive type. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the size of the integer - * primitive type. - * @throws ParseException - * if there are not sufficient bytes to parse the sequence - * number. - */ - private int parseSequenceNumber(final byte[] b, final int off) throws ParseException { - sequenceNumber_ = PrimitivesParser.parseInt(b, off); - return Integer.SIZE / Byte.SIZE; + } + + /** + * Parse the sequence number in the provided bytes. It looks for 4 bytes representing a integer + * primitive type in the provided bytes starting at the specified offset. + * + *

If successful, it returns the size of the parsed bytes which is the size of the integer + * primitive type. On failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the size of the integer primitive type. + * @throws ParseException if there are not sufficient bytes to parse the sequence number. + */ + private int parseSequenceNumber(final byte[] b, final int off) throws ParseException { + sequenceNumber_ = PrimitivesParser.parseInt(b, off); + return Integer.SIZE / Byte.SIZE; + } + + /** + * Parse the nonce in the provided bytes. It looks for bytes of size defined by the nonce length + * in the provided bytes starting at the specified off. + * + *

If successful, it returns the size of the parsed bytes which is the nonce length. On + * failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the nonce length. + * @throws ParseException if there are not sufficient bytes to parse the nonce. + */ + private int parseNonce(final byte[] b, final int off) throws ParseException { + final int bytesToParseLen = b.length - off; + if (bytesToParseLen >= nonceLength_) { + nonce_ = Arrays.copyOfRange(b, off, off + nonceLength_); + return nonceLength_; + } else { + throw new ParseException("Not enough bytes to parse nonce"); } - - /** - * Parse the nonce in the provided bytes. It looks for bytes of size - * defined by the nonce length in the provided bytes starting at the - * specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the nonce - * length. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the nonce length. - * @throws ParseException - * if there are not sufficient bytes to parse the nonce. - */ - private int parseNonce(final byte[] b, final int off) throws ParseException { - final int bytesToParseLen = b.length - off; - if (bytesToParseLen >= nonceLength_) { - nonce_ = Arrays.copyOfRange(b, off, off + nonceLength_); - return nonceLength_; - } else { - throw new ParseException("Not enough bytes to parse nonce"); - } + } + + /** + * Parse the frame content length in the provided bytes. It looks for 4 bytes representing an + * integer primitive type in the provided bytes starting at the specified off. + * + *

If successful, it returns the size of the parsed bytes which is the size of the integer + * primitive type. On failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the size of the integer primitive type. + * @throws ParseException if there are not sufficient bytes to parse the frame content length. + */ + private int parseFrameContentLength(final byte[] b, final int off) throws ParseException { + frameContentLength_ = PrimitivesParser.parseInt(b, off); + if (frameContentLength_ < 0) { + throw new BadCiphertextException("Invalid frame length in ciphertext"); } - - /** - * Parse the frame content length in the provided bytes. It looks for 4 - * bytes representing an integer primitive type in the provided bytes - * starting at the specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the size - * of the integer primitive type. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the size of the integer - * primitive type. - * @throws ParseException - * if there are not sufficient bytes to parse the frame content - * length. - */ - private int parseFrameContentLength(final byte[] b, final int off) throws ParseException { - frameContentLength_ = PrimitivesParser.parseInt(b, off); - if (frameContentLength_ < 0) { - throw new BadCiphertextException("Invalid frame length in ciphertext"); - } - return Integer.SIZE / Byte.SIZE; + return Integer.SIZE / Byte.SIZE; + } + + /** + * Deserialize the provided bytes starting at the specified offset to construct an instance of + * this class. + * + *

This method parses the provided bytes for the individual fields in this class. This methods + * also supports partial parsing where not all the bytes required for parsing the fields + * successfully are available. + * + * @param b the byte array to deserialize. + * @param off the offset in the byte array to use for deserialization. + * @return the number of bytes consumed in deserialization. + */ + public int deserialize(final byte[] b, final int off) { + if (b == null) { + return 0; } - /** - * Deserialize the provided bytes starting at the specified offset to - * construct an instance of this class. - * - *

- * This method parses the provided bytes for the individual fields in this - * class. This methods also supports partial parsing where not all the bytes - * required for parsing the fields successfully are available. - * - * @param b - * the byte array to deserialize. - * @param off - * the offset in the byte array to use for deserialization. - * @return - * the number of bytes consumed in deserialization. - */ - public int deserialize(final byte[] b, final int off) { - if (b == null) { - return 0; - } - - int parsedBytes = 0; - try { - if (sequenceNumber_ == 0) { - parsedBytes += parseSequenceNumber(b, off + parsedBytes); - } - - // parse the sequence number again if the sequence number parsed in - // the previous call is the final frame marker and this frame hasn't - // already been marked final. - if (sequenceNumber_ == Constants.ENDFRAME_SEQUENCE_NUMBER && !isFinalFrame_) { - parsedBytes += parseSequenceNumber(b, off + parsedBytes); - isFinalFrame_ = true; - } - - if (nonceLength_ > 0 && nonce_ == null) { - parsedBytes += parseNonce(b, off + parsedBytes); - } - - if (includeFrameSize_ || isFinalFrame_) { - if (frameContentLength_ < 0) { - parsedBytes += parseFrameContentLength(b, off + parsedBytes); - } - } - - isComplete_ = true; - } catch (ParseException e) { - // this results when we do partial parsing and there aren't enough - // bytes to parse; so just return the bytes parsed thus far. + int parsedBytes = 0; + try { + if (sequenceNumber_ == 0) { + parsedBytes += parseSequenceNumber(b, off + parsedBytes); + } + + // parse the sequence number again if the sequence number parsed in + // the previous call is the final frame marker and this frame hasn't + // already been marked final. + if (sequenceNumber_ == Constants.ENDFRAME_SEQUENCE_NUMBER && !isFinalFrame_) { + parsedBytes += parseSequenceNumber(b, off + parsedBytes); + isFinalFrame_ = true; + } + + if (nonceLength_ > 0 && nonce_ == null) { + parsedBytes += parseNonce(b, off + parsedBytes); + } + + if (includeFrameSize_ || isFinalFrame_) { + if (frameContentLength_ < 0) { + parsedBytes += parseFrameContentLength(b, off + parsedBytes); } + } - return parsedBytes; - } - - /** - * Return if the frame is a final frame. The final frame is identified as - * the frame containing the final sequence number marker. - * - * @return - * true if final frame; false otherwise. - */ - public boolean isFinalFrame() { - return isFinalFrame_; - } - - /** - * Check if this object has all the header fields populated and available - * for reading. - * - * @return - * true if this object containing the single block header fields - * is complete; false otherwise. - */ - public boolean isComplete() { - return isComplete_; - } - - /** - * Return the nonce set in the frame header. - * - * @return - * the bytes containing the nonce set in the frame header. - */ - public byte[] getNonce() { - return nonce_ != null ? nonce_.clone() : null; - } - - /** - * Return the frame content length set in the frame header. - * - * @return - * the frame content length set in the frame header. - */ - public int getFrameContentLength() { - return frameContentLength_; - } - - /** - * Return the frame sequence number set in the frame header. - * - * @return - * the frame sequence number set in the frame header. - */ - public int getSequenceNumber() { - return sequenceNumber_; + isComplete_ = true; + } catch (ParseException e) { + // this results when we do partial parsing and there aren't enough + // bytes to parse; so just return the bytes parsed thus far. } - /** - * Set the length of the nonce used in the encryption of the content in the - * frame. - * - * @param nonceLength - * the length of the nonce used in the encryption of the content - * in the frame. - */ - public void setNonceLength(final short nonceLength) { - nonceLength_ = nonceLength; - } - - /** - * Set the flag to specify whether the frame length needs to be included or - * parsed in the header. - * - * @param value - * true if the frame length needs to be included or parsed in the - * header; false otherwise - */ - public void includeFrameSize(final boolean value) { - includeFrameSize_ = true; - } + return parsedBytes; + } + + /** + * Return if the frame is a final frame. The final frame is identified as the frame containing the + * final sequence number marker. + * + * @return true if final frame; false otherwise. + */ + public boolean isFinalFrame() { + return isFinalFrame_; + } + + /** + * Check if this object has all the header fields populated and available for reading. + * + * @return true if this object containing the single block header fields is complete; false + * otherwise. + */ + public boolean isComplete() { + return isComplete_; + } + + /** + * Return the nonce set in the frame header. + * + * @return the bytes containing the nonce set in the frame header. + */ + public byte[] getNonce() { + return nonce_ != null ? nonce_.clone() : null; + } + + /** + * Return the frame content length set in the frame header. + * + * @return the frame content length set in the frame header. + */ + public int getFrameContentLength() { + return frameContentLength_; + } + + /** + * Return the frame sequence number set in the frame header. + * + * @return the frame sequence number set in the frame header. + */ + public int getSequenceNumber() { + return sequenceNumber_; + } + + /** + * Set the length of the nonce used in the encryption of the content in the frame. + * + * @param nonceLength the length of the nonce used in the encryption of the content in the frame. + */ + public void setNonceLength(final short nonceLength) { + nonceLength_ = nonceLength; + } + + /** + * Set the flag to specify whether the frame length needs to be included or parsed in the header. + * + * @param value true if the frame length needs to be included or parsed in the header; false + * otherwise + */ + public void includeFrameSize(final boolean value) { + includeFrameSize_ = true; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextFooters.java b/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextFooters.java index 061bfd32a..e67cda02e 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextFooters.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextFooters.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,116 +13,119 @@ package com.amazonaws.encryptionsdk.model; -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.util.Arrays; - import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.ParseException; import com.amazonaws.encryptionsdk.internal.Constants; import com.amazonaws.encryptionsdk.internal.PrimitivesParser; import com.amazonaws.encryptionsdk.internal.Utils; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Arrays; /** * This class encapsulates the optional footer information which follows the actual protected * content. - * - * It contains the following fields in order: + * + *

It contains the following fields in order: + * *

    - *
  1. AuthLength - 2 bytes - *
  2. MAuth - {@code AuthLength} bytes + *
  3. AuthLength - 2 bytes + *
  4. MAuth - {@code AuthLength} bytes *
*/ public class CiphertextFooters { - private int authLength_ = -1; - private byte[] mAuth_ = null; - private boolean isComplete_ = false; + private int authLength_ = -1; + private byte[] mAuth_ = null; + private boolean isComplete_ = false; - public CiphertextFooters() { - // Do nothing - } + public CiphertextFooters() { + // Do nothing + } - public CiphertextFooters(final byte[] mAuth) { - final int length = Utils.assertNonNull(mAuth, "mAuth").length; - if (length < 0 || length > Constants.UNSIGNED_SHORT_MAX_VAL) { - throw new IllegalArgumentException("Invalid length for mAuth: " + length); - } - authLength_ = length; - mAuth_ = mAuth.clone(); - isComplete_ = true; + public CiphertextFooters(final byte[] mAuth) { + final int length = Utils.assertNonNull(mAuth, "mAuth").length; + if (length < 0 || length > Constants.UNSIGNED_SHORT_MAX_VAL) { + throw new IllegalArgumentException("Invalid length for mAuth: " + length); } + authLength_ = length; + mAuth_ = mAuth.clone(); + isComplete_ = true; + } - /** - * Parses the footers from the {@code b} starting at offset {@code off} and returns the number - * of bytes parsed/consumed. - */ - public int deserialize(final byte[] b, final int off) throws ParseException { - if (b == null) { - return 0; - } - int parsedBytes = 0; - try { - if (authLength_ < 0) { - parsedBytes += parseLength(b, off + parsedBytes); - } - if (mAuth_ == null) { - parsedBytes += parseMauth(b, off + parsedBytes); - } - isComplete_ = true; - } catch (ParseException e) { - // this results when we do partial parsing and there aren't enough - // bytes to parse; ignore it and return the bytes parsed thus far. - } - return parsedBytes; + /** + * Parses the footers from the {@code b} starting at offset {@code off} and returns the number of + * bytes parsed/consumed. + */ + public int deserialize(final byte[] b, final int off) throws ParseException { + if (b == null) { + return 0; } - - public int getAuthLength() { - return authLength_; + int parsedBytes = 0; + try { + if (authLength_ < 0) { + parsedBytes += parseLength(b, off + parsedBytes); + } + if (mAuth_ == null) { + parsedBytes += parseMauth(b, off + parsedBytes); + } + isComplete_ = true; + } catch (ParseException e) { + // this results when we do partial parsing and there aren't enough + // bytes to parse; ignore it and return the bytes parsed thus far. } + return parsedBytes; + } - public byte[] getMAuth() { - return (mAuth_ != null) ? mAuth_.clone() : null; - } + public int getAuthLength() { + return authLength_; + } - /** - * Check if this object has all the header fields populated and available for reading. - * - * @return true if this object containing the single block header fields is complete; false - * otherwise. - */ - public boolean isComplete() { - return isComplete_; - } + public byte[] getMAuth() { + return (mAuth_ != null) ? mAuth_.clone() : null; + } - public byte[] toByteArray() { - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(baos)) { - PrimitivesParser.writeUnsignedShort(dos, authLength_); - dos.write(mAuth_); - dos.close(); - baos.close(); - return baos.toByteArray(); - } catch (final IOException ex) { - throw new AwsCryptoException(ex); - } - } + /** + * Check if this object has all the header fields populated and available for reading. + * + * @return true if this object containing the single block header fields is complete; false + * otherwise. + */ + public boolean isComplete() { + return isComplete_; + } - private int parseLength(final byte[] b, final int off) throws ParseException { - authLength_ = PrimitivesParser.parseUnsignedShort(b, off); - return 2; + public byte[] toByteArray() { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos)) { + PrimitivesParser.writeUnsignedShort(dos, authLength_); + dos.write(mAuth_); + dos.close(); + baos.close(); + return baos.toByteArray(); + } catch (final IOException ex) { + throw new AwsCryptoException(ex); } + } + + private int parseLength(final byte[] b, final int off) throws ParseException { + authLength_ = PrimitivesParser.parseUnsignedShort(b, off); + return 2; + } - private int parseMauth(final byte[] b, final int off) throws ParseException { - final int len = b.length - off; - if (len >= authLength_) { - mAuth_ = Arrays.copyOfRange(b, off, off + authLength_); - return authLength_; - } else { - throw new ParseException("Not enough bytes to parse mAuth, " - + " needed at least " + authLength_ + " bytes, but only had " - + len + " bytes"); - - } + private int parseMauth(final byte[] b, final int off) throws ParseException { + final int len = b.length - off; + if (len >= authLength_) { + mAuth_ = Arrays.copyOfRange(b, off, off + authLength_); + return authLength_; + } else { + throw new ParseException( + "Not enough bytes to parse mAuth, " + + " needed at least " + + authLength_ + + " bytes, but only had " + + len + + " bytes"); } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextHeaders.java b/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextHeaders.java index c992f9f10..94a35cc21 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextHeaders.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextHeaders.java @@ -3,897 +3,865 @@ package com.amazonaws.encryptionsdk.model; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; +import com.amazonaws.encryptionsdk.exception.ParseException; +import com.amazonaws.encryptionsdk.internal.Constants; +import com.amazonaws.encryptionsdk.internal.EncryptionContextSerializer; +import com.amazonaws.encryptionsdk.internal.PrimitivesParser; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; -import java.nio.ByteBuffer; import java.security.SecureRandom; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; -import com.amazonaws.encryptionsdk.exception.ParseException; -import com.amazonaws.encryptionsdk.internal.Constants; -import com.amazonaws.encryptionsdk.internal.EncryptionContextSerializer; -import com.amazonaws.encryptionsdk.internal.PrimitivesParser; - /** - * This class implements the headers for the message (ciphertext) produced by - * this library. These headers are parsed and used when the ciphertext is - * decrypted. + * This class implements the headers for the message (ciphertext) produced by this library. These + * headers are parsed and used when the ciphertext is decrypted. * - * See https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/message-format.html - * for a detailed description of the fields that make up the encrypted message header. + *

See https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/message-format.html for + * a detailed description of the fields that make up the encrypted message header. * - *

- * It is important to note that all but the last two header fields are checked - * for their integrity during decryption using AES-GCM with the nonce and MAC tag - * values supplied in the last two fields of the header. + *

It is important to note that all but the last two header fields are checked for their + * integrity during decryption using AES-GCM with the nonce and MAC tag values supplied in the last + * two fields of the header. */ public class CiphertextHeaders { - /** - * When passed as maxEncryptedDataKeys, indicates that no maximum should be enforced (i.e., any number of EDKs are allowed). - */ - public static final int NO_MAX_ENCRYPTED_DATA_KEYS = 0; - - private static final SecureRandom RND = new SecureRandom(); - private byte version_ = -1; - private byte typeVal_; // don't set this to -1 since Java byte is signed - // while this value is unsigned and can go up to 128. - private short cryptoAlgoVal_ = -1; - private byte[] messageId_; - private int encryptionContextLen_ = -1; - private byte[] encryptionContext_ = new byte[0]; - private int cipherKeyCount_ = -1; - private List cipherKeyBlobs_; - private byte contentTypeVal_ = -1; - private int reservedField_ = -1; - private short nonceLen_ = -1; - private int frameLength_ = -1; - - private byte[] headerNonce_; - private byte[] headerTag_; - - private int suiteDataLen_ = -1; - private byte[] suiteData_; - - // internal variables - private int currKeyBlobIndex_ = 0; - private boolean isComplete_; - private int maxEncryptedDataKeys_ = NO_MAX_ENCRYPTED_DATA_KEYS; - - /** - * Default constructor. - */ - public CiphertextHeaders() { - } - - /** - * Construct the ciphertext headers using the provided values. - * - * @param version - * the version to set in the header. - * @param type - * the type to set in the header. - * @param cryptoAlgo - * the CryptoAlgorithm enum to encode in the header. - * @param encryptionContext - * the bytes containing the encryption context to set in the - * header. - * @param keyBlobs - * list of keyBlobs containing the key provider id, key - * provider info, and encrypted data key to encode in the header. - * @param contentType - * the content type to set in the header. - * @param frameSize - * the frame payload size to set in the header. - * - * @deprecated {@link #CiphertextHeaders(CiphertextType, CryptoAlgorithm, byte[], List, ContentType, int)} - */ - @Deprecated - public CiphertextHeaders(final byte version, final CiphertextType type, final CryptoAlgorithm cryptoAlgo, - final byte[] encryptionContext, final List keyBlobs, final ContentType contentType, - final int frameSize) { - this(type, assertVersionCompatibility(version, cryptoAlgo), encryptionContext, keyBlobs, contentType, frameSize); - } - - // Utility method since there isn't another good way to check the argument prior to calling a second constructor - private static CryptoAlgorithm assertVersionCompatibility(final byte version, final CryptoAlgorithm cryptoAlgo) { - if (version != cryptoAlgo.getMessageFormatVersion()) { - throw new IllegalArgumentException("Version must match the message format version from the type"); - } - return cryptoAlgo; - } - - /** - * Construct the ciphertext headers using the provided values. - * - * @param type - * the type to set in the header. - * @param cryptoAlgo - * the CryptoAlgorithm enum to encode in the header. - * @param encryptionContext - * the bytes containing the encryption context to set in the - * header. - * @param keyBlobs - * list of keyBlobs containing the key provider id, key - * provider info, and encrypted data key to encode in the header. - * @param contentType - * the content type to set in the header. - * @param frameSize - * the frame payload size to set in the header. - */ - public CiphertextHeaders(final CiphertextType type, final CryptoAlgorithm cryptoAlgo, - final byte[] encryptionContext, final List keyBlobs, final ContentType contentType, - final int frameSize) { - - version_ = cryptoAlgo.getMessageFormatVersion(); - typeVal_ = type.getValue(); - - cryptoAlgoVal_ = cryptoAlgo.getValue(); - - encryptionContext_ = encryptionContext.clone(); - if (encryptionContext_.length > Constants.UNSIGNED_SHORT_MAX_VAL) { - throw new AwsCryptoException("Size of encryption context exceeds the allowed maximum " - + Constants.UNSIGNED_SHORT_MAX_VAL); - } - encryptionContextLen_ = encryptionContext.length; - - // we only support the encoding of 1 data key in the cipher blob. - cipherKeyCount_ = keyBlobs.size(); - cipherKeyBlobs_ = new ArrayList<>(keyBlobs); - - contentTypeVal_ = contentType.getValue(); - reservedField_ = 0; - nonceLen_ = cryptoAlgo.getNonceLen(); - - // generate random bytes and assign them as the unique identifier of the - // message wrapped by this header. - messageId_ = new byte[cryptoAlgo.getMessageIdLength()]; - RND.nextBytes(messageId_); - - frameLength_ = frameSize; - - // Completed by construction - isComplete_ = true; - } - - /** - * Check if this object has all the header fields populated and available - * for reading. - * - * @return - * true if this object containing the single block header fields - * is complete; false otherwise. - */ - public Boolean isComplete() { - return isComplete_; - } - - /** - * Parse the version in the provided bytes. It looks for a - * single byte in the provided bytes starting at the specified off. - * - * @see {@link ParsingStep} - */ - private int parseVersion(final byte[] b, final int off) throws ParseException { - if (version_ >= 0) { - return 0; - } - version_ = PrimitivesParser.parseByte(b, off); - return 1; - } - - /** - * Sets appropriate constants and parameters for v1 parsing - */ - private int configV1(final byte[] b, final int off) { - suiteDataLen_ = -1; - return 0; - } - - /** - * Sets appropriate constants and parameters for v2 parsing - */ - private int configV2(final byte[] b, final int off) { - suiteDataLen_ = getCryptoAlgoId().getSuiteDataLength(); - typeVal_ = CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA.getValue(); - headerNonce_ = getCryptoAlgoId().getHeaderNonce(); - if (headerNonce_ == null) { - throw new IllegalStateException("Message format v2 requires the algorithm to specify a header nonce."); - } - if (headerNonce_.length > Short.MAX_VALUE) { - throw new IllegalStateException("Message format v2 requires the algorithm to specify a header nonce with " + - "length less than 2^15."); - } - nonceLen_ = (short) headerNonce_.length; - return 0; - } - - /** - * Parse the type in the provided bytes. It looks for a - * single byte in the provided bytes starting at the specified off. - * - * @see {@link ParsingStep} - */ - private int parseType(final byte[] b, final int off) throws ParseException { - if (typeVal_ != 0) { - return 0; - } - typeVal_ = PrimitivesParser.parseByte(b, off); - if (CiphertextType.deserialize(typeVal_) == null) { - throw new BadCiphertextException("Invalid ciphertext type."); - } - return 1; - } - - /** - * Parse the algorithm identifier in the provided bytes. It looks for 2 - * bytes representing a short primitive type in the provided bytes starting - * at the specified off. - * - * @see {@link ParsingStep} - */ - private int parseAlgoId(final byte[] b, final int off) throws ParseException { - if (cryptoAlgoVal_ >= 0) { - return 0; - } - cryptoAlgoVal_ = PrimitivesParser.parseShort(b, off); - if (CryptoAlgorithm.deserialize(version_, cryptoAlgoVal_) == null) { - throw new BadCiphertextException("Invalid algorithm identifier in ciphertext"); - } - return Short.SIZE / Byte.SIZE; - } - - /** - * Parse the message ID in the provided bytes. It looks for bytes of the - * size defined by the message identifier length in the provided bytes - * starting at the specified off. - * - * @see {@link ParsingStep} - */ - private int parseMessageId(final byte[] b, final int off) throws ParseException { - if (messageId_ != null) { - return 0; - } - final int len = b.length - off; - final int messageIdLen = getCryptoAlgoId().getMessageIdLength(); - if (len >= messageIdLen) { - messageId_ = Arrays.copyOfRange(b, off, off + messageIdLen); - return messageIdLen; - } else { - throw new ParseException("Not enough bytes to parse serial number"); - } - } - - /** - * Parses suite specific data - * - * @see {@link ParsingStep} - */ - private int parseSuiteData(final byte[] b, final int off) throws ParseException { - if (suiteData_ != null) { - return 0; - } - final int len = b.length - off; - if (len >= suiteDataLen_) { - suiteData_ = Arrays.copyOfRange(b, off, off + suiteDataLen_); - return suiteDataLen_; - } else { - throw new ParseException("Not enough bytes to parse suite specific data"); - } - } - - /** - * Parse the length of the encryption context in the provided bytes. It - * looks for 2 bytes representing a short primitive type in the provided - * bytes starting at the specified off. - * - * @see {@link ParsingStep} - */ - private int parseEncryptionContextLen(final byte[] b, final int off) throws ParseException { - if (encryptionContextLen_ >= 0) { - return 0; - } - encryptionContextLen_ = PrimitivesParser.parseUnsignedShort(b, off); - if (encryptionContextLen_ < 0) { - throw new BadCiphertextException("Invalid encryption context length in ciphertext"); - } - return Short.SIZE / Byte.SIZE; - } - - /** - * Parse the encryption context in the provided bytes. It looks for bytes of - * size defined by the encryption context length in the provided bytes - * starting at the specified off. - * - * @see {@link ParsingStep} - */ - private int parseEncryptionContext(final byte[] b, final int off) throws ParseException { - if (encryptionContextLen_ < encryptionContext_.length) { - throw new IllegalStateException("Parsed encryption context is in an invalid state. Size exceeds parsed " + - "encryption context length."); - } - if (encryptionContextLen_ == encryptionContext_.length) { - return 0; - } - final int len = b.length - off; - if (len >= encryptionContextLen_) { - encryptionContext_ = Arrays.copyOfRange(b, off, off + encryptionContextLen_); - return encryptionContextLen_; - } else { - throw new ParseException("Not enough bytes to parse encryption context"); - } - } - - /** - * Parse the data key count in the provided bytes. It looks for 2 bytes - * representing a short primitive type in the provided bytes starting at the - * specified off. - * - * @see {@link ParsingStep} - */ - private int parseEncryptedDataKeyCount(final byte[] b, final int off) throws ParseException { - if (cipherKeyCount_ >= 0) { - return 0; - } - cipherKeyCount_ = PrimitivesParser.parseUnsignedShort(b, off); - if (cipherKeyCount_ < 0) { - throw new BadCiphertextException("Invalid cipher key count in ciphertext"); - } - if (maxEncryptedDataKeys_ > 0 && cipherKeyCount_ > maxEncryptedDataKeys_) { - throw new AwsCryptoException("Ciphertext encrypted data keys exceed maxEncryptedDataKeys"); - } - cipherKeyBlobs_ = Arrays.asList(new KeyBlob[cipherKeyCount_]); - return Short.SIZE / Byte.SIZE; - } - - /** - * Parses the list of encrypted key blobs. - * Unlike many of the other parsing methods, this one can make partial progress. - * To indicate this partial progress it throws a {@link PartialParseException} containing - * the number of parsed bytes. - * - * @see {@link ParsingStep} - */ - private int parseEncryptedKeyBlobList(final byte[] b, final int off) throws PartialParseException { - int parsedBytes = 0; - try { - if (cipherKeyCount_ > 0) { - while (currKeyBlobIndex_ < cipherKeyCount_) { - if (cipherKeyBlobs_.get(currKeyBlobIndex_) == null) { - cipherKeyBlobs_.set(currKeyBlobIndex_, new KeyBlob()); - } - if (cipherKeyBlobs_.get(currKeyBlobIndex_).isComplete() == false) { - parsedBytes += parseEncryptedKeyBlob(b, off + parsedBytes); - // check if we had enough bytes to parse the key blob - if (cipherKeyBlobs_.get(currKeyBlobIndex_).isComplete() == false) { - throw new ParseException("Not enough bytes to parse key blob"); - } - } - currKeyBlobIndex_++; - } + /** + * When passed as maxEncryptedDataKeys, indicates that no maximum should be enforced (i.e., any + * number of EDKs are allowed). + */ + public static final int NO_MAX_ENCRYPTED_DATA_KEYS = 0; + + private static final SecureRandom RND = new SecureRandom(); + private byte version_ = -1; + private byte typeVal_; // don't set this to -1 since Java byte is signed + // while this value is unsigned and can go up to 128. + private short cryptoAlgoVal_ = -1; + private byte[] messageId_; + private int encryptionContextLen_ = -1; + private byte[] encryptionContext_ = new byte[0]; + private int cipherKeyCount_ = -1; + private List cipherKeyBlobs_; + private byte contentTypeVal_ = -1; + private int reservedField_ = -1; + private short nonceLen_ = -1; + private int frameLength_ = -1; + + private byte[] headerNonce_; + private byte[] headerTag_; + + private int suiteDataLen_ = -1; + private byte[] suiteData_; + + // internal variables + private int currKeyBlobIndex_ = 0; + private boolean isComplete_; + private int maxEncryptedDataKeys_ = NO_MAX_ENCRYPTED_DATA_KEYS; + + /** Default constructor. */ + public CiphertextHeaders() {} + + /** + * Construct the ciphertext headers using the provided values. + * + * @param version the version to set in the header. + * @param type the type to set in the header. + * @param cryptoAlgo the CryptoAlgorithm enum to encode in the header. + * @param encryptionContext the bytes containing the encryption context to set in the header. + * @param keyBlobs list of keyBlobs containing the key provider id, key provider info, and + * encrypted data key to encode in the header. + * @param contentType the content type to set in the header. + * @param frameSize the frame payload size to set in the header. + * @deprecated {@link #CiphertextHeaders(CiphertextType, CryptoAlgorithm, byte[], List, + * ContentType, int)} + */ + @Deprecated + public CiphertextHeaders( + final byte version, + final CiphertextType type, + final CryptoAlgorithm cryptoAlgo, + final byte[] encryptionContext, + final List keyBlobs, + final ContentType contentType, + final int frameSize) { + this( + type, + assertVersionCompatibility(version, cryptoAlgo), + encryptionContext, + keyBlobs, + contentType, + frameSize); + } + + // Utility method since there isn't another good way to check the argument prior to calling a + // second constructor + private static CryptoAlgorithm assertVersionCompatibility( + final byte version, final CryptoAlgorithm cryptoAlgo) { + if (version != cryptoAlgo.getMessageFormatVersion()) { + throw new IllegalArgumentException( + "Version must match the message format version from the type"); + } + return cryptoAlgo; + } + + /** + * Construct the ciphertext headers using the provided values. + * + * @param type the type to set in the header. + * @param cryptoAlgo the CryptoAlgorithm enum to encode in the header. + * @param encryptionContext the bytes containing the encryption context to set in the header. + * @param keyBlobs list of keyBlobs containing the key provider id, key provider info, and + * encrypted data key to encode in the header. + * @param contentType the content type to set in the header. + * @param frameSize the frame payload size to set in the header. + */ + public CiphertextHeaders( + final CiphertextType type, + final CryptoAlgorithm cryptoAlgo, + final byte[] encryptionContext, + final List keyBlobs, + final ContentType contentType, + final int frameSize) { + + version_ = cryptoAlgo.getMessageFormatVersion(); + typeVal_ = type.getValue(); + + cryptoAlgoVal_ = cryptoAlgo.getValue(); + + encryptionContext_ = encryptionContext.clone(); + if (encryptionContext_.length > Constants.UNSIGNED_SHORT_MAX_VAL) { + throw new AwsCryptoException( + "Size of encryption context exceeds the allowed maximum " + + Constants.UNSIGNED_SHORT_MAX_VAL); + } + encryptionContextLen_ = encryptionContext.length; + + // we only support the encoding of 1 data key in the cipher blob. + cipherKeyCount_ = keyBlobs.size(); + cipherKeyBlobs_ = new ArrayList<>(keyBlobs); + + contentTypeVal_ = contentType.getValue(); + reservedField_ = 0; + nonceLen_ = cryptoAlgo.getNonceLen(); + + // generate random bytes and assign them as the unique identifier of the + // message wrapped by this header. + messageId_ = new byte[cryptoAlgo.getMessageIdLength()]; + RND.nextBytes(messageId_); + + frameLength_ = frameSize; + + // Completed by construction + isComplete_ = true; + } + + /** + * Check if this object has all the header fields populated and available for reading. + * + * @return true if this object containing the single block header fields is complete; false + * otherwise. + */ + public Boolean isComplete() { + return isComplete_; + } + + /** + * Parse the version in the provided bytes. It looks for a single byte in the provided bytes + * starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseVersion(final byte[] b, final int off) throws ParseException { + if (version_ >= 0) { + return 0; + } + version_ = PrimitivesParser.parseByte(b, off); + return 1; + } + + /** Sets appropriate constants and parameters for v1 parsing */ + private int configV1(final byte[] b, final int off) { + suiteDataLen_ = -1; + return 0; + } + + /** Sets appropriate constants and parameters for v2 parsing */ + private int configV2(final byte[] b, final int off) { + suiteDataLen_ = getCryptoAlgoId().getSuiteDataLength(); + typeVal_ = CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA.getValue(); + headerNonce_ = getCryptoAlgoId().getHeaderNonce(); + if (headerNonce_ == null) { + throw new IllegalStateException( + "Message format v2 requires the algorithm to specify a header nonce."); + } + if (headerNonce_.length > Short.MAX_VALUE) { + throw new IllegalStateException( + "Message format v2 requires the algorithm to specify a header nonce with " + + "length less than 2^15."); + } + nonceLen_ = (short) headerNonce_.length; + return 0; + } + + /** + * Parse the type in the provided bytes. It looks for a single byte in the provided bytes starting + * at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseType(final byte[] b, final int off) throws ParseException { + if (typeVal_ != 0) { + return 0; + } + typeVal_ = PrimitivesParser.parseByte(b, off); + if (CiphertextType.deserialize(typeVal_) == null) { + throw new BadCiphertextException("Invalid ciphertext type."); + } + return 1; + } + + /** + * Parse the algorithm identifier in the provided bytes. It looks for 2 bytes representing a short + * primitive type in the provided bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseAlgoId(final byte[] b, final int off) throws ParseException { + if (cryptoAlgoVal_ >= 0) { + return 0; + } + cryptoAlgoVal_ = PrimitivesParser.parseShort(b, off); + if (CryptoAlgorithm.deserialize(version_, cryptoAlgoVal_) == null) { + throw new BadCiphertextException("Invalid algorithm identifier in ciphertext"); + } + return Short.SIZE / Byte.SIZE; + } + + /** + * Parse the message ID in the provided bytes. It looks for bytes of the size defined by the + * message identifier length in the provided bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseMessageId(final byte[] b, final int off) throws ParseException { + if (messageId_ != null) { + return 0; + } + final int len = b.length - off; + final int messageIdLen = getCryptoAlgoId().getMessageIdLength(); + if (len >= messageIdLen) { + messageId_ = Arrays.copyOfRange(b, off, off + messageIdLen); + return messageIdLen; + } else { + throw new ParseException("Not enough bytes to parse serial number"); + } + } + + /** + * Parses suite specific data + * + * @see {@link ParsingStep} + */ + private int parseSuiteData(final byte[] b, final int off) throws ParseException { + if (suiteData_ != null) { + return 0; + } + final int len = b.length - off; + if (len >= suiteDataLen_) { + suiteData_ = Arrays.copyOfRange(b, off, off + suiteDataLen_); + return suiteDataLen_; + } else { + throw new ParseException("Not enough bytes to parse suite specific data"); + } + } + + /** + * Parse the length of the encryption context in the provided bytes. It looks for 2 bytes + * representing a short primitive type in the provided bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseEncryptionContextLen(final byte[] b, final int off) throws ParseException { + if (encryptionContextLen_ >= 0) { + return 0; + } + encryptionContextLen_ = PrimitivesParser.parseUnsignedShort(b, off); + if (encryptionContextLen_ < 0) { + throw new BadCiphertextException("Invalid encryption context length in ciphertext"); + } + return Short.SIZE / Byte.SIZE; + } + + /** + * Parse the encryption context in the provided bytes. It looks for bytes of size defined by the + * encryption context length in the provided bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseEncryptionContext(final byte[] b, final int off) throws ParseException { + if (encryptionContextLen_ < encryptionContext_.length) { + throw new IllegalStateException( + "Parsed encryption context is in an invalid state. Size exceeds parsed " + + "encryption context length."); + } + if (encryptionContextLen_ == encryptionContext_.length) { + return 0; + } + final int len = b.length - off; + if (len >= encryptionContextLen_) { + encryptionContext_ = Arrays.copyOfRange(b, off, off + encryptionContextLen_); + return encryptionContextLen_; + } else { + throw new ParseException("Not enough bytes to parse encryption context"); + } + } + + /** + * Parse the data key count in the provided bytes. It looks for 2 bytes representing a short + * primitive type in the provided bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseEncryptedDataKeyCount(final byte[] b, final int off) throws ParseException { + if (cipherKeyCount_ >= 0) { + return 0; + } + cipherKeyCount_ = PrimitivesParser.parseUnsignedShort(b, off); + if (cipherKeyCount_ < 0) { + throw new BadCiphertextException("Invalid cipher key count in ciphertext"); + } + if (maxEncryptedDataKeys_ > 0 && cipherKeyCount_ > maxEncryptedDataKeys_) { + throw new AwsCryptoException("Ciphertext encrypted data keys exceed maxEncryptedDataKeys"); + } + cipherKeyBlobs_ = Arrays.asList(new KeyBlob[cipherKeyCount_]); + return Short.SIZE / Byte.SIZE; + } + + /** + * Parses the list of encrypted key blobs. Unlike many of the other parsing methods, this one can + * make partial progress. To indicate this partial progress it throws a {@link + * PartialParseException} containing the number of parsed bytes. + * + * @see {@link ParsingStep} + */ + private int parseEncryptedKeyBlobList(final byte[] b, final int off) + throws PartialParseException { + int parsedBytes = 0; + try { + if (cipherKeyCount_ > 0) { + while (currKeyBlobIndex_ < cipherKeyCount_) { + if (cipherKeyBlobs_.get(currKeyBlobIndex_) == null) { + cipherKeyBlobs_.set(currKeyBlobIndex_, new KeyBlob()); + } + if (cipherKeyBlobs_.get(currKeyBlobIndex_).isComplete() == false) { + parsedBytes += parseEncryptedKeyBlob(b, off + parsedBytes); + // check if we had enough bytes to parse the key blob + if (cipherKeyBlobs_.get(currKeyBlobIndex_).isComplete() == false) { + throw new ParseException("Not enough bytes to parse key blob"); } - } catch (final ParseException ex) { - throw new PartialParseException(ex, parsedBytes); - } - return parsedBytes; - } - - /** - * Parse the encrypted key blob. It delegates the parsing to the methods in - * the key blob class. - * - * @see {@link ParsingStep} - */ - private int parseEncryptedKeyBlob(final byte[] b, final int off) throws ParseException { - return cipherKeyBlobs_.get(currKeyBlobIndex_).deserialize(b, off); - } - - /** - * Parse the content type in the provided bytes. It looks for a - * single byte in the provided bytes starting at the specified off. - * - * @see {@link ParsingStep} - */ - private int parseContentType(final byte[] b, final int off) throws ParseException { - if (contentTypeVal_ >= 0) { - return 0; - } - contentTypeVal_ = PrimitivesParser.parseByte(b, off); - if (ContentType.deserialize(contentTypeVal_) == null) { - throw new BadCiphertextException("Invalid content type in ciphertext."); - } - return 1; - } - - /** - * Parse reserved field in the provided bytes. It looks for 4 bytes - * representing an integer primitive type in the provided bytes starting at - * the specified off. - * - * @see {@link ParsingStep} - */ - private int parseReservedField(final byte[] b, final int off) throws ParseException { - if (reservedField_ >= 0) { - return 0; - } - reservedField_ = PrimitivesParser.parseInt(b, off); - if (reservedField_ != 0) { - throw new BadCiphertextException("Invalid value for reserved field in ciphertext"); - } - return Integer.SIZE / Byte.SIZE; - } - - /** - * Parse the length of the nonce in the provided bytes. It looks for a - * single byte in the provided bytes starting at the specified off. - * - * @see {@link ParsingStep} - */ - private int parseNonceLen(final byte[] b, final int off) throws ParseException { - if (nonceLen_ >= 0) { - return 0; - } - nonceLen_ = PrimitivesParser.parseByte(b, off); - if (nonceLen_ < 0) { - throw new BadCiphertextException("Invalid nonce length in ciphertext"); - } - return 1; - } - - /** - * Parse the frame payload length in the provided bytes. It looks for 4 - * bytes representing an integer primitive type in the provided bytes - * starting at the specified off. - * - * @see {@link ParsingStep} - */ - private int parseFramePayloadLength(final byte[] b, final int off) throws ParseException { - if (frameLength_ >= 0) { - return 0; - } - frameLength_ = PrimitivesParser.parseInt(b, off); - if (frameLength_ < 0) { - throw new BadCiphertextException("Invalid frame length in ciphertext"); - } - return Integer.SIZE / Byte.SIZE; - } - - /** - * Parse the header nonce in the provided bytes. It looks for bytes of the - * size defined by the nonce length in the provided bytes starting at the - * specified off. - * - * @see {@link ParsingStep} - */ - private int parseHeaderNonce(final byte[] b, final int off) throws ParseException { - if (nonceLen_ == 0 || headerNonce_ != null) { - return 0; - } - final int len = b.length - off; - if (len >= nonceLen_) { - headerNonce_ = Arrays.copyOfRange(b, off, off + nonceLen_); - return nonceLen_; - } else { - throw new ParseException("Not enough bytes to parse header nonce"); - } - } - - /** - * Parse the header tag in the provided bytes. It uses the crypto algorithm - * identifier to determine the length of the tag to parse. It looks for - * bytes of size defined by the tag length in the provided bytes starting at - * the specified off. - * - * @see {@link ParsingStep} - */ - private int parseHeaderTag(final byte[] b, final int off) throws ParseException { - if (headerTag_ != null) { - return 0; - } - final int len = b.length - off; - final CryptoAlgorithm cryptoAlgo = CryptoAlgorithm.deserialize(version_, cryptoAlgoVal_); - final int tagLen = cryptoAlgo.getTagLen(); - if (len >= tagLen) { - headerTag_ = Arrays.copyOfRange(b, off, off + tagLen); - return tagLen; - } else { - throw new ParseException("Not enough bytes to parse header tag"); - } - } - - /** - * Marks a deserialization operation as complete. - * This method always succeeds while consuming zero bytes. - * It sets {@link #isComplete_} to {@code true}. - * - * @see {@link ParsingStep} - */ - private int parseComplete(final byte[] b, final int off) throws ParseException { - isComplete_ = true; - return 0; - } - - /** - * Deserialize the provided bytes starting at the specified offset to - * construct an instance of this class. - * - *

- * This method parses the provided bytes for the individual fields in this - * class. This methods also supports partial parsing where not all the bytes - * required for parsing the fields successfully are available. - * - * @param b - * the byte array to deserialize. - * @param off - * the offset in the byte array to use for deserialization. - * @param maxEncryptedDataKeys - * the maximum number of EDKs to deserialize; zero indicates no maximum - * @return - * the number of bytes consumed in deserialization. - */ - public int deserialize(final byte[] b, final int off, int maxEncryptedDataKeys) throws ParseException { - if (b == null) { - return 0; - } - - maxEncryptedDataKeys_ = maxEncryptedDataKeys; - - int parsedBytes = 0; - try { - parsedBytes += parseVersion(b, off + parsedBytes); - - final ParsingStep[] steps; - switch (version_) { - case 1: // Initial version - steps = new ParsingStep[]{ - this::configV1, - this::parseType, - this::parseAlgoId, - this::parseMessageId, - this::parseEncryptionContextLen, - this::parseEncryptionContext, - this::parseEncryptedDataKeyCount, - this::parseEncryptedKeyBlobList, - this::parseContentType, - this::parseReservedField, - this::parseNonceLen, - this::parseFramePayloadLength, - this::parseHeaderNonce, - this::parseHeaderTag, - this::parseComplete}; - break; - case 2: - steps = new ParsingStep[]{ - this::parseAlgoId, - this::configV2, // Must come after we've parsed the algorithm - this::parseMessageId, - this::parseEncryptionContextLen, - this::parseEncryptionContext, - this::parseEncryptedDataKeyCount, - this::parseEncryptedKeyBlobList, - this::parseContentType, - this::parseFramePayloadLength, - this::parseSuiteData, - this::parseHeaderTag, - this::parseComplete}; - break; - default: - throw new BadCiphertextException("Invalid version"); - } - - for (final ParsingStep step : steps) { - parsedBytes += step.parse(b, off + parsedBytes); - } - - } catch (final PartialParseException e) { - // this results when we do partial parsing and there aren't enough - // bytes to parse; ignore it and return the bytes parsed thus far. - parsedBytes += e.bytesParsed_; - } catch (final ParseException e) { - // this results when we do partial parsing and there aren't enough - // bytes to parse; ignore it and return the bytes parsed thus far. - } - - return parsedBytes; - } - - /** - * Serialize the header fields into a byte array. Note this method does not - * serialize the header nonce and tag. - * - * @return - * the serialized bytes of the header fields not including the - * header nonce and tag. - */ - public byte[] serializeAuthenticatedFields() { - try { - ByteArrayOutputStream outBytes = new ByteArrayOutputStream(); - DataOutputStream dataStream = new DataOutputStream(outBytes); - - dataStream.writeByte(version_); - - if (version_ == 1) { - dataStream.writeByte(typeVal_); - dataStream.writeShort(cryptoAlgoVal_); - dataStream.write(messageId_); - PrimitivesParser.writeUnsignedShort(dataStream, encryptionContextLen_); - if (encryptionContextLen_ > 0) { - dataStream.write(encryptionContext_); - } - - dataStream.writeShort(cipherKeyCount_); - for (int i = 0; i < cipherKeyCount_; i++) { - final byte[] cipherKeyBlobBytes = cipherKeyBlobs_.get(i).toByteArray(); - dataStream.write(cipherKeyBlobBytes); - } - - dataStream.writeByte(contentTypeVal_); - dataStream.writeInt(reservedField_); - - dataStream.writeByte(nonceLen_); - dataStream.writeInt(frameLength_); - } else if (version_ == 2){ - dataStream.writeShort(cryptoAlgoVal_); - dataStream.write(messageId_); - PrimitivesParser.writeUnsignedShort(dataStream, encryptionContextLen_); - if (encryptionContextLen_ > 0) { - dataStream.write(encryptionContext_); - } - - dataStream.writeShort(cipherKeyCount_); - for (int i = 0; i < cipherKeyCount_; i++) { - final byte[] cipherKeyBlobBytes = cipherKeyBlobs_.get(i).toByteArray(); - dataStream.write(cipherKeyBlobBytes); - } - - dataStream.writeByte(contentTypeVal_); - dataStream.writeInt(frameLength_); - dataStream.write(suiteData_); - } else { - throw new IllegalArgumentException("Unsupported version: " + version_); - } - dataStream.close(); - return outBytes.toByteArray(); - } catch (IOException e) { - throw new RuntimeException("Failed to serialize cipher text headers", e); - } - } - - /** - * Serialize the header fields into a byte array. This method serializes all - * the header fields including the header nonce and tag. - * - * @return - * the serialized bytes of the entire header. - */ - public byte[] toByteArray() { - if (headerNonce_ == null || headerTag_ == null) { - throw new AwsCryptoException("Header nonce and tag cannot be null."); - } - if (version_ == 2 && suiteData_ == null) { - throw new AwsCryptoException("Suite Data cannot be null in the v2 message format."); - } - - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - baos.write(serializeAuthenticatedFields()); - // The v1 header format includes the header nonce. - // In v2 this is specified by the crypto algorithm. - if (version_ == 1) { - baos.write(headerNonce_); - } - baos.write(headerTag_); - - return baos.toByteArray(); - } catch (IOException ex) { - throw new AwsCryptoException(ex); - } - } - - /** - * Return the version set in the header. - * - * @return - * the byte value representing the version. - */ - public byte getVersion() { - return version_; - } - - /** - * Return the type set in the header. - * - * @return - * the CiphertextType enum value representing the type set in the - * header. - */ - public CiphertextType getType() { - return CiphertextType.deserialize(typeVal_); - } - - /** - * Return the crypto algorithm identifier set in the header. - * - * @return - * the CryptoAlgorithm enum value representing the identifier set in - * the header. - */ - public CryptoAlgorithm getCryptoAlgoId() { - return CryptoAlgorithm.deserialize(version_, cryptoAlgoVal_); - } - - /** - * Return the length of the encryption context set in the header. - * - * @return - * the length of the encryption context set in the header. - */ - public int getEncryptionContextLen() { - return encryptionContextLen_; - } - - /** - * Return the encryption context set in the header. - * - * @return - * the bytes containing encryption context set in the header. - */ - public byte[] getEncryptionContext() { - return encryptionContext_.clone(); - } - - public Map getEncryptionContextMap() { - return EncryptionContextSerializer.deserialize(encryptionContext_); - } - - /** - * Return the count of the encrypted key blobs set in the header. - * - * @return - * the count of the encrypted key blobs set in the header. - */ - public int getEncryptedKeyBlobCount() { - return cipherKeyCount_; - } - - /** - * Return the encrypted key blobs set in the header. - * - * @return - * the KeyBlob objects representing the key blobs set in the header. - */ - public List getEncryptedKeyBlobs() { - return new ArrayList<>(cipherKeyBlobs_); - } - - /** - * Return the content type set in the header. - * - * @return - * the ContentType enum value representing the content type set in - * the header. - */ - public ContentType getContentType() { - return ContentType.deserialize(contentTypeVal_); - } - - /** - * Return the message identifier set in the header. - * - * @return - * the bytes containing the message identifier set in the header. - */ - public byte[] getMessageId() { - return messageId_ != null ? messageId_.clone() : null; - } - - /** - * Return the length of the nonce set in the header. - * - * @return - * the length of the nonce set in the header. - */ - public short getNonceLength() { - return nonceLen_; - } - - /** - * Return the length of the frame set in the header. - * - * @return - * the length of the frame set in the header. - */ - public int getFrameLength() { - return frameLength_; - } - - /** - * Return the header nonce set in the header. - * - * @return - * the bytes containing the header nonce set in the header. - */ - public byte[] getHeaderNonce() { - return headerNonce_ != null ? headerNonce_.clone() : null; - } - - /** - * Return the header tag set in the header. - * - * @return - * the header tag set in the header. - */ - public byte[] getHeaderTag() { - return headerTag_ != null ? headerTag_.clone() : null; - } - - /** - * Set the header nonce to use for authenticating the header data. - * - * @param headerNonce - * the header nonce to use. - */ - public void setHeaderNonce(final byte[] headerNonce) { - headerNonce_ = headerNonce.clone(); - } - - /** - * Set the header tag to use for authenticating the header data. - * - * @param headerTag - * the header tag to use. - */ - public void setHeaderTag(final byte[] headerTag) { - headerTag_ = headerTag.clone(); - } - - /** - * Return suite specific data. - * @return suiteData - */ - public byte[] getSuiteData() { - return suiteData_ != null ? suiteData_.clone() : null; - } - - /** - * Sets suite specific data - * @param suiteData - */ - public void setSuiteData(byte[] suiteData) { - suiteData_ = suiteData.clone(); - } - - private static class PartialParseException extends Exception { - private static final long serialVersionUID = 1L; - final int bytesParsed_; - - private PartialParseException(Throwable ex, int bytesParsed) { - super(ex); - bytesParsed_ = bytesParsed; - } - } - - /** - * Represents a single step in parsing a header. - * - * The following requirements apply: - *

    - *
  • It must be safe to call multiple times. This means that it knows if it has already parsed something and should be a NOP
  • - *
  • It returns how many bytes have been consumed. This will be 0 in the case of a NOP.
  • - *
  • If there are insufficient bytes and no bytes are consumed, it may throw either a {@link ParseException} - * or a {@link PartialParseException}.
  • - *
  • If there are insufficient bytes and some bytes are parsed then it must throw a {@link PartialParseException} - * indicating the number of bytes parsed.
  • - *
- */ - @FunctionalInterface - private interface ParsingStep { - int parse(byte[] b, int off) throws ParseException, PartialParseException; - } + } + currKeyBlobIndex_++; + } + } + } catch (final ParseException ex) { + throw new PartialParseException(ex, parsedBytes); + } + return parsedBytes; + } + + /** + * Parse the encrypted key blob. It delegates the parsing to the methods in the key blob class. + * + * @see {@link ParsingStep} + */ + private int parseEncryptedKeyBlob(final byte[] b, final int off) throws ParseException { + return cipherKeyBlobs_.get(currKeyBlobIndex_).deserialize(b, off); + } + + /** + * Parse the content type in the provided bytes. It looks for a single byte in the provided bytes + * starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseContentType(final byte[] b, final int off) throws ParseException { + if (contentTypeVal_ >= 0) { + return 0; + } + contentTypeVal_ = PrimitivesParser.parseByte(b, off); + if (ContentType.deserialize(contentTypeVal_) == null) { + throw new BadCiphertextException("Invalid content type in ciphertext."); + } + return 1; + } + + /** + * Parse reserved field in the provided bytes. It looks for 4 bytes representing an integer + * primitive type in the provided bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseReservedField(final byte[] b, final int off) throws ParseException { + if (reservedField_ >= 0) { + return 0; + } + reservedField_ = PrimitivesParser.parseInt(b, off); + if (reservedField_ != 0) { + throw new BadCiphertextException("Invalid value for reserved field in ciphertext"); + } + return Integer.SIZE / Byte.SIZE; + } + + /** + * Parse the length of the nonce in the provided bytes. It looks for a single byte in the provided + * bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseNonceLen(final byte[] b, final int off) throws ParseException { + if (nonceLen_ >= 0) { + return 0; + } + nonceLen_ = PrimitivesParser.parseByte(b, off); + if (nonceLen_ < 0) { + throw new BadCiphertextException("Invalid nonce length in ciphertext"); + } + return 1; + } + + /** + * Parse the frame payload length in the provided bytes. It looks for 4 bytes representing an + * integer primitive type in the provided bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseFramePayloadLength(final byte[] b, final int off) throws ParseException { + if (frameLength_ >= 0) { + return 0; + } + frameLength_ = PrimitivesParser.parseInt(b, off); + if (frameLength_ < 0) { + throw new BadCiphertextException("Invalid frame length in ciphertext"); + } + return Integer.SIZE / Byte.SIZE; + } + + /** + * Parse the header nonce in the provided bytes. It looks for bytes of the size defined by the + * nonce length in the provided bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseHeaderNonce(final byte[] b, final int off) throws ParseException { + if (nonceLen_ == 0 || headerNonce_ != null) { + return 0; + } + final int len = b.length - off; + if (len >= nonceLen_) { + headerNonce_ = Arrays.copyOfRange(b, off, off + nonceLen_); + return nonceLen_; + } else { + throw new ParseException("Not enough bytes to parse header nonce"); + } + } + + /** + * Parse the header tag in the provided bytes. It uses the crypto algorithm identifier to + * determine the length of the tag to parse. It looks for bytes of size defined by the tag length + * in the provided bytes starting at the specified off. + * + * @see {@link ParsingStep} + */ + private int parseHeaderTag(final byte[] b, final int off) throws ParseException { + if (headerTag_ != null) { + return 0; + } + final int len = b.length - off; + final CryptoAlgorithm cryptoAlgo = CryptoAlgorithm.deserialize(version_, cryptoAlgoVal_); + final int tagLen = cryptoAlgo.getTagLen(); + if (len >= tagLen) { + headerTag_ = Arrays.copyOfRange(b, off, off + tagLen); + return tagLen; + } else { + throw new ParseException("Not enough bytes to parse header tag"); + } + } + + /** + * Marks a deserialization operation as complete. This method always succeeds while consuming zero + * bytes. It sets {@link #isComplete_} to {@code true}. + * + * @see {@link ParsingStep} + */ + private int parseComplete(final byte[] b, final int off) throws ParseException { + isComplete_ = true; + return 0; + } + + /** + * Deserialize the provided bytes starting at the specified offset to construct an instance of + * this class. + * + *

This method parses the provided bytes for the individual fields in this class. This methods + * also supports partial parsing where not all the bytes required for parsing the fields + * successfully are available. + * + * @param b the byte array to deserialize. + * @param off the offset in the byte array to use for deserialization. + * @param maxEncryptedDataKeys the maximum number of EDKs to deserialize; zero indicates no + * maximum + * @return the number of bytes consumed in deserialization. + */ + public int deserialize(final byte[] b, final int off, int maxEncryptedDataKeys) + throws ParseException { + if (b == null) { + return 0; + } + + maxEncryptedDataKeys_ = maxEncryptedDataKeys; + + int parsedBytes = 0; + try { + parsedBytes += parseVersion(b, off + parsedBytes); + + final ParsingStep[] steps; + switch (version_) { + case 1: // Initial version + steps = + new ParsingStep[] { + this::configV1, + this::parseType, + this::parseAlgoId, + this::parseMessageId, + this::parseEncryptionContextLen, + this::parseEncryptionContext, + this::parseEncryptedDataKeyCount, + this::parseEncryptedKeyBlobList, + this::parseContentType, + this::parseReservedField, + this::parseNonceLen, + this::parseFramePayloadLength, + this::parseHeaderNonce, + this::parseHeaderTag, + this::parseComplete + }; + break; + case 2: + steps = + new ParsingStep[] { + this::parseAlgoId, + this::configV2, // Must come after we've parsed the algorithm + this::parseMessageId, + this::parseEncryptionContextLen, + this::parseEncryptionContext, + this::parseEncryptedDataKeyCount, + this::parseEncryptedKeyBlobList, + this::parseContentType, + this::parseFramePayloadLength, + this::parseSuiteData, + this::parseHeaderTag, + this::parseComplete + }; + break; + default: + throw new BadCiphertextException("Invalid version"); + } + + for (final ParsingStep step : steps) { + parsedBytes += step.parse(b, off + parsedBytes); + } + + } catch (final PartialParseException e) { + // this results when we do partial parsing and there aren't enough + // bytes to parse; ignore it and return the bytes parsed thus far. + parsedBytes += e.bytesParsed_; + } catch (final ParseException e) { + // this results when we do partial parsing and there aren't enough + // bytes to parse; ignore it and return the bytes parsed thus far. + } + + return parsedBytes; + } + + /** + * Serialize the header fields into a byte array. Note this method does not serialize the header + * nonce and tag. + * + * @return the serialized bytes of the header fields not including the header nonce and tag. + */ + public byte[] serializeAuthenticatedFields() { + try { + ByteArrayOutputStream outBytes = new ByteArrayOutputStream(); + DataOutputStream dataStream = new DataOutputStream(outBytes); + + dataStream.writeByte(version_); + + if (version_ == 1) { + dataStream.writeByte(typeVal_); + dataStream.writeShort(cryptoAlgoVal_); + dataStream.write(messageId_); + PrimitivesParser.writeUnsignedShort(dataStream, encryptionContextLen_); + if (encryptionContextLen_ > 0) { + dataStream.write(encryptionContext_); + } + + dataStream.writeShort(cipherKeyCount_); + for (int i = 0; i < cipherKeyCount_; i++) { + final byte[] cipherKeyBlobBytes = cipherKeyBlobs_.get(i).toByteArray(); + dataStream.write(cipherKeyBlobBytes); + } + + dataStream.writeByte(contentTypeVal_); + dataStream.writeInt(reservedField_); + + dataStream.writeByte(nonceLen_); + dataStream.writeInt(frameLength_); + } else if (version_ == 2) { + dataStream.writeShort(cryptoAlgoVal_); + dataStream.write(messageId_); + PrimitivesParser.writeUnsignedShort(dataStream, encryptionContextLen_); + if (encryptionContextLen_ > 0) { + dataStream.write(encryptionContext_); + } + + dataStream.writeShort(cipherKeyCount_); + for (int i = 0; i < cipherKeyCount_; i++) { + final byte[] cipherKeyBlobBytes = cipherKeyBlobs_.get(i).toByteArray(); + dataStream.write(cipherKeyBlobBytes); + } + + dataStream.writeByte(contentTypeVal_); + dataStream.writeInt(frameLength_); + dataStream.write(suiteData_); + } else { + throw new IllegalArgumentException("Unsupported version: " + version_); + } + dataStream.close(); + return outBytes.toByteArray(); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize cipher text headers", e); + } + } + + /** + * Serialize the header fields into a byte array. This method serializes all the header fields + * including the header nonce and tag. + * + * @return the serialized bytes of the entire header. + */ + public byte[] toByteArray() { + if (headerNonce_ == null || headerTag_ == null) { + throw new AwsCryptoException("Header nonce and tag cannot be null."); + } + if (version_ == 2 && suiteData_ == null) { + throw new AwsCryptoException("Suite Data cannot be null in the v2 message format."); + } + + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + baos.write(serializeAuthenticatedFields()); + // The v1 header format includes the header nonce. + // In v2 this is specified by the crypto algorithm. + if (version_ == 1) { + baos.write(headerNonce_); + } + baos.write(headerTag_); + + return baos.toByteArray(); + } catch (IOException ex) { + throw new AwsCryptoException(ex); + } + } + + /** + * Return the version set in the header. + * + * @return the byte value representing the version. + */ + public byte getVersion() { + return version_; + } + + /** + * Return the type set in the header. + * + * @return the CiphertextType enum value representing the type set in the header. + */ + public CiphertextType getType() { + return CiphertextType.deserialize(typeVal_); + } + + /** + * Return the crypto algorithm identifier set in the header. + * + * @return the CryptoAlgorithm enum value representing the identifier set in the header. + */ + public CryptoAlgorithm getCryptoAlgoId() { + return CryptoAlgorithm.deserialize(version_, cryptoAlgoVal_); + } + + /** + * Return the length of the encryption context set in the header. + * + * @return the length of the encryption context set in the header. + */ + public int getEncryptionContextLen() { + return encryptionContextLen_; + } + + /** + * Return the encryption context set in the header. + * + * @return the bytes containing encryption context set in the header. + */ + public byte[] getEncryptionContext() { + return encryptionContext_.clone(); + } + + public Map getEncryptionContextMap() { + return EncryptionContextSerializer.deserialize(encryptionContext_); + } + + /** + * Return the count of the encrypted key blobs set in the header. + * + * @return the count of the encrypted key blobs set in the header. + */ + public int getEncryptedKeyBlobCount() { + return cipherKeyCount_; + } + + /** + * Return the encrypted key blobs set in the header. + * + * @return the KeyBlob objects representing the key blobs set in the header. + */ + public List getEncryptedKeyBlobs() { + return new ArrayList<>(cipherKeyBlobs_); + } + + /** + * Return the content type set in the header. + * + * @return the ContentType enum value representing the content type set in the header. + */ + public ContentType getContentType() { + return ContentType.deserialize(contentTypeVal_); + } + + /** + * Return the message identifier set in the header. + * + * @return the bytes containing the message identifier set in the header. + */ + public byte[] getMessageId() { + return messageId_ != null ? messageId_.clone() : null; + } + + /** + * Return the length of the nonce set in the header. + * + * @return the length of the nonce set in the header. + */ + public short getNonceLength() { + return nonceLen_; + } + + /** + * Return the length of the frame set in the header. + * + * @return the length of the frame set in the header. + */ + public int getFrameLength() { + return frameLength_; + } + + /** + * Return the header nonce set in the header. + * + * @return the bytes containing the header nonce set in the header. + */ + public byte[] getHeaderNonce() { + return headerNonce_ != null ? headerNonce_.clone() : null; + } + + /** + * Return the header tag set in the header. + * + * @return the header tag set in the header. + */ + public byte[] getHeaderTag() { + return headerTag_ != null ? headerTag_.clone() : null; + } + + /** + * Set the header nonce to use for authenticating the header data. + * + * @param headerNonce the header nonce to use. + */ + public void setHeaderNonce(final byte[] headerNonce) { + headerNonce_ = headerNonce.clone(); + } + + /** + * Set the header tag to use for authenticating the header data. + * + * @param headerTag the header tag to use. + */ + public void setHeaderTag(final byte[] headerTag) { + headerTag_ = headerTag.clone(); + } + + /** + * Return suite specific data. + * + * @return suiteData + */ + public byte[] getSuiteData() { + return suiteData_ != null ? suiteData_.clone() : null; + } + + /** + * Sets suite specific data + * + * @param suiteData + */ + public void setSuiteData(byte[] suiteData) { + suiteData_ = suiteData.clone(); + } + + private static class PartialParseException extends Exception { + private static final long serialVersionUID = 1L; + final int bytesParsed_; + + private PartialParseException(Throwable ex, int bytesParsed) { + super(ex); + bytesParsed_ = bytesParsed; + } + } + + /** + * Represents a single step in parsing a header. + * + *

The following requirements apply: + * + *

    + *
  • It must be safe to call multiple times. This means that it knows if it has already parsed + * something and should be a NOP + *
  • It returns how many bytes have been consumed. This will be 0 in the case of a NOP. + *
  • If there are insufficient bytes and no bytes are consumed, it may throw either a {@link + * ParseException} or a {@link PartialParseException}. + *
  • If there are insufficient bytes and some bytes are parsed then it must throw a {@link + * PartialParseException} indicating the number of bytes parsed. + *
+ */ + @FunctionalInterface + private interface ParsingStep { + int parse(byte[] b, int off) throws ParseException, PartialParseException; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextType.java b/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextType.java index aef722763..a6a1aecdb 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextType.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/CiphertextType.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -19,58 +19,54 @@ /** * This enum describes the supported types of ciphertext in this library. - * - *

- * Format: CiphertextType(byte value representing the type) + * + *

Format: CiphertextType(byte value representing the type) */ public enum CiphertextType { - CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA(128); + CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA(128); - private final byte value_; + private final byte value_; - /** - * Create a mapping between the CiphertextType object and its byte value. - * This is a static method so the map is created when the class is loaded. - * This enables fast lookups of the CiphertextType given a value. - */ - private static final Map ID_MAPPING = new HashMap<>(); - static { - for (final CiphertextType s : EnumSet.allOf(CiphertextType.class)) { - ID_MAPPING.put(s.value_, s); - } - } + /** + * Create a mapping between the CiphertextType object and its byte value. This is a static method + * so the map is created when the class is loaded. This enables fast lookups of the CiphertextType + * given a value. + */ + private static final Map ID_MAPPING = new HashMap<>(); - private CiphertextType(final int value) { - /* - * Java reads literals as integers. So we cast the integer value to byte - * here to avoid doing this in the enum definitions above. - */ - value_ = (byte) value; + static { + for (final CiphertextType s : EnumSet.allOf(CiphertextType.class)) { + ID_MAPPING.put(s.value_, s); } + } - /** - * Return the value used to encode this ciphertext type object in the - * ciphertext. - * - * @return - * the byte value used to encode this ciphertext type. + private CiphertextType(final int value) { + /* + * Java reads literals as integers. So we cast the integer value to byte + * here to avoid doing this in the enum definitions above. */ - public byte getValue() { - return value_; - } + value_ = (byte) value; + } - /** - * Deserialize the provided byte value by returning the CiphertextType - * object representing the byte value. - * - * @param value - * the byte representing the value of the CiphertextType object. - * @return - * the CiphertextType object representing the byte value. - */ - public static CiphertextType deserialize(final byte value) { - final Byte valueByte = Byte.valueOf(value); - final CiphertextType result = ID_MAPPING.get(valueByte); - return result; - } + /** + * Return the value used to encode this ciphertext type object in the ciphertext. + * + * @return the byte value used to encode this ciphertext type. + */ + public byte getValue() { + return value_; + } + + /** + * Deserialize the provided byte value by returning the CiphertextType object representing the + * byte value. + * + * @param value the byte representing the value of the CiphertextType object. + * @return the CiphertextType object representing the byte value. + */ + public static CiphertextType deserialize(final byte value) { + final Byte valueByte = Byte.valueOf(value); + final CiphertextType result = ID_MAPPING.get(valueByte); + return result; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/ContentType.java b/src/main/java/com/amazonaws/encryptionsdk/model/ContentType.java index bab998248..97bb82cfe 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/ContentType.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/ContentType.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -18,70 +18,64 @@ import java.util.Map; /** - * This enum describes the supported types for storing the encrypted content in - * the message format. There are two types current currently supported: single - * block and frames. - * - *

- * The single block format stores the encrypted content in a single block - * wrapped with headers containing the nonce, MAC tag, and the content length. - * - *

- * The frame format partitions the encrypted content in multiple frames of a - * specified frame length. Each frame is wrapped by an header containing the - * frame sequence number, nonce, and the MAC tag. - * - *

- * Format: ContentType(byte value representing the type) + * This enum describes the supported types for storing the encrypted content in the message format. + * There are two types current currently supported: single block and frames. + * + *

The single block format stores the encrypted content in a single block wrapped with headers + * containing the nonce, MAC tag, and the content length. + * + *

The frame format partitions the encrypted content in multiple frames of a specified frame + * length. Each frame is wrapped by an header containing the frame sequence number, nonce, and the + * MAC tag. + * + *

Format: ContentType(byte value representing the type) */ public enum ContentType { - SINGLEBLOCK(1), FRAME(2); + SINGLEBLOCK(1), + FRAME(2); - private final byte value_; + private final byte value_; - /** - * Create a mapping between the ContentType object and its byte value. This - * is a static method so the map is created when the class is loaded. This - * enables fast lookups of the ContentType given a value. - */ - private static final Map ID_MAPPING = new HashMap(); - static { - for (final ContentType s : EnumSet.allOf(ContentType.class)) { - ID_MAPPING.put(s.value_, s); - } - } + /** + * Create a mapping between the ContentType object and its byte value. This is a static method so + * the map is created when the class is loaded. This enables fast lookups of the ContentType given + * a value. + */ + private static final Map ID_MAPPING = new HashMap(); - private ContentType(final int value) { - /* - * Java reads literals as integers. So we cast the integer value to byte - * here to avoid doing this in the enum definitions above. - */ - value_ = (byte) value; + static { + for (final ContentType s : EnumSet.allOf(ContentType.class)) { + ID_MAPPING.put(s.value_, s); } + } - /** - * Return the value used to encode this content type object in the - * ciphertext. - * - * @return - * the byte value used to encode this content type. + private ContentType(final int value) { + /* + * Java reads literals as integers. So we cast the integer value to byte + * here to avoid doing this in the enum definitions above. */ - public byte getValue() { - return value_; - } + value_ = (byte) value; + } - /** - * Deserialize the provided byte value by returning the ContentType object - * representing the byte value. - * - * @param value - * the byte representing the value of the ContentType object. - * @return - * the ContentType object representing the byte value. - */ - public static ContentType deserialize(final byte value) { - final Byte valueByte = Byte.valueOf(value); - final ContentType result = ID_MAPPING.get(valueByte); - return result; - } + /** + * Return the value used to encode this content type object in the ciphertext. + * + * @return the byte value used to encode this content type. + */ + public byte getValue() { + return value_; + } + + /** + * Deserialize the provided byte value by returning the ContentType object representing the byte + * value. + * + * @param value the byte representing the value of the ContentType object. + * @return the ContentType object representing the byte value. + */ + public static ContentType deserialize(final byte value) { + final Byte valueByte = Byte.valueOf(value); + final ContentType result = ID_MAPPING.get(valueByte); + return result; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/DecryptionMaterials.java b/src/main/java/com/amazonaws/encryptionsdk/model/DecryptionMaterials.java index 94423b884..4f137d3fb 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/DecryptionMaterials.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/DecryptionMaterials.java @@ -1,65 +1,64 @@ package com.amazonaws.encryptionsdk.model; -import java.security.PublicKey; - import com.amazonaws.encryptionsdk.DataKey; +import java.security.PublicKey; public final class DecryptionMaterials { - private final DataKey dataKey; - private final PublicKey trailingSignatureKey; + private final DataKey dataKey; + private final PublicKey trailingSignatureKey; - private DecryptionMaterials(Builder b) { - dataKey = b.getDataKey(); - trailingSignatureKey = b.getTrailingSignatureKey(); - } + private DecryptionMaterials(Builder b) { + dataKey = b.getDataKey(); + trailingSignatureKey = b.getTrailingSignatureKey(); + } - public DataKey getDataKey() { - return dataKey; - } + public DataKey getDataKey() { + return dataKey; + } - public PublicKey getTrailingSignatureKey() { - return trailingSignatureKey; - } + public PublicKey getTrailingSignatureKey() { + return trailingSignatureKey; + } - public static Builder newBuilder() { - return new Builder(); - } + public static Builder newBuilder() { + return new Builder(); + } - public Builder toBuilder() { - return new Builder(this); - } + public Builder toBuilder() { + return new Builder(this); + } - public static final class Builder { - private DataKey dataKey; - private PublicKey trailingSignatureKey; + public static final class Builder { + private DataKey dataKey; + private PublicKey trailingSignatureKey; - private Builder(DecryptionMaterials result) { - this.dataKey = result.getDataKey(); - this.trailingSignatureKey = result.getTrailingSignatureKey(); - } + private Builder(DecryptionMaterials result) { + this.dataKey = result.getDataKey(); + this.trailingSignatureKey = result.getTrailingSignatureKey(); + } - private Builder() {} + private Builder() {} - public DataKey getDataKey() { - return dataKey; - } + public DataKey getDataKey() { + return dataKey; + } - public Builder setDataKey(DataKey dataKey) { - this.dataKey = dataKey; - return this; - } + public Builder setDataKey(DataKey dataKey) { + this.dataKey = dataKey; + return this; + } - public PublicKey getTrailingSignatureKey() { - return trailingSignatureKey; - } + public PublicKey getTrailingSignatureKey() { + return trailingSignatureKey; + } - public Builder setTrailingSignatureKey(PublicKey trailingSignatureKey) { - this.trailingSignatureKey = trailingSignatureKey; - return this; - } + public Builder setTrailingSignatureKey(PublicKey trailingSignatureKey) { + this.trailingSignatureKey = trailingSignatureKey; + return this; + } - public DecryptionMaterials build() { - return new DecryptionMaterials(this); - } + public DecryptionMaterials build() { + return new DecryptionMaterials(this); } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/DecryptionMaterialsRequest.java b/src/main/java/com/amazonaws/encryptionsdk/model/DecryptionMaterialsRequest.java index f1ce1247c..102a76e4a 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/DecryptionMaterialsRequest.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/DecryptionMaterialsRequest.java @@ -1,94 +1,93 @@ package com.amazonaws.encryptionsdk.model; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; - public final class DecryptionMaterialsRequest { - private final CryptoAlgorithm algorithm; - private final Map encryptionContext; - private final List encryptedDataKeys; - - private DecryptionMaterialsRequest(Builder b) { - this.algorithm = b.getAlgorithm(); - this.encryptionContext = b.getEncryptionContext(); - this.encryptedDataKeys = b.getEncryptedDataKeys(); + private final CryptoAlgorithm algorithm; + private final Map encryptionContext; + private final List encryptedDataKeys; + + private DecryptionMaterialsRequest(Builder b) { + this.algorithm = b.getAlgorithm(); + this.encryptionContext = b.getEncryptionContext(); + this.encryptedDataKeys = b.getEncryptedDataKeys(); + } + + public CryptoAlgorithm getAlgorithm() { + return algorithm; + } + + public Map getEncryptionContext() { + return encryptionContext; + } + + public List getEncryptedDataKeys() { + return encryptedDataKeys; + } + + public static Builder newBuilder() { + return new Builder(); + } + + public Builder toBuilder() { + return new Builder(this); + } + + public static DecryptionMaterialsRequest fromCiphertextHeaders(CiphertextHeaders headers) { + return newBuilder() + .setAlgorithm(headers.getCryptoAlgoId()) + .setEncryptionContext(headers.getEncryptionContextMap()) + .setEncryptedDataKeys(headers.getEncryptedKeyBlobs()) + .build(); + } + + public static final class Builder { + private CryptoAlgorithm algorithm; + private Map encryptionContext; + private List encryptedDataKeys; + + private Builder(DecryptionMaterialsRequest request) { + this.algorithm = request.getAlgorithm(); + this.encryptionContext = request.getEncryptionContext(); + this.encryptedDataKeys = request.getEncryptedDataKeys(); } - public CryptoAlgorithm getAlgorithm() { - return algorithm; + private Builder() {} + + public DecryptionMaterialsRequest build() { + return new DecryptionMaterialsRequest(this); } - public Map getEncryptionContext() { - return encryptionContext; + public CryptoAlgorithm getAlgorithm() { + return algorithm; } - public List getEncryptedDataKeys() { - return encryptedDataKeys; + public Builder setAlgorithm(CryptoAlgorithm algorithm) { + this.algorithm = algorithm; + return this; } - public static Builder newBuilder() { - return new Builder(); + public Map getEncryptionContext() { + return encryptionContext; } - public Builder toBuilder() { - return new Builder(this); + public Builder setEncryptionContext(Map encryptionContext) { + this.encryptionContext = Collections.unmodifiableMap(new HashMap<>(encryptionContext)); + return this; } - public static DecryptionMaterialsRequest fromCiphertextHeaders(CiphertextHeaders headers) { - return newBuilder() - .setAlgorithm(headers.getCryptoAlgoId()) - .setEncryptionContext(headers.getEncryptionContextMap()) - .setEncryptedDataKeys(headers.getEncryptedKeyBlobs()) - .build(); + public List getEncryptedDataKeys() { + return encryptedDataKeys; } - public static final class Builder { - private CryptoAlgorithm algorithm; - private Map encryptionContext; - private List encryptedDataKeys; - - private Builder(DecryptionMaterialsRequest request) { - this.algorithm = request.getAlgorithm(); - this.encryptionContext = request.getEncryptionContext(); - this.encryptedDataKeys = request.getEncryptedDataKeys(); - } - - private Builder() {} - - public DecryptionMaterialsRequest build() { - return new DecryptionMaterialsRequest(this); - } - - public CryptoAlgorithm getAlgorithm() { - return algorithm; - } - - public Builder setAlgorithm(CryptoAlgorithm algorithm) { - this.algorithm = algorithm; - return this; - } - - public Map getEncryptionContext() { - return encryptionContext; - } - - public Builder setEncryptionContext(Map encryptionContext) { - this.encryptionContext = Collections.unmodifiableMap(new HashMap<>(encryptionContext)); - return this; - } - - public List getEncryptedDataKeys() { - return encryptedDataKeys; - } - - public Builder setEncryptedDataKeys(List encryptedDataKeys) { - this.encryptedDataKeys = Collections.unmodifiableList(new ArrayList<>(encryptedDataKeys)); - return this; - } + public Builder setEncryptedDataKeys(List encryptedDataKeys) { + this.encryptedDataKeys = Collections.unmodifiableList(new ArrayList<>(encryptedDataKeys)); + return this; } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionCompletionListener.java b/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionCompletionListener.java index ff1247feb..9720e84b7 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionCompletionListener.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionCompletionListener.java @@ -2,11 +2,11 @@ @FunctionalInterface public interface EncryptionCompletionListener { - /** - * Invoked upon encryption completion; MaterialsManagers that need to know the size of the plaintext (e.g. to - * enforce caching policies) can make use of this. - * - * @param plaintextBytes Total number of plaintext bytes encrypted - */ - void onEncryptDone(long plaintextBytes); + /** + * Invoked upon encryption completion; MaterialsManagers that need to know the size of the + * plaintext (e.g. to enforce caching policies) can make use of this. + * + * @param plaintextBytes Total number of plaintext bytes encrypted + */ + void onEncryptDone(long plaintextBytes); } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionMaterials.java b/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionMaterials.java index 1a40d7c36..2d0156482 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionMaterials.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionMaterials.java @@ -1,6 +1,7 @@ package com.amazonaws.encryptionsdk.model; -import javax.crypto.SecretKey; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.MasterKey; import java.security.PrivateKey; import java.util.ArrayList; import java.util.Collections; @@ -8,181 +9,184 @@ import java.util.List; import java.util.Map; import java.util.Objects; - -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.MasterKey; +import javax.crypto.SecretKey; /** * Contains the cryptographic materials needed for an encryption operation. * - * @see com.amazonaws.encryptionsdk.CryptoMaterialsManager#getMaterialsForEncrypt(EncryptionMaterialsRequest) + * @see + * com.amazonaws.encryptionsdk.CryptoMaterialsManager#getMaterialsForEncrypt(EncryptionMaterialsRequest) */ public final class EncryptionMaterials { - private final CryptoAlgorithm algorithm; - private final Map encryptionContext; - private final List encryptedDataKeys; - private final SecretKey cleartextDataKey; - private final PrivateKey trailingSignatureKey; - private final List masterKeys; - - private EncryptionMaterials(Builder b) { - this.algorithm = b.algorithm; - this.encryptionContext = b.encryptionContext; - this.encryptedDataKeys = b.encryptedDataKeys; - this.cleartextDataKey = b.cleartextDataKey; - this.trailingSignatureKey = b.trailingSignatureKey; - this.masterKeys = b.getMasterKeys(); + private final CryptoAlgorithm algorithm; + private final Map encryptionContext; + private final List encryptedDataKeys; + private final SecretKey cleartextDataKey; + private final PrivateKey trailingSignatureKey; + private final List masterKeys; + + private EncryptionMaterials(Builder b) { + this.algorithm = b.algorithm; + this.encryptionContext = b.encryptionContext; + this.encryptedDataKeys = b.encryptedDataKeys; + this.cleartextDataKey = b.cleartextDataKey; + this.trailingSignatureKey = b.trailingSignatureKey; + this.masterKeys = b.getMasterKeys(); + } + + public Builder toBuilder() { + return new Builder(this); + } + + public static Builder newBuilder() { + return new Builder(); + } + + /** + * The algorithm to use for this encryption operation. Must match the algorithm in + * EncryptionMaterialsRequest, if that algorithm was non-null. + */ + public CryptoAlgorithm getAlgorithm() { + return algorithm; + } + + /** + * The encryption context to use for the encryption operation. Does not need to match the + * EncryptionMaterialsRequest. + */ + public Map getEncryptionContext() { + return encryptionContext; + } + + /** The KeyBlobs to serialize (in cleartext) into the encrypted message. */ + public List getEncryptedDataKeys() { + return encryptedDataKeys; + } + + /** + * The cleartext data key to use for encrypting this message. Note that this is the data key prior + * to any key derivation required by the crypto algorithm in use. + */ + public SecretKey getCleartextDataKey() { + return cleartextDataKey; + } + + /** + * The private key to be used to sign the message trailer. Must be present if any only if required + * by the crypto algorithm, and the key type must likewise match the algorithm in use. + * + *

Note that it's the {@link com.amazonaws.encryptionsdk.CryptoMaterialsManager}'s + * responsibility to find a place to put the public key; typically, this will be in the encryption + * context, to improve cross-compatibility, but this is not a strict requirement. + */ + public PrivateKey getTrailingSignatureKey() { + return trailingSignatureKey; + } + + /** Contains a list of all MasterKeys that could decrypt this message. */ + public List getMasterKeys() { + return masterKeys; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EncryptionMaterials that = (EncryptionMaterials) o; + return algorithm == that.algorithm + && Objects.equals(encryptionContext, that.encryptionContext) + && Objects.equals(encryptedDataKeys, that.encryptedDataKeys) + && Objects.equals(cleartextDataKey, that.cleartextDataKey) + && Objects.equals(trailingSignatureKey, that.trailingSignatureKey) + && Objects.equals(masterKeys, that.masterKeys); + } + + @Override + public int hashCode() { + return Objects.hash( + algorithm, + encryptionContext, + encryptedDataKeys, + cleartextDataKey, + trailingSignatureKey, + masterKeys); + } + + public static class Builder { + private CryptoAlgorithm algorithm; + private Map encryptionContext = Collections.emptyMap(); + private List encryptedDataKeys = null; + private SecretKey cleartextDataKey; + private PrivateKey trailingSignatureKey; + private List masterKeys = Collections.emptyList(); + + private Builder() {} + + private Builder(EncryptionMaterials r) { + algorithm = r.algorithm; + encryptionContext = r.encryptionContext; + encryptedDataKeys = r.encryptedDataKeys; + cleartextDataKey = r.cleartextDataKey; + trailingSignatureKey = r.trailingSignatureKey; + setMasterKeys(r.masterKeys); } - public Builder toBuilder() { - return new Builder(this); + public EncryptionMaterials build() { + return new EncryptionMaterials(this); } - public static Builder newBuilder() { - return new Builder(); + public CryptoAlgorithm getAlgorithm() { + return algorithm; } - /** - * The algorithm to use for this encryption operation. Must match the algorithm in EncryptionMaterialsRequest, if that - * algorithm was non-null. - */ - public CryptoAlgorithm getAlgorithm() { - return algorithm; + public Builder setAlgorithm(CryptoAlgorithm algorithm) { + this.algorithm = algorithm; + return this; } - /** - * The encryption context to use for the encryption operation. Does not need to match the EncryptionMaterialsRequest. - */ public Map getEncryptionContext() { - return encryptionContext; + return encryptionContext; + } + + public Builder setEncryptionContext(Map encryptionContext) { + this.encryptionContext = Collections.unmodifiableMap(new HashMap<>(encryptionContext)); + return this; } - /** - * The KeyBlobs to serialize (in cleartext) into the encrypted message. - */ public List getEncryptedDataKeys() { - return encryptedDataKeys; + return encryptedDataKeys; + } + + public Builder setEncryptedDataKeys(List encryptedDataKeys) { + this.encryptedDataKeys = Collections.unmodifiableList(new ArrayList<>(encryptedDataKeys)); + return this; } - /** - * The cleartext data key to use for encrypting this message. Note that this is the data key prior to - * any key derivation required by the crypto algorithm in use. - */ public SecretKey getCleartextDataKey() { - return cleartextDataKey; + return cleartextDataKey; } - /** - * The private key to be used to sign the message trailer. Must be present if any only if required by the - * crypto algorithm, and the key type must likewise match the algorithm in use. - * - * Note that it's the {@link com.amazonaws.encryptionsdk.CryptoMaterialsManager}'s responsibility to find a place - * to put the public key; typically, this will be in the encryption context, to improve cross-compatibility, - * but this is not a strict requirement. - */ - public PrivateKey getTrailingSignatureKey() { - return trailingSignatureKey; + public Builder setCleartextDataKey(SecretKey cleartextDataKey) { + this.cleartextDataKey = cleartextDataKey; + return this; } - /** - * Contains a list of all MasterKeys that could decrypt this message. - */ - public List getMasterKeys() { - return masterKeys; + public PrivateKey getTrailingSignatureKey() { + return trailingSignatureKey; } - @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - EncryptionMaterials that = (EncryptionMaterials) o; - return algorithm == that.algorithm && - Objects.equals(encryptionContext, that.encryptionContext) && - Objects.equals(encryptedDataKeys, that.encryptedDataKeys) && - Objects.equals(cleartextDataKey, that.cleartextDataKey) && - Objects.equals(trailingSignatureKey, that.trailingSignatureKey) && - Objects.equals(masterKeys, that.masterKeys); + public Builder setTrailingSignatureKey(PrivateKey trailingSignatureKey) { + this.trailingSignatureKey = trailingSignatureKey; + return this; } - @Override public int hashCode() { - return Objects.hash(algorithm, encryptionContext, encryptedDataKeys, cleartextDataKey, trailingSignatureKey, - masterKeys); + public List getMasterKeys() { + return masterKeys; } - public static class Builder { - private CryptoAlgorithm algorithm; - private Map encryptionContext = Collections.emptyMap(); - private List encryptedDataKeys = null; - private SecretKey cleartextDataKey; - private PrivateKey trailingSignatureKey; - private List masterKeys = Collections.emptyList(); - - private Builder() {} - - private Builder(EncryptionMaterials r) { - algorithm = r.algorithm; - encryptionContext = r.encryptionContext; - encryptedDataKeys = r.encryptedDataKeys; - cleartextDataKey = r.cleartextDataKey; - trailingSignatureKey = r.trailingSignatureKey; - setMasterKeys(r.masterKeys); - } - - public EncryptionMaterials build() { - return new EncryptionMaterials(this); - } - - public CryptoAlgorithm getAlgorithm() { - return algorithm; - } - - public Builder setAlgorithm(CryptoAlgorithm algorithm) { - this.algorithm = algorithm; - return this; - } - - public Map getEncryptionContext() { - return encryptionContext; - } - - public Builder setEncryptionContext(Map encryptionContext) { - this.encryptionContext = Collections.unmodifiableMap(new HashMap<>(encryptionContext)); - return this; - } - - public List getEncryptedDataKeys() { - return encryptedDataKeys; - } - - public Builder setEncryptedDataKeys(List encryptedDataKeys) { - this.encryptedDataKeys = Collections.unmodifiableList(new ArrayList<>(encryptedDataKeys)); - return this; - } - - public SecretKey getCleartextDataKey() { - return cleartextDataKey; - } - - public Builder setCleartextDataKey(SecretKey cleartextDataKey) { - this.cleartextDataKey = cleartextDataKey; - return this; - } - - public PrivateKey getTrailingSignatureKey() { - return trailingSignatureKey; - } - - public Builder setTrailingSignatureKey(PrivateKey trailingSignatureKey) { - this.trailingSignatureKey = trailingSignatureKey; - return this; - } - - public List getMasterKeys() { - return masterKeys; - } - - public Builder setMasterKeys(List masterKeys) { - this.masterKeys = Collections.unmodifiableList(new ArrayList<>(masterKeys)); - return this; - } + public Builder setMasterKeys(List masterKeys) { + this.masterKeys = Collections.unmodifiableList(new ArrayList<>(masterKeys)); + return this; } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionMaterialsRequest.java b/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionMaterialsRequest.java index a13a0f3e4..28a3243fc 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionMaterialsRequest.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/EncryptionMaterialsRequest.java @@ -3,188 +3,186 @@ package com.amazonaws.encryptionsdk.model; +import com.amazonaws.encryptionsdk.CommitmentPolicy; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; -import com.amazonaws.encryptionsdk.CommitmentPolicy; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; - -import com.amazonaws.encryptionsdk.CryptoAlgorithm; - /** * Contains the contextual information needed to prepare an encryption operation. * - * @see com.amazonaws.encryptionsdk.CryptoMaterialsManager#getMaterialsForEncrypt(EncryptionMaterialsRequest) + * @see + * com.amazonaws.encryptionsdk.CryptoMaterialsManager#getMaterialsForEncrypt(EncryptionMaterialsRequest) */ public final class EncryptionMaterialsRequest { - private final Map context; - private final CryptoAlgorithm requestedAlgorithm; - private final long plaintextSize; - private final byte[] plaintext; - private final CommitmentPolicy commitmentPolicy; - - private EncryptionMaterialsRequest(Builder builder) { - this.context = builder.context; - this.requestedAlgorithm = builder.requestedAlgorithm; - this.plaintextSize = builder.plaintextSize; - this.plaintext = builder.plaintext; - - if (builder.commitmentPolicy == null) { - throw new IllegalArgumentException("Cannot create EncryptionMaterialRequest without a " + - "CommitmentPolicy specified."); - } - this.commitmentPolicy = builder.commitmentPolicy; + private final Map context; + private final CryptoAlgorithm requestedAlgorithm; + private final long plaintextSize; + private final byte[] plaintext; + private final CommitmentPolicy commitmentPolicy; + + private EncryptionMaterialsRequest(Builder builder) { + this.context = builder.context; + this.requestedAlgorithm = builder.requestedAlgorithm; + this.plaintextSize = builder.plaintextSize; + this.plaintext = builder.plaintext; + + if (builder.commitmentPolicy == null) { + throw new IllegalArgumentException( + "Cannot create EncryptionMaterialRequest without a " + "CommitmentPolicy specified."); + } + this.commitmentPolicy = builder.commitmentPolicy; + } + + /** @return the encryption context (possibly an empty map) */ + public Map getContext() { + return context; + } + + /** + * @return If a specific encryption algorithm was requested by calling {@link + * com.amazonaws.encryptionsdk.AwsCrypto#setEncryptionAlgorithm(CryptoAlgorithm)}, the + * algorithm requested. Otherwise, returns null. + */ + public CryptoAlgorithm getRequestedAlgorithm() { + return requestedAlgorithm; + } + + /** @return The size of the plaintext if known, or -1 otherwise */ + public long getPlaintextSize() { + return plaintextSize; + } + + /** + * @return The entire input plaintext, if available (and not streaming). Note that for performance + * reason this is not a copy of the plaintext; you should never modify this buffer, + * lest the actual data being encrypted be modified. If the input plaintext is unavailable, + * this will be null. + */ + @SuppressFBWarnings("EI_EXPOSE_REP") + public byte[] getPlaintext() { + return plaintext; + } + + public CommitmentPolicy getCommitmentPolicy() { + return commitmentPolicy; + } + + public Builder toBuilder() { + return new Builder(this); + } + + public static Builder newBuilder() { + return new Builder(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EncryptionMaterialsRequest request = (EncryptionMaterialsRequest) o; + return plaintextSize == request.plaintextSize + && Objects.equals(context, request.context) + && requestedAlgorithm == request.requestedAlgorithm + && Arrays.equals(plaintext, request.plaintext); + } + + @Override + public int hashCode() { + return Objects.hash(context, requestedAlgorithm, plaintextSize, plaintext); + } + + public static class Builder { + private Map context = Collections.emptyMap(); + private CryptoAlgorithm requestedAlgorithm = null; + private long plaintextSize = -1; + private byte[] plaintext = null; + private CommitmentPolicy commitmentPolicy = null; + + private Builder() {} + + private Builder(EncryptionMaterialsRequest request) { + this.context = request.getContext(); + this.requestedAlgorithm = request.getRequestedAlgorithm(); + this.plaintextSize = request.getPlaintextSize(); + this.plaintext = request.getPlaintext(); + this.commitmentPolicy = request.getCommitmentPolicy(); + } + + public EncryptionMaterialsRequest build() { + return new EncryptionMaterialsRequest(this); } - /** - * @return the encryption context (possibly an empty map) - */ public Map getContext() { - return context; + return context; + } + + public Builder setContext(Map context) { + this.context = Collections.unmodifiableMap(new HashMap<>(context)); + return this; } - /** - * @return If a specific encryption algorithm was requested by calling - * {@link com.amazonaws.encryptionsdk.AwsCrypto#setEncryptionAlgorithm(CryptoAlgorithm)}, the algorithm requested. - * Otherwise, returns null. - */ public CryptoAlgorithm getRequestedAlgorithm() { - return requestedAlgorithm; + return requestedAlgorithm; } - /** - * @return The size of the plaintext if known, or -1 otherwise - */ - public long getPlaintextSize() { - return plaintextSize; + public Builder setRequestedAlgorithm(CryptoAlgorithm requestedAlgorithm) { + this.requestedAlgorithm = requestedAlgorithm; + return this; } - /** - * @return The entire input plaintext, if available (and not streaming). Note that for performance reason this is - * not a copy of the plaintext; you should never modify this buffer, lest the actual data being encrypted be - * modified. If the input plaintext is unavailable, this will be null. - */ - @SuppressFBWarnings("EI_EXPOSE_REP") - public byte[] getPlaintext() { - return plaintext; + public long getPlaintextSize() { + return plaintextSize; } - public CommitmentPolicy getCommitmentPolicy() { return commitmentPolicy; } + public Builder setPlaintextSize(long plaintextSize) { + if (plaintextSize < -1) { + throw new IllegalArgumentException("Bad plaintext size"); + } - public Builder toBuilder() { - return new Builder(this); + this.plaintextSize = plaintextSize; + return this; } - public static Builder newBuilder() { - return new Builder(); + /** + * Please note that this does not make a defensive copy of the plaintext and so any + * modifications made to the backing array will be reflected in this Builder. + */ + @SuppressFBWarnings("EI_EXPOSE_REP") + public byte[] getPlaintext() { + return plaintext; } - @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - EncryptionMaterialsRequest request = (EncryptionMaterialsRequest) o; - return plaintextSize == request.plaintextSize && - Objects.equals(context, request.context) && - requestedAlgorithm == request.requestedAlgorithm && - Arrays.equals(plaintext, request.plaintext); + /** + * Sets the plaintext field of the request. + * + *

Please note that this does not make a defensive copy of the plaintext and so any + * modifications made to the backing array will be reflected in this Builder. + * + *

This method implicitly sets plaintext size as well. + */ + @SuppressFBWarnings("EI_EXPOSE_REP") + public Builder setPlaintext(byte[] plaintext) { + this.plaintext = plaintext; + + if (plaintext != null) { + return setPlaintextSize(plaintext.length); + } else { + return setPlaintextSize(-1); + } } - @Override public int hashCode() { - return Objects.hash(context, requestedAlgorithm, plaintextSize, plaintext); + public CommitmentPolicy getCommitmentPolicy() { + return commitmentPolicy; } - public static class Builder { - private Map context = Collections.emptyMap(); - private CryptoAlgorithm requestedAlgorithm = null; - private long plaintextSize = -1; - private byte[] plaintext = null; - private CommitmentPolicy commitmentPolicy = null; - - private Builder() { - - } - - private Builder(EncryptionMaterialsRequest request) { - this.context = request.getContext(); - this.requestedAlgorithm = request.getRequestedAlgorithm(); - this.plaintextSize = request.getPlaintextSize(); - this.plaintext = request.getPlaintext(); - this.commitmentPolicy = request.getCommitmentPolicy(); - } - - public EncryptionMaterialsRequest build() { - return new EncryptionMaterialsRequest(this); - } - - public Map getContext() { - return context; - } - - public Builder setContext(Map context) { - this.context = Collections.unmodifiableMap(new HashMap<>(context)); - return this; - } - - public CryptoAlgorithm getRequestedAlgorithm() { - return requestedAlgorithm; - } - - public Builder setRequestedAlgorithm(CryptoAlgorithm requestedAlgorithm) { - this.requestedAlgorithm = requestedAlgorithm; - return this; - } - - public long getPlaintextSize() { - return plaintextSize; - } - - public Builder setPlaintextSize(long plaintextSize) { - if (plaintextSize < -1) { - throw new IllegalArgumentException("Bad plaintext size"); - } - - this.plaintextSize = plaintextSize; - return this; - } - - /** - * Please note that this does not make a defensive copy of the plaintext and so any - * modifications made to the backing array will be reflected in this Builder. - */ - @SuppressFBWarnings("EI_EXPOSE_REP") - public byte[] getPlaintext() { - return plaintext; - } - - /** - * Sets the plaintext field of the request. - * - * Please note that this does not make a defensive copy of the plaintext and so any - * modifications made to the backing array will be reflected in this Builder. - * - * This method implicitly sets plaintext size as well. - */ - @SuppressFBWarnings("EI_EXPOSE_REP") - public Builder setPlaintext(byte[] plaintext) { - this.plaintext = plaintext; - - if (plaintext != null) { - return setPlaintextSize(plaintext.length); - } else { - return setPlaintextSize(-1); - } - } - - public CommitmentPolicy getCommitmentPolicy() { - return commitmentPolicy; - } - - public Builder setCommitmentPolicy(CommitmentPolicy commitmentPolicy) { - this.commitmentPolicy = commitmentPolicy; - return this; - } + public Builder setCommitmentPolicy(CommitmentPolicy commitmentPolicy) { + this.commitmentPolicy = commitmentPolicy; + return this; } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/KeyBlob.java b/src/main/java/com/amazonaws/encryptionsdk/model/KeyBlob.java index c44fd2f8f..5b21dc73c 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/KeyBlob.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/KeyBlob.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,692 +13,652 @@ package com.amazonaws.encryptionsdk.model; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; - import com.amazonaws.encryptionsdk.EncryptedDataKey; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.ParseException; import com.amazonaws.encryptionsdk.internal.Constants; import com.amazonaws.encryptionsdk.internal.PrimitivesParser; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; /** - * This class implements the format of the key blob. The format contains the - * following fields in order: + * This class implements the format of the key blob. The format contains the following fields in + * order: + * *

    - *
  1. - * length of key provider
  2. - *
  3. - * key provider
  4. - *
  5. - * length of key provider info
  6. - *
  7. - * key provider info
  8. - *
  9. - * length of encrypted key
  10. - *
  11. - * encrypted key
  12. + *
  13. length of key provider + *
  14. key provider + *
  15. length of key provider info + *
  16. key provider info + *
  17. length of encrypted key + *
  18. encrypted key *
*/ -//@ nullable_by_default +// @ nullable_by_default public final class KeyBlob implements EncryptedDataKey { - private int keyProviderIdLen_ = -1; //@ in providerId; - private byte[] keyProviderId_; //@ in providerId; - private int keyProviderInfoLen_ = -1; //@ in providerInformation; - private byte[] keyProviderInfo_; //@ in providerInformation; - private int encryptedKeyLen_ = -1; //@ in encryptedDataKey; - private byte[] encryptedKey_; //@ in encryptedDataKey; - - //@ private invariant keyProviderIdLen_ <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ private invariant keyProviderInfoLen_ <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ private invariant encryptedKeyLen_ <= Constants.UNSIGNED_SHORT_MAX_VAL; - - //@// KeyBlob implements EncryptedDataKey, which defines three model fields. - //@// For a KeyBlob, these model fields correspond directly to some underlying - //@// Java fields, as expressed by the following "represents" declarations: - //@ private represents providerId = keyProviderId_; - //@ private represents providerInformation = keyProviderInfo_; - //@ private represents encryptedDataKey = encryptedKey_; - - //@// As mentioned in EncryptedDataKey, deserialization goes through some - //@// incomplete intermediate states. The ghost field "deserializing" keeps - //@// track of these states: - //@ private ghost int deserializing; - //@ private invariant 0 <= deserializing && deserializing < 4; - //@// The abstract "isDeserializing", defined in EncryptedDataKey, is represented - //@// as "true" whenever "deserializing" is non-0. - //@ private represents isDeserializing = deserializing != 0; - - //@// The fields of KeyBlob come in pairs, for example, "keyProviderId_" and - //@// "keyProviderIdLen_". Generally, the latter stores the length of the former. - //@// But this is not always so. For one, if the former is "null", then the latter - //@// is -1. Also, this relationship the two fields does not hold in one of the - //@// incomplete intermediate deserialization states. Therefore, the invariants - //@// about these fields are as follows: - - //@ private invariant deserializing == 1 || keyProviderIdLen_ == (keyProviderId_ == null ? -1 : keyProviderId_.length); - //@ private invariant deserializing == 2 || keyProviderInfoLen_ == (keyProviderInfo_ == null ? -1 : keyProviderInfo_.length); - //@ private invariant deserializing == 3 || encryptedKeyLen_ == (encryptedKey_ == null ? -1 : encryptedKey_.length); - - //@// In the incomplete intermediate states, other specific properties hold about the - //@// fields, as expressed in the following invariants: - - //@ private invariant deserializing == 1 ==> 0 <= keyProviderIdLen_ && keyProviderId_ == null; - //@ private invariant deserializing == 2 ==> 0 <= keyProviderIdLen_ && 0 <= keyProviderInfoLen_ && keyProviderInfo_ == null; - //@ private invariant deserializing == 3 ==> 0 <= keyProviderIdLen_ && 0 <= keyProviderInfoLen_ && 0 <= encryptedKeyLen_ && encryptedKey_ == null; - - //@// It is by querying the "isComplete()" method that a caller finds out if the - //@// deserialization is only partially done or is complete. The "isComplete()" - //@// method is defined later on and returns the value of the field "isComplete_". - //@// If postcondition of "deserialize()" and the following public invariant about - //@// "isComplete_" tell a client that the 3 abstract properties of the class have - //@// been initialized. Note that this invariant (and, indeed, the "isComplete()" - //@// method) does not tell a client anything useful unless "deserialize()" has been - //@// called. For example, if the 3 abstract properties of a KeyBlob have been - //@// initialized using the "set..." methods, then the result value of "isComplete()" - //@// is meaningless. - //@ spec_public - private boolean isComplete_ = false; - //@ public invariant isComplete_ && !isDeserializing ==> providerId != null && providerInformation != null && encryptedDataKey != null; - - /** - * Default constructor. - */ - //@ public normal_behavior - //@ ensures providerId == null && providerInformation == null && encryptedDataKey == null; - //@ ensures !isDeserializing; - //@ pure - public KeyBlob() { + private int keyProviderIdLen_ = -1; // @ in providerId; + private byte[] keyProviderId_; // @ in providerId; + private int keyProviderInfoLen_ = -1; // @ in providerInformation; + private byte[] keyProviderInfo_; // @ in providerInformation; + private int encryptedKeyLen_ = -1; // @ in encryptedDataKey; + private byte[] encryptedKey_; // @ in encryptedDataKey; + + // @ private invariant keyProviderIdLen_ <= Constants.UNSIGNED_SHORT_MAX_VAL; + // @ private invariant keyProviderInfoLen_ <= Constants.UNSIGNED_SHORT_MAX_VAL; + // @ private invariant encryptedKeyLen_ <= Constants.UNSIGNED_SHORT_MAX_VAL; + + // @// KeyBlob implements EncryptedDataKey, which defines three model fields. + // @// For a KeyBlob, these model fields correspond directly to some underlying + // @// Java fields, as expressed by the following "represents" declarations: + // @ private represents providerId = keyProviderId_; + // @ private represents providerInformation = keyProviderInfo_; + // @ private represents encryptedDataKey = encryptedKey_; + + // @// As mentioned in EncryptedDataKey, deserialization goes through some + // @// incomplete intermediate states. The ghost field "deserializing" keeps + // @// track of these states: + // @ private ghost int deserializing; + // @ private invariant 0 <= deserializing && deserializing < 4; + // @// The abstract "isDeserializing", defined in EncryptedDataKey, is represented + // @// as "true" whenever "deserializing" is non-0. + // @ private represents isDeserializing = deserializing != 0; + + // @// The fields of KeyBlob come in pairs, for example, "keyProviderId_" and + // @// "keyProviderIdLen_". Generally, the latter stores the length of the former. + // @// But this is not always so. For one, if the former is "null", then the latter + // @// is -1. Also, this relationship the two fields does not hold in one of the + // @// incomplete intermediate deserialization states. Therefore, the invariants + // @// about these fields are as follows: + + // @ private invariant deserializing == 1 || keyProviderIdLen_ == (keyProviderId_ == null ? -1 : + // keyProviderId_.length); + // @ private invariant deserializing == 2 || keyProviderInfoLen_ == (keyProviderInfo_ == null ? -1 + // : keyProviderInfo_.length); + // @ private invariant deserializing == 3 || encryptedKeyLen_ == (encryptedKey_ == null ? -1 : + // encryptedKey_.length); + + // @// In the incomplete intermediate states, other specific properties hold about the + // @// fields, as expressed in the following invariants: + + // @ private invariant deserializing == 1 ==> 0 <= keyProviderIdLen_ && keyProviderId_ == null; + // @ private invariant deserializing == 2 ==> 0 <= keyProviderIdLen_ && 0 <= keyProviderInfoLen_ + // && keyProviderInfo_ == null; + // @ private invariant deserializing == 3 ==> 0 <= keyProviderIdLen_ && 0 <= keyProviderInfoLen_ + // && 0 <= encryptedKeyLen_ && encryptedKey_ == null; + + // @// It is by querying the "isComplete()" method that a caller finds out if the + // @// deserialization is only partially done or is complete. The "isComplete()" + // @// method is defined later on and returns the value of the field "isComplete_". + // @// If postcondition of "deserialize()" and the following public invariant about + // @// "isComplete_" tell a client that the 3 abstract properties of the class have + // @// been initialized. Note that this invariant (and, indeed, the "isComplete()" + // @// method) does not tell a client anything useful unless "deserialize()" has been + // @// called. For example, if the 3 abstract properties of a KeyBlob have been + // @// initialized using the "set..." methods, then the result value of "isComplete()" + // @// is meaningless. + // @ spec_public + private boolean isComplete_ = false; + // @ public invariant isComplete_ && !isDeserializing ==> providerId != null && + // providerInformation != null && encryptedDataKey != null; + + /** Default constructor. */ + // @ public normal_behavior + // @ ensures providerId == null && providerInformation == null && encryptedDataKey == null; + // @ ensures !isDeserializing; + // @ pure + public KeyBlob() {} + + /** + * Construct a key blob using the provided key, key provider identifier, and key provider + * information. + * + * @param keyProviderId the key provider identifier string. + * @param keyProviderInfo the bytes containing the key provider info. + * @param encryptedDataKey the encrypted bytes of the data key. + */ + // @ public normal_behavior + // @ requires keyProviderId != null && EncryptedDataKey.s2ba(keyProviderId).length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ requires keyProviderInfo != null && keyProviderInfo.length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ requires encryptedDataKey != null && encryptedDataKey.length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ ensures \fresh(providerId); + // @ ensures Arrays.equalArrays(providerId, EncryptedDataKey.s2ba(keyProviderId)); + // @ ensures \fresh(providerInformation); + // @ ensures Arrays.equalArrays(providerInformation, keyProviderInfo); + // @ ensures \fresh(this.encryptedDataKey); + // @ ensures Arrays.equalArrays(this.encryptedDataKey, encryptedDataKey); + // @ ensures !isDeserializing; + // @ also + // @ public exceptional_behavior + // @ requires keyProviderId != null && keyProviderInfo != null && encryptedDataKey != null; + // @ requires Constants.UNSIGNED_SHORT_MAX_VAL < EncryptedDataKey.s2ba(keyProviderId).length || + // Constants.UNSIGNED_SHORT_MAX_VAL < keyProviderInfo.length || Constants.UNSIGNED_SHORT_MAX_VAL < + // encryptedDataKey.length; + // @ signals_only AwsCryptoException; + // @ pure + public KeyBlob( + final String keyProviderId, final byte[] keyProviderInfo, final byte[] encryptedDataKey) { + setEncryptedDataKey(encryptedDataKey); + setKeyProviderId(keyProviderId); + setKeyProviderInfo(keyProviderInfo); + } + + // @ public normal_behavior + // @ requires edk != null && !edk.isDeserializing; + // @ requires edk.providerId != null && EncryptedDataKey.ba2s2ba(edk.providerId).length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ requires edk.providerInformation != null && edk.providerInformation.length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ requires edk.encryptedDataKey != null && edk.encryptedDataKey.length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ ensures \fresh(providerId); + // @ ensures Arrays.equalArrays(providerId, EncryptedDataKey.ba2s2ba(edk.providerId)); + // @ ensures \fresh(providerInformation); + // @ ensures Arrays.equalArrays(providerInformation, edk.providerInformation); + // @ ensures \fresh(encryptedDataKey); + // @ ensures Arrays.equalArrays(encryptedDataKey, edk.encryptedDataKey); + // @ ensures !isDeserializing; + // @ also + // @ public exceptional_behavior + // @ requires edk != null && !edk.isDeserializing; + // @ requires edk.providerId != null && edk.providerInformation != null && edk.encryptedDataKey + // != null; + // @ requires Constants.UNSIGNED_SHORT_MAX_VAL < EncryptedDataKey.ba2s2ba(edk.providerId).length + // || Constants.UNSIGNED_SHORT_MAX_VAL < edk.providerInformation.length || + // Constants.UNSIGNED_SHORT_MAX_VAL < edk.encryptedDataKey.length; + // @ signals_only AwsCryptoException; + // @ pure + public KeyBlob(final EncryptedDataKey edk) { + setEncryptedDataKey(edk.getEncryptedDataKey()); + String s = edk.getProviderId(); + // @ set EncryptedDataKey.lemma_s2ba_depends_only_string_contents_only(s, + // EncryptedDataKey.ba2s(edk.providerId)); + setKeyProviderId(s); + setKeyProviderInfo(edk.getProviderInformation()); + } + + /** + * Parse the key provider identifier length in the provided bytes. It looks for 2 bytes + * representing a short primitive type in the provided bytes starting at the specified off. + * + *

If successful, it returns the size of the parsed bytes which is the size of the short + * primitive type. On failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the size of the short primitive. + * @throws ParseException if there are not sufficient bytes to parse the identifier length. + */ + // @ private normal_behavior + // @ requires deserializing == 0 && keyProviderId_ == null; + // @ requires b != null && 0 <= off && off <= b.length - Short.BYTES; + // @ assignable keyProviderIdLen_, deserializing, isDeserializing; + // @ ensures \result == Short.BYTES && deserializing == 1; + // @ also + // @ private exceptional_behavior + // @ requires keyProviderId_ == null; + // @ requires b != null && 0 <= off && b.length - Short.BYTES < off; + // @ assignable \nothing; + // @ signals_only ParseException; + private int parseKeyProviderIdLen(final byte[] b, final int off) throws ParseException { + keyProviderIdLen_ = PrimitivesParser.parseUnsignedShort(b, off); + // @ set deserializing = 1; + return Short.SIZE / Byte.SIZE; + } + + /** + * Parse the key provider identifier in the provided bytes. It looks for bytes of size defined by + * the key provider identifier length in the provided bytes starting at the specified off. + * + *

If successful, it returns the size of the parsed bytes which is the key provider identifier + * length. On failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the key provider identifier length. + * @throws ParseException if there are not sufficient bytes to parse the identifier. + */ + // @ private normal_behavior + // @ requires deserializing == 1 && b != null && 0 <= off && off <= b.length; + // @ requires keyProviderIdLen_ <= b.length - off; + // @ assignable keyProviderId_, deserializing, isDeserializing; + // @ ensures \result == keyProviderIdLen_ && deserializing == 0; + // @ ensures keyProviderId_ != null && keyProviderId_.length == keyProviderIdLen_; + // @ also + // @ private exceptional_behavior + // @ requires deserializing == 1 && b != null && 0 <= off && off <= b.length; + // @ requires b.length - off < keyProviderIdLen_; + // @ assignable \nothing; + // @ signals_only ParseException; + private int parseKeyProviderId(final byte[] b, final int off) throws ParseException { + final int bytesToParseLen = b.length - off; + if (bytesToParseLen >= keyProviderIdLen_) { + keyProviderId_ = Arrays.copyOfRange(b, off, off + keyProviderIdLen_); + // @ set deserializing = 0; + return keyProviderIdLen_; + } else { + throw new ParseException("Not enough bytes to parse key provider id"); } - - /** - * Construct a key blob using the provided key, key provider identifier, and - * key provider information. - * @param keyProviderId - * the key provider identifier string. - * @param keyProviderInfo - * the bytes containing the key provider info. - * @param encryptedDataKey - * the encrypted bytes of the data key. - */ - //@ public normal_behavior - //@ requires keyProviderId != null && EncryptedDataKey.s2ba(keyProviderId).length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ requires keyProviderInfo != null && keyProviderInfo.length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ requires encryptedDataKey != null && encryptedDataKey.length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ ensures \fresh(providerId); - //@ ensures Arrays.equalArrays(providerId, EncryptedDataKey.s2ba(keyProviderId)); - //@ ensures \fresh(providerInformation); - //@ ensures Arrays.equalArrays(providerInformation, keyProviderInfo); - //@ ensures \fresh(this.encryptedDataKey); - //@ ensures Arrays.equalArrays(this.encryptedDataKey, encryptedDataKey); - //@ ensures !isDeserializing; - //@ also - //@ public exceptional_behavior - //@ requires keyProviderId != null && keyProviderInfo != null && encryptedDataKey != null; - //@ requires Constants.UNSIGNED_SHORT_MAX_VAL < EncryptedDataKey.s2ba(keyProviderId).length || Constants.UNSIGNED_SHORT_MAX_VAL < keyProviderInfo.length || Constants.UNSIGNED_SHORT_MAX_VAL < encryptedDataKey.length; - //@ signals_only AwsCryptoException; - //@ pure - public KeyBlob(final String keyProviderId, final byte[] keyProviderInfo, final byte[] encryptedDataKey) { - setEncryptedDataKey(encryptedDataKey); - setKeyProviderId(keyProviderId); - setKeyProviderInfo(keyProviderInfo); + } + + /** + * Parse the key provider info length in the provided bytes. It looks for 2 bytes representing a + * short primitive type in the provided bytes starting at the specified off. + * + *

If successful, it returns the size of the parsed bytes which is the size of the short + * primitive type. On failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the size of the short primitive type. + * @throws ParseException if there are not sufficient bytes to parse the provider info length. + */ + // @ private normal_behavior + // @ requires deserializing == 0 && 0 <= keyProviderIdLen_ && keyProviderInfo_ == null; + // @ requires b != null && 0 <= off && off <= b.length - Short.BYTES; + // @ assignable keyProviderInfoLen_, deserializing, isDeserializing; + // @ ensures \result == Short.BYTES && deserializing == 2; + // @ also + // @ private exceptional_behavior + // @ requires deserializing == 0 && 0 <= keyProviderIdLen_ && keyProviderInfo_ == null; + // @ requires b != null && 0 <= off && b.length - Short.BYTES < off; + // @ assignable \nothing; + // @ signals_only ParseException; + private int parseKeyProviderInfoLen(final byte[] b, final int off) throws ParseException { + keyProviderInfoLen_ = PrimitivesParser.parseUnsignedShort(b, off); + // @ set deserializing = 2; + return Short.SIZE / Byte.SIZE; + } + + /** + * Parse the key provider info in the provided bytes. It looks for bytes of size defined by the + * key provider info length in the provided bytes starting at the specified off. + * + *

If successful, it returns the size of the parsed bytes which is the key provider info + * length. On failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the key provider info length. + * @throws ParseException if there are not sufficient bytes to parse the provider info. + */ + // @ private normal_behavior + // @ requires deserializing == 2 && b != null && 0 <= off && off <= b.length; + // @ requires keyProviderInfoLen_ <= b.length - off; + // @ assignable keyProviderInfo_, deserializing, isDeserializing; + // @ ensures \result == keyProviderInfoLen_ && deserializing == 0; + // @ ensures keyProviderInfo_ != null && keyProviderInfo_.length == keyProviderInfoLen_; + // @ also + // @ private exceptional_behavior + // @ requires deserializing == 2 && b != null && 0 <= off && off <= b.length; + // @ requires b.length - off < keyProviderInfoLen_; + // @ assignable \nothing; + // @ signals_only ParseException; + private int parseKeyProviderInfo(final byte[] b, final int off) throws ParseException { + final int bytesToParseLen = b.length - off; + if (bytesToParseLen >= keyProviderInfoLen_) { + keyProviderInfo_ = Arrays.copyOfRange(b, off, off + keyProviderInfoLen_); + // @ set deserializing = 0; + return keyProviderInfoLen_; + } else { + throw new ParseException("Not enough bytes to parse key provider info"); } - - //@ public normal_behavior - //@ requires edk != null && !edk.isDeserializing; - //@ requires edk.providerId != null && EncryptedDataKey.ba2s2ba(edk.providerId).length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ requires edk.providerInformation != null && edk.providerInformation.length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ requires edk.encryptedDataKey != null && edk.encryptedDataKey.length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ ensures \fresh(providerId); - //@ ensures Arrays.equalArrays(providerId, EncryptedDataKey.ba2s2ba(edk.providerId)); - //@ ensures \fresh(providerInformation); - //@ ensures Arrays.equalArrays(providerInformation, edk.providerInformation); - //@ ensures \fresh(encryptedDataKey); - //@ ensures Arrays.equalArrays(encryptedDataKey, edk.encryptedDataKey); - //@ ensures !isDeserializing; - //@ also - //@ public exceptional_behavior - //@ requires edk != null && !edk.isDeserializing; - //@ requires edk.providerId != null && edk.providerInformation != null && edk.encryptedDataKey != null; - //@ requires Constants.UNSIGNED_SHORT_MAX_VAL < EncryptedDataKey.ba2s2ba(edk.providerId).length || Constants.UNSIGNED_SHORT_MAX_VAL < edk.providerInformation.length || Constants.UNSIGNED_SHORT_MAX_VAL < edk.encryptedDataKey.length; - //@ signals_only AwsCryptoException; - //@ pure - public KeyBlob(final EncryptedDataKey edk) { - setEncryptedDataKey(edk.getEncryptedDataKey()); - String s = edk.getProviderId(); - //@ set EncryptedDataKey.lemma_s2ba_depends_only_string_contents_only(s, EncryptedDataKey.ba2s(edk.providerId)); - setKeyProviderId(s); - setKeyProviderInfo(edk.getProviderInformation()); + } + + /** + * Parse the key length in the provided bytes. It looks for 2 bytes representing a short primitive + * type in the provided bytes starting at the specified off. + * + *

If successful, it returns the size of the parsed bytes which is the size of the short + * primitive type. On failure, it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the size of the short primitive type. + * @throws ParseException if there are not sufficient bytes to parse the key length. + */ + // @ private normal_behavior + // @ requires deserializing == 0 && 0 <= keyProviderIdLen_ && 0 <= keyProviderInfoLen_ && + // encryptedKey_ == null; + // @ requires b != null && 0 <= off && off <= b.length - Short.BYTES; + // @ assignable encryptedKeyLen_, deserializing, isDeserializing; + // @ ensures \result == Short.BYTES && deserializing == 3; + // @ also + // @ private exceptional_behavior + // @ requires deserializing == 0 && 0 <= keyProviderIdLen_ && 0 <= keyProviderInfoLen_ && + // encryptedKey_ == null; + // @ requires b != null && 0 <= off && b.length - Short.BYTES < off; + // @ assignable \nothing; + // @ signals_only ParseException; + private int parseKeyLen(final byte[] b, final int off) throws ParseException { + encryptedKeyLen_ = PrimitivesParser.parseUnsignedShort(b, off); + // @ set deserializing = 3; + return Short.SIZE / Byte.SIZE; + } + + /** + * Parse the key in the provided bytes. It looks for bytes of size defined by the key length in + * the provided bytes starting at the specified off. + * + *

If successful, it returns the size of the parsed bytes which is the key length. On failure, + * it throws a parse exception. + * + * @param b the byte array to parse. + * @param off the offset in the byte array to use when parsing. + * @return the size of the parsed bytes which is the key length. + * @throws ParseException if there are not sufficient bytes to parse the key. + */ + // @ private normal_behavior + // @ requires deserializing == 3 && b != null && 0 <= off && off <= b.length; + // @ requires encryptedKeyLen_ <= b.length - off; + // @ assignable encryptedKey_, deserializing, isDeserializing; + // @ ensures \result == encryptedKeyLen_ && deserializing == 0; + // @ ensures encryptedKey_ != null && encryptedKey_.length == encryptedKeyLen_; + // @ also + // @ private exceptional_behavior + // @ requires deserializing == 3 && b != null && 0 <= off && off <= b.length; + // @ requires b.length - off < encryptedKeyLen_; + // @ assignable \nothing; + // @ signals_only ParseException; + private int parseKey(final byte[] b, final int off) throws ParseException { + final int bytesToParseLen = b.length - off; + if (bytesToParseLen >= encryptedKeyLen_) { + encryptedKey_ = Arrays.copyOfRange(b, off, off + encryptedKeyLen_); + // @ set deserializing = 0; + return encryptedKeyLen_; + } else { + throw new ParseException("Not enough bytes to parse key"); } - - /** - * Parse the key provider identifier length in the provided bytes. It looks - * for 2 bytes representing a short primitive type in the provided bytes - * starting at the specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the size - * of the short primitive type. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the size of the short - * primitive. - * @throws ParseException - * if there are not sufficient bytes to parse the identifier - * length. - */ - //@ private normal_behavior - //@ requires deserializing == 0 && keyProviderId_ == null; - //@ requires b != null && 0 <= off && off <= b.length - Short.BYTES; - //@ assignable keyProviderIdLen_, deserializing, isDeserializing; - //@ ensures \result == Short.BYTES && deserializing == 1; - //@ also - //@ private exceptional_behavior - //@ requires keyProviderId_ == null; - //@ requires b != null && 0 <= off && b.length - Short.BYTES < off; - //@ assignable \nothing; - //@ signals_only ParseException; - private int parseKeyProviderIdLen(final byte[] b, final int off) throws ParseException { - keyProviderIdLen_ = PrimitivesParser.parseUnsignedShort(b, off); - //@ set deserializing = 1; - return Short.SIZE / Byte.SIZE; + } + + /** + * Deserialize the provided bytes starting at the specified offset to construct an instance of + * this class. + * + *

This method parses the provided bytes for the individual fields in this class. This methods + * also supports partial parsing where not all the bytes required for parsing the fields + * successfully are available. + * + * @param b the byte array to deserialize. + * @param off the offset in the byte array to use for deserialization. + * @return the number of bytes consumed in deserialization. + */ + // @ public normal_behavior + // @ requires b == null; + // @ assignable \nothing; + // @ ensures \result == 0; + // @ also + // @ public normal_behavior + // @ requires !isComplete_; + // @ requires b != null && 0 <= off && off <= b.length; + // @ assignable this.*; + // @ ensures 0 <= \result && \result <= b.length - off; + // @ ensures isComplete_ ==> !isDeserializing; + public int deserialize(final byte[] b, final int off) { + if (b == null) { + return 0; } - /** - * Parse the key provider identifier in the provided bytes. It looks - * for bytes of size defined by the key provider identifier length in the - * provided bytes starting at the specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the key - * provider identifier length. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the key provider identifier - * length. - * @throws ParseException - * if there are not sufficient bytes to parse the identifier. - */ - //@ private normal_behavior - //@ requires deserializing == 1 && b != null && 0 <= off && off <= b.length; - //@ requires keyProviderIdLen_ <= b.length - off; - //@ assignable keyProviderId_, deserializing, isDeserializing; - //@ ensures \result == keyProviderIdLen_ && deserializing == 0; - //@ ensures keyProviderId_ != null && keyProviderId_.length == keyProviderIdLen_; - //@ also - //@ private exceptional_behavior - //@ requires deserializing == 1 && b != null && 0 <= off && off <= b.length; - //@ requires b.length - off < keyProviderIdLen_; - //@ assignable \nothing; - //@ signals_only ParseException; - private int parseKeyProviderId(final byte[] b, final int off) throws ParseException { - final int bytesToParseLen = b.length - off; - if (bytesToParseLen >= keyProviderIdLen_) { - keyProviderId_ = Arrays.copyOfRange(b, off, off + keyProviderIdLen_); - //@ set deserializing = 0; - return keyProviderIdLen_; - } else { - throw new ParseException("Not enough bytes to parse key provider id"); - } - } + int parsedBytes = 0; + try { + if (keyProviderIdLen_ < 0) { + parsedBytes += parseKeyProviderIdLen(b, off + parsedBytes); + } - /** - * Parse the key provider info length in the provided bytes. It looks - * for 2 bytes representing a short primitive type in the provided bytes - * starting at the specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the size - * of the short primitive type. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the size of the short - * primitive type. - * @throws ParseException - * if there are not sufficient bytes to parse the provider info - * length. - */ - //@ private normal_behavior - //@ requires deserializing == 0 && 0 <= keyProviderIdLen_ && keyProviderInfo_ == null; - //@ requires b != null && 0 <= off && off <= b.length - Short.BYTES; - //@ assignable keyProviderInfoLen_, deserializing, isDeserializing; - //@ ensures \result == Short.BYTES && deserializing == 2; - //@ also - //@ private exceptional_behavior - //@ requires deserializing == 0 && 0 <= keyProviderIdLen_ && keyProviderInfo_ == null; - //@ requires b != null && 0 <= off && b.length - Short.BYTES < off; - //@ assignable \nothing; - //@ signals_only ParseException; - private int parseKeyProviderInfoLen(final byte[] b, final int off) throws ParseException { - keyProviderInfoLen_ = PrimitivesParser.parseUnsignedShort(b, off); - //@ set deserializing = 2; - return Short.SIZE / Byte.SIZE; - } + if (keyProviderId_ == null) { + parsedBytes += parseKeyProviderId(b, off + parsedBytes); + } - /** - * Parse the key provider info in the provided bytes. It looks for bytes of - * size defined by the key provider info length in the provided bytes - * starting at the specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the key - * provider info length. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the key provider info - * length. - * @throws ParseException - * if there are not sufficient bytes to parse the provider info. - */ - //@ private normal_behavior - //@ requires deserializing == 2 && b != null && 0 <= off && off <= b.length; - //@ requires keyProviderInfoLen_ <= b.length - off; - //@ assignable keyProviderInfo_, deserializing, isDeserializing; - //@ ensures \result == keyProviderInfoLen_ && deserializing == 0; - //@ ensures keyProviderInfo_ != null && keyProviderInfo_.length == keyProviderInfoLen_; - //@ also - //@ private exceptional_behavior - //@ requires deserializing == 2 && b != null && 0 <= off && off <= b.length; - //@ requires b.length - off < keyProviderInfoLen_; - //@ assignable \nothing; - //@ signals_only ParseException; - private int parseKeyProviderInfo(final byte[] b, final int off) throws ParseException { - final int bytesToParseLen = b.length - off; - if (bytesToParseLen >= keyProviderInfoLen_) { - keyProviderInfo_ = Arrays.copyOfRange(b, off, off + keyProviderInfoLen_); - //@ set deserializing = 0; - return keyProviderInfoLen_; - } else { - throw new ParseException("Not enough bytes to parse key provider info"); - } - } + if (keyProviderInfoLen_ < 0) { + parsedBytes += parseKeyProviderInfoLen(b, off + parsedBytes); + } - /** - * Parse the key length in the provided bytes. It looks for 2 bytes - * representing a short primitive type in the provided bytes starting at the - * specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the size - * of the short primitive type. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the size of the short - * primitive type. - * @throws ParseException - * if there are not sufficient bytes to parse the key length. - */ - //@ private normal_behavior - //@ requires deserializing == 0 && 0 <= keyProviderIdLen_ && 0 <= keyProviderInfoLen_ && encryptedKey_ == null; - //@ requires b != null && 0 <= off && off <= b.length - Short.BYTES; - //@ assignable encryptedKeyLen_, deserializing, isDeserializing; - //@ ensures \result == Short.BYTES && deserializing == 3; - //@ also - //@ private exceptional_behavior - //@ requires deserializing == 0 && 0 <= keyProviderIdLen_ && 0 <= keyProviderInfoLen_ && encryptedKey_ == null; - //@ requires b != null && 0 <= off && b.length - Short.BYTES < off; - //@ assignable \nothing; - //@ signals_only ParseException; - private int parseKeyLen(final byte[] b, final int off) throws ParseException { - encryptedKeyLen_ = PrimitivesParser.parseUnsignedShort(b, off); - //@ set deserializing = 3; - return Short.SIZE / Byte.SIZE; - } + if (keyProviderInfo_ == null) { + parsedBytes += parseKeyProviderInfo(b, off + parsedBytes); + } - /** - * Parse the key in the provided bytes. It looks for bytes of size defined - * by the key length in the provided bytes starting at the specified off. - * - *

- * If successful, it returns the size of the parsed bytes which is the key - * length. On failure, it throws a parse exception. - * - * @param b - * the byte array to parse. - * @param off - * the offset in the byte array to use when parsing. - * @return - * the size of the parsed bytes which is the key length. - * @throws ParseException - * if there are not sufficient bytes to parse the key. - */ - //@ private normal_behavior - //@ requires deserializing == 3 && b != null && 0 <= off && off <= b.length; - //@ requires encryptedKeyLen_ <= b.length - off; - //@ assignable encryptedKey_, deserializing, isDeserializing; - //@ ensures \result == encryptedKeyLen_ && deserializing == 0; - //@ ensures encryptedKey_ != null && encryptedKey_.length == encryptedKeyLen_; - //@ also - //@ private exceptional_behavior - //@ requires deserializing == 3 && b != null && 0 <= off && off <= b.length; - //@ requires b.length - off < encryptedKeyLen_; - //@ assignable \nothing; - //@ signals_only ParseException; - private int parseKey(final byte[] b, final int off) throws ParseException { - final int bytesToParseLen = b.length - off; - if (bytesToParseLen >= encryptedKeyLen_) { - encryptedKey_ = Arrays.copyOfRange(b, off, off + encryptedKeyLen_); - //@ set deserializing = 0; - return encryptedKeyLen_; - } else { - throw new ParseException("Not enough bytes to parse key"); - } - } + if (encryptedKeyLen_ < 0) { + parsedBytes += parseKeyLen(b, off + parsedBytes); + } - /** - * Deserialize the provided bytes starting at the specified offset to - * construct an instance of this class. - * - *

- * This method parses the provided bytes for the individual fields in this - * class. This methods also supports partial parsing where not all the bytes - * required for parsing the fields successfully are available. - * - * @param b - * the byte array to deserialize. - * @param off - * the offset in the byte array to use for deserialization. - * @return - * the number of bytes consumed in deserialization. - * - */ - //@ public normal_behavior - //@ requires b == null; - //@ assignable \nothing; - //@ ensures \result == 0; - //@ also - //@ public normal_behavior - //@ requires !isComplete_; - //@ requires b != null && 0 <= off && off <= b.length; - //@ assignable this.*; - //@ ensures 0 <= \result && \result <= b.length - off; - //@ ensures isComplete_ ==> !isDeserializing; - public int deserialize(final byte[] b, final int off) { - if (b == null) { - return 0; - } - - int parsedBytes = 0; - try { - if (keyProviderIdLen_ < 0) { - parsedBytes += parseKeyProviderIdLen(b, off + parsedBytes); - } - - if (keyProviderId_ == null) { - parsedBytes += parseKeyProviderId(b, off + parsedBytes); - } - - if (keyProviderInfoLen_ < 0) { - parsedBytes += parseKeyProviderInfoLen(b, off + parsedBytes); - } - - if (keyProviderInfo_ == null) { - parsedBytes += parseKeyProviderInfo(b, off + parsedBytes); - } - - if (encryptedKeyLen_ < 0) { - parsedBytes += parseKeyLen(b, off + parsedBytes); - } - - if (encryptedKey_ == null) { - parsedBytes += parseKey(b, off + parsedBytes); - } - - isComplete_ = true; - } catch (ParseException e) { - // this results when we do partial parsing and there aren't enough - // bytes to parse; ignore it and return the bytes parsed thus far. - } - return parsedBytes; - } + if (encryptedKey_ == null) { + parsedBytes += parseKey(b, off + parsedBytes); + } - /** - * Serialize an instance of this class to a byte array. - * - * @return - * the serialized bytes of the instance. - */ - //@ public normal_behavior - //@ requires !isDeserializing; - //@ requires providerId != null; - //@ requires providerInformation != null; - //@ requires encryptedDataKey != null; - //@ assignable \nothing; - //@ ensures \fresh(\result); - //@ ensures \result.length == 3 * Short.BYTES + providerId.length + providerInformation.length + encryptedDataKey.length; - //@ code_java_math // necessary, or else casts to short are warnings - public byte[] toByteArray() { - final int outLen = 3 * (Short.SIZE / Byte.SIZE) + keyProviderIdLen_ + keyProviderInfoLen_ + encryptedKeyLen_; - final ByteBuffer out = ByteBuffer.allocate(outLen); - - out.putShort((short) keyProviderIdLen_); - out.put(keyProviderId_, 0, keyProviderIdLen_); - - out.putShort((short) keyProviderInfoLen_); - out.put(keyProviderInfo_, 0, keyProviderInfoLen_); - - out.putShort((short) encryptedKeyLen_); - out.put(encryptedKey_, 0, encryptedKeyLen_); - - return out.array(); + isComplete_ = true; + } catch (ParseException e) { + // this results when we do partial parsing and there aren't enough + // bytes to parse; ignore it and return the bytes parsed thus far. } - - /** - * Check if this object has all the header fields populated and available - * for reading. - * - * @return - * true if this object containing the single block header fields - * is complete; false otherwise. - */ - //@ public normal_behavior - //@ ensures \result == isComplete_; - //@ pure - public boolean isComplete() { - return isComplete_; + return parsedBytes; + } + + /** + * Serialize an instance of this class to a byte array. + * + * @return the serialized bytes of the instance. + */ + // @ public normal_behavior + // @ requires !isDeserializing; + // @ requires providerId != null; + // @ requires providerInformation != null; + // @ requires encryptedDataKey != null; + // @ assignable \nothing; + // @ ensures \fresh(\result); + // @ ensures \result.length == 3 * Short.BYTES + providerId.length + providerInformation.length + // + encryptedDataKey.length; + // @ code_java_math // necessary, or else casts to short are warnings + public byte[] toByteArray() { + final int outLen = + 3 * (Short.SIZE / Byte.SIZE) + keyProviderIdLen_ + keyProviderInfoLen_ + encryptedKeyLen_; + final ByteBuffer out = ByteBuffer.allocate(outLen); + + out.putShort((short) keyProviderIdLen_); + out.put(keyProviderId_, 0, keyProviderIdLen_); + + out.putShort((short) keyProviderInfoLen_); + out.put(keyProviderInfo_, 0, keyProviderInfoLen_); + + out.putShort((short) encryptedKeyLen_); + out.put(encryptedKey_, 0, encryptedKeyLen_); + + return out.array(); + } + + /** + * Check if this object has all the header fields populated and available for reading. + * + * @return true if this object containing the single block header fields is complete; false + * otherwise. + */ + // @ public normal_behavior + // @ ensures \result == isComplete_; + // @ pure + public boolean isComplete() { + return isComplete_; + } + + /** + * Return the length of the key provider identifier set in the header. + * + * @return the length of the key provider identifier. + */ + // @ public normal_behavior + // @ requires !isDeserializing; + // @ ensures providerId == null ==> \result < 0; + // @ ensures providerId != null ==> \result == providerId.length; + // @ pure + public int getKeyProviderIdLen() { + return keyProviderIdLen_; + } + + /** + * Return the key provider identifier set in the header. + * + * @return the string containing the key provider identifier. + */ + @Override + public String getProviderId() { + String s = new String(keyProviderId_, StandardCharsets.UTF_8); + // The following assume statement essentially says that different + // calls to the String constructor above, with the same parameters, + // result in strings with the same contents. The assumption is + // needed, because JML does not give a way to prove it. + // @ assume String.equals(s, EncryptedDataKey.ba2s(keyProviderId_)); + return s; + } + + /** + * Return the length of the key provider info set in the header. + * + * @return the length of the key provider info. + */ + // @ public normal_behavior + // @ requires !isDeserializing; + // @ ensures providerInformation == null ==> \result < 0; + // @ ensures providerInformation != null ==> \result == providerInformation.length; + // @ pure + public int getKeyProviderInfoLen() { + return keyProviderInfoLen_; + } + + /** + * Return the information on the key provider set in the header. + * + * @return the bytes containing information on the key provider. + */ + @Override + public byte[] getProviderInformation() { + return keyProviderInfo_.clone(); + } + + /** + * Return the length of the encrypted data key set in the header. + * + * @return the length of the encrypted data key. + */ + // @ public normal_behavior + // @ requires !isDeserializing; + // @ ensures encryptedDataKey == null ==> \result < 0; + // @ ensures encryptedDataKey != null ==> \result == encryptedDataKey.length; + // @ pure + public int getEncryptedDataKeyLen() { + return encryptedKeyLen_; + } + + /** + * Return the encrypted data key set in the header. + * + * @return the bytes containing the encrypted data key. + */ + @Override + public byte[] getEncryptedDataKey() { + return encryptedKey_.clone(); + } + + /** + * Set the key provider identifier. + * + * @param keyProviderId the key provider identifier. + */ + // @ public normal_behavior + // @ requires !isDeserializing; + // @ requires keyProviderId != null && EncryptedDataKey.s2ba(keyProviderId).length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ assignable providerId; + // @ ensures \fresh(providerId); + // @ ensures Arrays.equalArrays(providerId, EncryptedDataKey.s2ba(keyProviderId)); + // @ also + // @ private normal_behavior // TODO: this behavior is a temporary workaround + // @ requires !isDeserializing; + // @ requires keyProviderId != null && EncryptedDataKey.s2ba(keyProviderId).length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ assignable keyProviderId_, keyProviderIdLen_; + // @ also + // @ public exceptional_behavior + // @ requires !isDeserializing; + // @ requires keyProviderId != null && Constants.UNSIGNED_SHORT_MAX_VAL < + // EncryptedDataKey.s2ba(keyProviderId).length; + // @ assignable \nothing; + // @ signals_only AwsCryptoException; + public void setKeyProviderId(final String keyProviderId) { + final byte[] keyProviderIdBytes = keyProviderId.getBytes(StandardCharsets.UTF_8); + // @ assume Arrays.equalArrays(keyProviderIdBytes, EncryptedDataKey.s2ba(keyProviderId)); + if (keyProviderIdBytes.length > Constants.UNSIGNED_SHORT_MAX_VAL) { + throw new AwsCryptoException( + "Key provider identifier length exceeds the max value of an unsigned short primitive."); } - - /** - * Return the length of the key provider identifier set in the header. - * - * @return - * the length of the key provider identifier. - */ - //@ public normal_behavior - //@ requires !isDeserializing; - //@ ensures providerId == null ==> \result < 0; - //@ ensures providerId != null ==> \result == providerId.length; - //@ pure - public int getKeyProviderIdLen() { - return keyProviderIdLen_; - } - - /** - * Return the key provider identifier set in the header. - * - * @return - * the string containing the key provider identifier. - */ - @Override - public String getProviderId() { - String s = new String(keyProviderId_, StandardCharsets.UTF_8); - // The following assume statement essentially says that different - // calls to the String constructor above, with the same parameters, - // result in strings with the same contents. The assumption is - // needed, because JML does not give a way to prove it. - //@ assume String.equals(s, EncryptedDataKey.ba2s(keyProviderId_)); - return s; - } - - /** - * Return the length of the key provider info set in the header. - * - * @return - * the length of the key provider info. - */ - //@ public normal_behavior - //@ requires !isDeserializing; - //@ ensures providerInformation == null ==> \result < 0; - //@ ensures providerInformation != null ==> \result == providerInformation.length; - //@ pure - public int getKeyProviderInfoLen() { - return keyProviderInfoLen_; - } - - /** - * Return the information on the key provider set in the header. - * - * @return - * the bytes containing information on the key provider. - */ - @Override - public byte[] getProviderInformation() { - return keyProviderInfo_.clone(); + keyProviderId_ = keyProviderIdBytes; + keyProviderIdLen_ = keyProviderId_.length; + } + + /** + * Set the information on the key provider identifier. + * + * @param keyProviderInfo the bytes containing information on the key provider identifier. + */ + // @ public normal_behavior + // @ requires !isDeserializing; + // @ requires keyProviderInfo != null && keyProviderInfo.length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ assignable providerInformation; + // @ ensures \fresh(providerInformation); + // @ ensures Arrays.equalArrays(providerInformation, keyProviderInfo); + // @ also + // @ private normal_behavior // TODO: this behavior is a temporary workaround + // @ requires !isDeserializing; + // @ requires keyProviderInfo != null && keyProviderInfo.length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ assignable keyProviderInfo_, keyProviderInfoLen_; + // @ also private exceptional_behavior + // @ requires !isDeserializing; + // @ requires keyProviderInfo != null; + // @ requires keyProviderInfo.length > Constants.UNSIGNED_SHORT_MAX_VAL; + // @ assignable \nothing; + // @ signals_only AwsCryptoException; + public void setKeyProviderInfo(final byte[] keyProviderInfo) { + if (keyProviderInfo.length > Constants.UNSIGNED_SHORT_MAX_VAL) { + throw new AwsCryptoException( + "Key provider identifier information length exceeds the max value of an unsigned short primitive."); } - - /** - * Return the length of the encrypted data key set in the header. - * - * @return - * the length of the encrypted data key. - */ - //@ public normal_behavior - //@ requires !isDeserializing; - //@ ensures encryptedDataKey == null ==> \result < 0; - //@ ensures encryptedDataKey != null ==> \result == encryptedDataKey.length; - //@ pure - public int getEncryptedDataKeyLen() { - return encryptedKeyLen_; - } - - /** - * Return the encrypted data key set in the header. - * - * @return - * the bytes containing the encrypted data key. - */ - @Override - public byte[] getEncryptedDataKey() { - return encryptedKey_.clone(); - } - - /** - * Set the key provider identifier. - * - * @param keyProviderId - * the key provider identifier. - */ - //@ public normal_behavior - //@ requires !isDeserializing; - //@ requires keyProviderId != null && EncryptedDataKey.s2ba(keyProviderId).length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ assignable providerId; - //@ ensures \fresh(providerId); - //@ ensures Arrays.equalArrays(providerId, EncryptedDataKey.s2ba(keyProviderId)); - //@ also - //@ private normal_behavior // TODO: this behavior is a temporary workaround - //@ requires !isDeserializing; - //@ requires keyProviderId != null && EncryptedDataKey.s2ba(keyProviderId).length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ assignable keyProviderId_, keyProviderIdLen_; - //@ also - //@ public exceptional_behavior - //@ requires !isDeserializing; - //@ requires keyProviderId != null && Constants.UNSIGNED_SHORT_MAX_VAL < EncryptedDataKey.s2ba(keyProviderId).length; - //@ assignable \nothing; - //@ signals_only AwsCryptoException; - public void setKeyProviderId(final String keyProviderId) { - final byte[] keyProviderIdBytes = keyProviderId.getBytes(StandardCharsets.UTF_8); - //@ assume Arrays.equalArrays(keyProviderIdBytes, EncryptedDataKey.s2ba(keyProviderId)); - if (keyProviderIdBytes.length > Constants.UNSIGNED_SHORT_MAX_VAL) { - throw new AwsCryptoException( - "Key provider identifier length exceeds the max value of an unsigned short primitive."); - } - keyProviderId_ = keyProviderIdBytes; - keyProviderIdLen_ = keyProviderId_.length; - } - - /** - * Set the information on the key provider identifier. - * - * @param keyProviderInfo - * the bytes containing information on the key provider - * identifier. - */ - //@ public normal_behavior - //@ requires !isDeserializing; - //@ requires keyProviderInfo != null && keyProviderInfo.length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ assignable providerInformation; - //@ ensures \fresh(providerInformation); - //@ ensures Arrays.equalArrays(providerInformation, keyProviderInfo); - //@ also - //@ private normal_behavior // TODO: this behavior is a temporary workaround - //@ requires !isDeserializing; - //@ requires keyProviderInfo != null && keyProviderInfo.length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ assignable keyProviderInfo_, keyProviderInfoLen_; - //@ also private exceptional_behavior - //@ requires !isDeserializing; - //@ requires keyProviderInfo != null; - //@ requires keyProviderInfo.length > Constants.UNSIGNED_SHORT_MAX_VAL; - //@ assignable \nothing; - //@ signals_only AwsCryptoException; - public void setKeyProviderInfo(final byte[] keyProviderInfo) { - if (keyProviderInfo.length > Constants.UNSIGNED_SHORT_MAX_VAL) { - throw new AwsCryptoException( - "Key provider identifier information length exceeds the max value of an unsigned short primitive."); - } - keyProviderInfo_ = keyProviderInfo.clone(); - keyProviderInfoLen_ = keyProviderInfo.length; - } - - /** - * Set the encrypted data key. - * - * @param encryptedDataKey - * the bytes containing the encrypted data key. - */ - //@ public normal_behavior - //@ requires !isDeserializing; - //@ requires encryptedDataKey != null && encryptedDataKey.length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ assignable this.encryptedDataKey; - //@ ensures \fresh(this.encryptedDataKey); - //@ ensures Arrays.equalArrays(this.encryptedDataKey, encryptedDataKey); - //@ also - //@ private normal_behavior // TODO: this behavior is a temporary workaround - //@ requires !isDeserializing; - //@ requires encryptedDataKey != null && encryptedDataKey.length <= Constants.UNSIGNED_SHORT_MAX_VAL; - //@ assignable encryptedKey_, encryptedKeyLen_; - //@ also - //@ public exceptional_behavior - //@ requires !isDeserializing; - //@ requires encryptedDataKey != null; - //@ requires encryptedDataKey.length > Constants.UNSIGNED_SHORT_MAX_VAL; - //@ assignable \nothing; - //@ signals_only AwsCryptoException; - public void setEncryptedDataKey(final byte[] encryptedDataKey) { - if (encryptedDataKey.length > Constants.UNSIGNED_SHORT_MAX_VAL) { - throw new AwsCryptoException("Key length exceeds the max value of an unsigned short primitive."); - } - encryptedKey_ = encryptedDataKey.clone(); - encryptedKeyLen_ = encryptedKey_.length; + keyProviderInfo_ = keyProviderInfo.clone(); + keyProviderInfoLen_ = keyProviderInfo.length; + } + + /** + * Set the encrypted data key. + * + * @param encryptedDataKey the bytes containing the encrypted data key. + */ + // @ public normal_behavior + // @ requires !isDeserializing; + // @ requires encryptedDataKey != null && encryptedDataKey.length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ assignable this.encryptedDataKey; + // @ ensures \fresh(this.encryptedDataKey); + // @ ensures Arrays.equalArrays(this.encryptedDataKey, encryptedDataKey); + // @ also + // @ private normal_behavior // TODO: this behavior is a temporary workaround + // @ requires !isDeserializing; + // @ requires encryptedDataKey != null && encryptedDataKey.length <= + // Constants.UNSIGNED_SHORT_MAX_VAL; + // @ assignable encryptedKey_, encryptedKeyLen_; + // @ also + // @ public exceptional_behavior + // @ requires !isDeserializing; + // @ requires encryptedDataKey != null; + // @ requires encryptedDataKey.length > Constants.UNSIGNED_SHORT_MAX_VAL; + // @ assignable \nothing; + // @ signals_only AwsCryptoException; + public void setEncryptedDataKey(final byte[] encryptedDataKey) { + if (encryptedDataKey.length > Constants.UNSIGNED_SHORT_MAX_VAL) { + throw new AwsCryptoException( + "Key length exceeds the max value of an unsigned short primitive."); } + encryptedKey_ = encryptedDataKey.clone(); + encryptedKeyLen_ = encryptedKey_.length; + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/model/package-info.java b/src/main/java/com/amazonaws/encryptionsdk/model/package-info.java index 897a15f3b..38d97a804 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/model/package-info.java +++ b/src/main/java/com/amazonaws/encryptionsdk/model/package-info.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -14,23 +14,16 @@ /** * Contains the classes that implement the defined message format for storing the encrypted content * and the data key. - * + * *

    - *
  • - * the CiphertextHeaders class implements the format for the headers that wrap the - * (single-block/framed) encrypted content. The data key is stored in this header.
  • - * - *
  • - * the CipherBlockHeaders class implements the format for the headers that wrap the encrypted - * content stored as a single-block.
  • - * - *
  • - * the CipherFrameHeader class implements the format for the headers that wrap the encrypted content - * stored in frames.
  • - * - *
  • - * the KeyBlob class implements the format for storing the encrypted data key along with the headers - * that identify the key provider.
  • + *
  • the CiphertextHeaders class implements the format for the headers that wrap the + * (single-block/framed) encrypted content. The data key is stored in this header. + *
  • the CipherBlockHeaders class implements the format for the headers that wrap the encrypted + * content stored as a single-block. + *
  • the CipherFrameHeader class implements the format for the headers that wrap the encrypted + * content stored in frames. + *
  • the KeyBlob class implements the format for storing the encrypted data key along with the + * headers that identify the key provider. *
*/ package com.amazonaws.encryptionsdk.model; diff --git a/src/main/java/com/amazonaws/encryptionsdk/multi/MultipleProviderFactory.java b/src/main/java/com/amazonaws/encryptionsdk/multi/MultipleProviderFactory.java index 00b723d47..133f2e042 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/multi/MultipleProviderFactory.java +++ b/src/main/java/com/amazonaws/encryptionsdk/multi/MultipleProviderFactory.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,12 +13,6 @@ package com.amazonaws.encryptionsdk.multi; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; -import java.util.Map; - import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.DataKey; import com.amazonaws.encryptionsdk.EncryptedDataKey; @@ -29,133 +23,144 @@ import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; import com.amazonaws.encryptionsdk.internal.Utils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; /** - * Constructs {@link MasterKeyProvider}s which are backed by any number of other - * {@link MasterKeyProvider}s. The returned provider will have the following properties: + * Constructs {@link MasterKeyProvider}s which are backed by any number of other {@link + * MasterKeyProvider}s. The returned provider will have the following properties: * *
    - *
  • {@link MasterKeyProvider#getMasterKeysForEncryption(MasterKeyRequest)} will result in the - * union of all responses from the backing providers. Likewise, - *
  • {@link MasterKeyProvider#decryptDataKey(CryptoAlgorithm, Collection, Map)} will succeed if - * and only if at least one backing provider can successfully decrypt the {@link DataKey}s. - *
  • {@link MasterKeyProvider#getDefaultProviderId()} is delegated to the first backing provider. - *
  • {@link MasterKeyProvider#getMasterKey(String, String)} will attempt to find the appropriate - * backing provider to return a {@link MasterKey}. + *
  • {@link MasterKeyProvider#getMasterKeysForEncryption(MasterKeyRequest)} will result in the + * union of all responses from the backing providers. Likewise, + *
  • {@link MasterKeyProvider#decryptDataKey(CryptoAlgorithm, Collection, Map)} will succeed if + * and only if at least one backing provider can successfully decrypt the {@link DataKey}s. + *
  • {@link MasterKeyProvider#getDefaultProviderId()} is delegated to the first backing + * provider. + *
  • {@link MasterKeyProvider#getMasterKey(String, String)} will attempt to find the appropriate + * backing provider to return a {@link MasterKey}. *
* * All methods in this factory return identical results and exist only for different degrees of * type-safety. */ public class MultipleProviderFactory { - private MultipleProviderFactory() { - // Prevent instantiation - } + private MultipleProviderFactory() { + // Prevent instantiation + } - public static > MasterKeyProvider buildMultiProvider(final Class masterKeyClass, - final List> providers) { - return new MultiProvider(providers); - } + public static > MasterKeyProvider buildMultiProvider( + final Class masterKeyClass, + final List> providers) { + return new MultiProvider(providers); + } - @SafeVarargs - public static , P extends MasterKeyProvider> MasterKeyProvider buildMultiProvider( - final Class masterKeyClass, final P... providers) { - return buildMultiProvider(masterKeyClass, Arrays.asList(providers)); - } + @SafeVarargs + public static , P extends MasterKeyProvider> + MasterKeyProvider buildMultiProvider(final Class masterKeyClass, final P... providers) { + return buildMultiProvider(masterKeyClass, Arrays.asList(providers)); + } - @SuppressWarnings({ "rawtypes", "unchecked" }) - public static MasterKeyProvider buildMultiProvider(final List> providers) { - return new MultiProvider(providers); - } + @SuppressWarnings({"rawtypes", "unchecked"}) + public static MasterKeyProvider buildMultiProvider( + final List> providers) { + return new MultiProvider(providers); + } - @SafeVarargs - public static

> MasterKeyProvider buildMultiProvider(final P... providers) { - return buildMultiProvider(Arrays.asList(providers)); - } + @SafeVarargs + public static

> MasterKeyProvider buildMultiProvider( + final P... providers) { + return buildMultiProvider(Arrays.asList(providers)); + } - private static class MultiProvider> extends MasterKeyProvider { - private final List> providers_; + private static class MultiProvider> extends MasterKeyProvider { + private final List> providers_; - private MultiProvider(final List> providers) { - Utils.assertNonNull(providers, "providers"); - if (providers.isEmpty()) { - throw new IllegalArgumentException("providers must not be empty"); - } - providers_ = new ArrayList<>(providers); - } + private MultiProvider(final List> providers) { + Utils.assertNonNull(providers, "providers"); + if (providers.isEmpty()) { + throw new IllegalArgumentException("providers must not be empty"); + } + providers_ = new ArrayList<>(providers); + } - @Override - public String getDefaultProviderId() { - return providers_.get(0).getDefaultProviderId(); - } + @Override + public String getDefaultProviderId() { + return providers_.get(0).getDefaultProviderId(); + } - @Override - public K getMasterKey(final String keyId) throws UnsupportedProviderException, NoSuchMasterKeyException { - for (final MasterKeyProvider prov : providers_) { - try { - final K result = prov.getMasterKey(keyId); - if (result != null) { - return result; - } - } catch (final NoSuchMasterKeyException ex) { - // swallow and continue - } - } - throw new NoSuchMasterKeyException(); + @Override + public K getMasterKey(final String keyId) + throws UnsupportedProviderException, NoSuchMasterKeyException { + for (final MasterKeyProvider prov : providers_) { + try { + final K result = prov.getMasterKey(keyId); + if (result != null) { + return result; + } + } catch (final NoSuchMasterKeyException ex) { + // swallow and continue } + } + throw new NoSuchMasterKeyException(); + } - @Override - public K getMasterKey(final String provider, final String keyId) throws UnsupportedProviderException, - NoSuchMasterKeyException { - boolean foundProvider = false; - for (final MasterKeyProvider prov : providers_) { - if (prov.canProvide(provider)) { - foundProvider = true; - try { - final K result = prov.getMasterKey(provider, keyId); - if (result != null) { - return result; - } - } catch (final NoSuchMasterKeyException ex) { - // swallow and continue - } - } - } - if (foundProvider) { - throw new NoSuchMasterKeyException(); - } else { - throw new UnsupportedProviderException(provider); + @Override + public K getMasterKey(final String provider, final String keyId) + throws UnsupportedProviderException, NoSuchMasterKeyException { + boolean foundProvider = false; + for (final MasterKeyProvider prov : providers_) { + if (prov.canProvide(provider)) { + foundProvider = true; + try { + final K result = prov.getMasterKey(provider, keyId); + if (result != null) { + return result; } + } catch (final NoSuchMasterKeyException ex) { + // swallow and continue + } } + } + if (foundProvider) { + throw new NoSuchMasterKeyException(); + } else { + throw new UnsupportedProviderException(provider); + } + } - @Override - public List getMasterKeysForEncryption(final MasterKeyRequest request) { - final List result = new ArrayList<>(); - for (final MasterKeyProvider prov : providers_) { - result.addAll(prov.getMasterKeysForEncryption(request)); - } - return result; - } + @Override + public List getMasterKeysForEncryption(final MasterKeyRequest request) { + final List result = new ArrayList<>(); + for (final MasterKeyProvider prov : providers_) { + result.addAll(prov.getMasterKeysForEncryption(request)); + } + return result; + } - @SuppressWarnings("unchecked") - @Override - public DataKey decryptDataKey(final CryptoAlgorithm algorithm, - final Collection encryptedDataKeys, - final Map encryptionContext) - throws UnsupportedProviderException, AwsCryptoException { - final List exceptions = new ArrayList<>(); - for (final MasterKeyProvider prov : providers_) { - try { - final DataKey result = prov - .decryptDataKey(algorithm, encryptedDataKeys, encryptionContext); - if (result != null) { - return (DataKey) result; - } - } catch (final Exception ex) { - exceptions.add(ex); - } - } - throw buildCannotDecryptDksException(exceptions); + @SuppressWarnings("unchecked") + @Override + public DataKey decryptDataKey( + final CryptoAlgorithm algorithm, + final Collection encryptedDataKeys, + final Map encryptionContext) + throws UnsupportedProviderException, AwsCryptoException { + final List exceptions = new ArrayList<>(); + for (final MasterKeyProvider prov : providers_) { + try { + final DataKey result = + prov.decryptDataKey(algorithm, encryptedDataKeys, encryptionContext); + if (result != null) { + return (DataKey) result; + } + } catch (final Exception ex) { + exceptions.add(ex); } + } + throw buildCannotDecryptDksException(exceptions); } + } } diff --git a/src/main/java/com/amazonaws/encryptionsdk/multi/package-info.java b/src/main/java/com/amazonaws/encryptionsdk/multi/package-info.java index d5deac27a..8539f0e31 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/multi/package-info.java +++ b/src/main/java/com/amazonaws/encryptionsdk/multi/package-info.java @@ -1,18 +1,18 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ /** - * Contains logic necessary to create {@link com.amazonaws.encryptionsdk.MasterKeyProvider}s - * which are backed by multiple {@code MasterKeyProviders}. + * Contains logic necessary to create {@link com.amazonaws.encryptionsdk.MasterKeyProvider}s which + * are backed by multiple {@code MasterKeyProviders}. */ package com.amazonaws.encryptionsdk.multi; diff --git a/src/main/java/com/amazonaws/encryptionsdk/package-info.java b/src/main/java/com/amazonaws/encryptionsdk/package-info.java index f30a2b6ab..c4939e2a4 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/package-info.java +++ b/src/main/java/com/amazonaws/encryptionsdk/package-info.java @@ -1,19 +1,18 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ /** - *

- * Contains {@link com.amazonaws.encryptionsdk.AwsCrypto}, the primary entry-point to the Aws Encryption SDK. + * Contains {@link com.amazonaws.encryptionsdk.AwsCrypto}, the primary entry-point to the Aws + * Encryption SDK. */ package com.amazonaws.encryptionsdk; - diff --git a/src/test/java/com/amazonaws/crypto/examples/BasicEncryptionExampleTest.java b/src/test/java/com/amazonaws/crypto/examples/BasicEncryptionExampleTest.java index 3384e06ed..20f7c79e4 100644 --- a/src/test/java/com/amazonaws/crypto/examples/BasicEncryptionExampleTest.java +++ b/src/test/java/com/amazonaws/crypto/examples/BasicEncryptionExampleTest.java @@ -8,8 +8,8 @@ public class BasicEncryptionExampleTest { - @Test - public void testEncryptAndDecrypt() { - BasicEncryptionExample.encryptAndDecrypt(KMSTestFixtures.TEST_KEY_IDS[0]); - } + @Test + public void testEncryptAndDecrypt() { + BasicEncryptionExample.encryptAndDecrypt(KMSTestFixtures.TEST_KEY_IDS[0]); + } } diff --git a/src/test/java/com/amazonaws/crypto/examples/BasicMultiRegionKeyEncryptionExampleTest.java b/src/test/java/com/amazonaws/crypto/examples/BasicMultiRegionKeyEncryptionExampleTest.java index 4a5e567b6..634b2c47b 100644 --- a/src/test/java/com/amazonaws/crypto/examples/BasicMultiRegionKeyEncryptionExampleTest.java +++ b/src/test/java/com/amazonaws/crypto/examples/BasicMultiRegionKeyEncryptionExampleTest.java @@ -8,11 +8,10 @@ public class BasicMultiRegionKeyEncryptionExampleTest { - @Test - public void testEncryptAndDecrypt() { - BasicMultiRegionKeyEncryptionExample.encryptAndDecrypt( - KMSTestFixtures.US_EAST_1_MULTI_REGION_KEY_ID, - KMSTestFixtures.US_WEST_2_MULTI_REGION_KEY_ID - ); - } + @Test + public void testEncryptAndDecrypt() { + BasicMultiRegionKeyEncryptionExample.encryptAndDecrypt( + KMSTestFixtures.US_EAST_1_MULTI_REGION_KEY_ID, + KMSTestFixtures.US_WEST_2_MULTI_REGION_KEY_ID); + } } diff --git a/src/test/java/com/amazonaws/crypto/examples/DiscoveryDecryptionExampleTest.java b/src/test/java/com/amazonaws/crypto/examples/DiscoveryDecryptionExampleTest.java index 91c6abb8d..1df65105b 100644 --- a/src/test/java/com/amazonaws/crypto/examples/DiscoveryDecryptionExampleTest.java +++ b/src/test/java/com/amazonaws/crypto/examples/DiscoveryDecryptionExampleTest.java @@ -8,9 +8,9 @@ public class DiscoveryDecryptionExampleTest { - @Test - public void testEncryptAndDecrypt() { - DiscoveryDecryptionExample.encryptAndDecrypt(KMSTestFixtures.TEST_KEY_IDS[0], - KMSTestFixtures.PARTITION, KMSTestFixtures.ACCOUNT_ID); - } + @Test + public void testEncryptAndDecrypt() { + DiscoveryDecryptionExample.encryptAndDecrypt( + KMSTestFixtures.TEST_KEY_IDS[0], KMSTestFixtures.PARTITION, KMSTestFixtures.ACCOUNT_ID); + } } diff --git a/src/test/java/com/amazonaws/crypto/examples/DiscoveryMultiRegionDecryptionExampleTest.java b/src/test/java/com/amazonaws/crypto/examples/DiscoveryMultiRegionDecryptionExampleTest.java index 626d116db..8ebdae172 100644 --- a/src/test/java/com/amazonaws/crypto/examples/DiscoveryMultiRegionDecryptionExampleTest.java +++ b/src/test/java/com/amazonaws/crypto/examples/DiscoveryMultiRegionDecryptionExampleTest.java @@ -8,13 +8,12 @@ public class DiscoveryMultiRegionDecryptionExampleTest { - @Test - public void testEncryptAndDecrypt() { - DiscoveryMultiRegionDecryptionExample.encryptAndDecrypt( - KMSTestFixtures.US_EAST_1_MULTI_REGION_KEY_ID, - KMSTestFixtures.PARTITION, - KMSTestFixtures.ACCOUNT_ID, - KMSTestFixtures.US_WEST_2 - ); - } + @Test + public void testEncryptAndDecrypt() { + DiscoveryMultiRegionDecryptionExample.encryptAndDecrypt( + KMSTestFixtures.US_EAST_1_MULTI_REGION_KEY_ID, + KMSTestFixtures.PARTITION, + KMSTestFixtures.ACCOUNT_ID, + KMSTestFixtures.US_WEST_2); + } } diff --git a/src/test/java/com/amazonaws/crypto/examples/MultipleCmkEncryptExampleTest.java b/src/test/java/com/amazonaws/crypto/examples/MultipleCmkEncryptExampleTest.java index 0f83117ca..6ff8ee970 100644 --- a/src/test/java/com/amazonaws/crypto/examples/MultipleCmkEncryptExampleTest.java +++ b/src/test/java/com/amazonaws/crypto/examples/MultipleCmkEncryptExampleTest.java @@ -8,8 +8,9 @@ public class MultipleCmkEncryptExampleTest { - @Test - public void testEncryptAndDecrypt() { - MultipleCmkEncryptExample.encryptAndDecrypt(KMSTestFixtures.TEST_KEY_IDS[0], KMSTestFixtures.TEST_KEY_IDS[1]); - } + @Test + public void testEncryptAndDecrypt() { + MultipleCmkEncryptExample.encryptAndDecrypt( + KMSTestFixtures.TEST_KEY_IDS[0], KMSTestFixtures.TEST_KEY_IDS[1]); + } } diff --git a/src/test/java/com/amazonaws/crypto/examples/RestrictRegionExampleTest.java b/src/test/java/com/amazonaws/crypto/examples/RestrictRegionExampleTest.java index 7d3e2e7f1..0758378c2 100644 --- a/src/test/java/com/amazonaws/crypto/examples/RestrictRegionExampleTest.java +++ b/src/test/java/com/amazonaws/crypto/examples/RestrictRegionExampleTest.java @@ -8,9 +8,12 @@ public class RestrictRegionExampleTest { - @Test - public void testEncryptAndDecrypt() { - RestrictRegionExample.encryptAndDecrypt(KMSTestFixtures.US_WEST_2_KEY_ID, - KMSTestFixtures.PARTITION, KMSTestFixtures.ACCOUNT_ID, KMSTestFixtures.US_WEST_2); - } + @Test + public void testEncryptAndDecrypt() { + RestrictRegionExample.encryptAndDecrypt( + KMSTestFixtures.US_WEST_2_KEY_ID, + KMSTestFixtures.PARTITION, + KMSTestFixtures.ACCOUNT_ID, + KMSTestFixtures.US_WEST_2); + } } diff --git a/src/test/java/com/amazonaws/crypto/examples/SetCommitmentPolicyExampleTest.java b/src/test/java/com/amazonaws/crypto/examples/SetCommitmentPolicyExampleTest.java index b1d3b3f7f..60df04623 100644 --- a/src/test/java/com/amazonaws/crypto/examples/SetCommitmentPolicyExampleTest.java +++ b/src/test/java/com/amazonaws/crypto/examples/SetCommitmentPolicyExampleTest.java @@ -8,8 +8,8 @@ public class SetCommitmentPolicyExampleTest { - @Test - public void testEncryptAndDecrypt() { - SetCommitmentPolicyExample.encryptAndDecrypt(KMSTestFixtures.TEST_KEY_IDS[0]); - } + @Test + public void testEncryptAndDecrypt() { + SetCommitmentPolicyExample.encryptAndDecrypt(KMSTestFixtures.TEST_KEY_IDS[0]); + } } diff --git a/src/test/java/com/amazonaws/crypto/examples/SetEncryptionAlgorithmExampleTest.java b/src/test/java/com/amazonaws/crypto/examples/SetEncryptionAlgorithmExampleTest.java index 68250057d..fe88e583d 100644 --- a/src/test/java/com/amazonaws/crypto/examples/SetEncryptionAlgorithmExampleTest.java +++ b/src/test/java/com/amazonaws/crypto/examples/SetEncryptionAlgorithmExampleTest.java @@ -8,8 +8,8 @@ public class SetEncryptionAlgorithmExampleTest { - @Test - public void testEncryptAndDecrypt() { - SetEncryptionAlgorithmExample.encryptAndDecrypt(KMSTestFixtures.TEST_KEY_IDS[0]); - } + @Test + public void testEncryptAndDecrypt() { + SetEncryptionAlgorithmExample.encryptAndDecrypt(KMSTestFixtures.TEST_KEY_IDS[0]); + } } diff --git a/src/test/java/com/amazonaws/crypto/examples/SimpleDataKeyCachingExampleTest.java b/src/test/java/com/amazonaws/crypto/examples/SimpleDataKeyCachingExampleTest.java index 535f05f04..9bd2f0da1 100644 --- a/src/test/java/com/amazonaws/crypto/examples/SimpleDataKeyCachingExampleTest.java +++ b/src/test/java/com/amazonaws/crypto/examples/SimpleDataKeyCachingExampleTest.java @@ -3,25 +3,28 @@ package com.amazonaws.crypto.examples; -import com.amazonaws.encryptionsdk.ParsedCiphertext; -import com.amazonaws.encryptionsdk.kms.KMSTestFixtures; - -import org.junit.Test; - import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import com.amazonaws.encryptionsdk.ParsedCiphertext; +import com.amazonaws.encryptionsdk.kms.KMSTestFixtures; +import org.junit.Test; + public class SimpleDataKeyCachingExampleTest { - private static final int MAX_ENTRY_AGE = 100; - private static final int CACHE_CAPACITY = 1000; + private static final int MAX_ENTRY_AGE = 100; + private static final int CACHE_CAPACITY = 1000; - @Test - public void testEncryptWithCaching() { - final byte[] result = SimpleDataKeyCachingExample.encryptWithCaching(KMSTestFixtures.TEST_KEY_IDS[0], MAX_ENTRY_AGE, CACHE_CAPACITY); - assertNotNull(result); - final ParsedCiphertext parsedResult = new ParsedCiphertext(result); - assertEquals(1, parsedResult.getEncryptedKeyBlobs().size()); - assertArrayEquals(KMSTestFixtures.TEST_KEY_IDS[0].getBytes(), parsedResult.getEncryptedKeyBlobs().get(0).getProviderInformation()); - } + @Test + public void testEncryptWithCaching() { + final byte[] result = + SimpleDataKeyCachingExample.encryptWithCaching( + KMSTestFixtures.TEST_KEY_IDS[0], MAX_ENTRY_AGE, CACHE_CAPACITY); + assertNotNull(result); + final ParsedCiphertext parsedResult = new ParsedCiphertext(result); + assertEquals(1, parsedResult.getEncryptedKeyBlobs().size()); + assertArrayEquals( + KMSTestFixtures.TEST_KEY_IDS[0].getBytes(), + parsedResult.getEncryptedKeyBlobs().get(0).getProviderInformation()); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java b/src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java index b53119f11..e6fd1534b 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java +++ b/src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java @@ -3,94 +3,92 @@ package com.amazonaws.encryptionsdk; -import com.amazonaws.crypto.examples.SetCommitmentPolicyExampleTest; -import com.amazonaws.crypto.examples.SetEncryptionAlgorithmExampleTest; -import com.amazonaws.crypto.examples.SimpleDataKeyCachingExampleTest; -import com.amazonaws.encryptionsdk.internal.*; -import com.amazonaws.encryptionsdk.jce.JceMasterKeyTest; -import org.junit.runner.RunWith; -import org.junit.runners.Suite; - import com.amazonaws.crypto.examples.BasicEncryptionExampleTest; import com.amazonaws.crypto.examples.BasicMultiRegionKeyEncryptionExampleTest; import com.amazonaws.crypto.examples.DiscoveryDecryptionExampleTest; import com.amazonaws.crypto.examples.DiscoveryMultiRegionDecryptionExampleTest; import com.amazonaws.crypto.examples.MultipleCmkEncryptExampleTest; import com.amazonaws.crypto.examples.RestrictRegionExampleTest; +import com.amazonaws.crypto.examples.SetCommitmentPolicyExampleTest; +import com.amazonaws.crypto.examples.SetEncryptionAlgorithmExampleTest; +import com.amazonaws.crypto.examples.SimpleDataKeyCachingExampleTest; import com.amazonaws.encryptionsdk.caching.CacheIdentifierTests; import com.amazonaws.encryptionsdk.caching.CachingCryptoMaterialsManagerTest; import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCacheTest; import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCacheThreadStormTest; import com.amazonaws.encryptionsdk.caching.NullCryptoMaterialsCacheTest; +import com.amazonaws.encryptionsdk.internal.*; +import com.amazonaws.encryptionsdk.jce.JceMasterKeyTest; import com.amazonaws.encryptionsdk.jce.KeyStoreProviderTest; +import com.amazonaws.encryptionsdk.kms.AwsKmsMrkAwareMasterKeyProviderTest; +import com.amazonaws.encryptionsdk.kms.AwsKmsMrkAwareMasterKeyTest; +import com.amazonaws.encryptionsdk.kms.DiscoveryFilterTest; +import com.amazonaws.encryptionsdk.kms.KMSProviderBuilderMockTests; +import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProviderTest; +import com.amazonaws.encryptionsdk.kms.KmsMasterKeyTest; import com.amazonaws.encryptionsdk.model.CipherBlockHeadersTest; import com.amazonaws.encryptionsdk.model.CipherFrameHeadersTest; import com.amazonaws.encryptionsdk.model.CiphertextHeadersTest; -import com.amazonaws.encryptionsdk.kms.DiscoveryFilterTest; -import com.amazonaws.encryptionsdk.model.KeyBlobTest; import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequestTest; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequestTest; +import com.amazonaws.encryptionsdk.model.KeyBlobTest; import com.amazonaws.encryptionsdk.multi.MultipleMasterKeyTest; -import com.amazonaws.encryptionsdk.kms.KMSProviderBuilderMockTests; -import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProviderTest; -import com.amazonaws.encryptionsdk.kms.KmsMasterKeyTest; -import com.amazonaws.encryptionsdk.kms.AwsKmsMrkAwareMasterKeyProviderTest; -import com.amazonaws.encryptionsdk.kms.AwsKmsMrkAwareMasterKeyTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; @RunWith(Suite.class) @Suite.SuiteClasses({ - CryptoAlgorithmTest.class, - CiphertextHeadersTest.class, - BlockDecryptionHandlerTest.class, - BlockEncryptionHandlerTest.class, - CipherHandlerTest.class, - DecryptionHandlerTest.class, - EncContextSerializerTest.class, - EncryptionHandlerTest.class, - FrameDecryptionHandlerTest.class, - FrameEncryptionHandlerTest.class, - PrimitivesParserTest.class, - KeyStoreProviderTest.class, - CipherBlockHeadersTest.class, - CipherFrameHeadersTest.class, - KeyBlobTest.class, - DecryptionMaterialsRequestTest.class, - MultipleMasterKeyTest.class, - AwsCryptoTest.class, - CryptoInputStreamTest.class, - CryptoOutputStreamTest.class, - TestVectorRunner.class, - XCompatDecryptTest.class, - DefaultCryptoMaterialsManagerTest.class, - NullCryptoMaterialsCacheTest.class, - AwsKmsCmkArnInfoTest.class, - CacheIdentifierTests.class, - CachingCryptoMaterialsManagerTest.class, - LocalCryptoMaterialsCacheTest.class, - LocalCryptoMaterialsCacheThreadStormTest.class, - UtilsTest.class, - MultipleMasterKeyTest.class, - KMSProviderBuilderMockTests.class, - JceMasterKeyTest.class, - KmsMasterKeyProviderTest.class, - KmsMasterKeyTest.class, - DiscoveryFilterTest.class, - CommittedKeyTest.class, - EncryptionMaterialsRequestTest.class, - CommitmentKATRunner.class, - BasicEncryptionExampleTest.class, - BasicMultiRegionKeyEncryptionExampleTest.class, - DiscoveryDecryptionExampleTest.class, - DiscoveryMultiRegionDecryptionExampleTest.class, - MultipleCmkEncryptExampleTest.class, - RestrictRegionExampleTest.class, - SimpleDataKeyCachingExampleTest.class, - SetEncryptionAlgorithmExampleTest.class, - SetCommitmentPolicyExampleTest.class, - ParsedCiphertextTest.class, - AwsKmsMrkAwareMasterKeyProviderTest.class, - AwsKmsMrkAwareMasterKeyTest.class, - VersionInfoTest.class, + CryptoAlgorithmTest.class, + CiphertextHeadersTest.class, + BlockDecryptionHandlerTest.class, + BlockEncryptionHandlerTest.class, + CipherHandlerTest.class, + DecryptionHandlerTest.class, + EncContextSerializerTest.class, + EncryptionHandlerTest.class, + FrameDecryptionHandlerTest.class, + FrameEncryptionHandlerTest.class, + PrimitivesParserTest.class, + KeyStoreProviderTest.class, + CipherBlockHeadersTest.class, + CipherFrameHeadersTest.class, + KeyBlobTest.class, + DecryptionMaterialsRequestTest.class, + MultipleMasterKeyTest.class, + AwsCryptoTest.class, + CryptoInputStreamTest.class, + CryptoOutputStreamTest.class, + TestVectorRunner.class, + XCompatDecryptTest.class, + DefaultCryptoMaterialsManagerTest.class, + NullCryptoMaterialsCacheTest.class, + AwsKmsCmkArnInfoTest.class, + CacheIdentifierTests.class, + CachingCryptoMaterialsManagerTest.class, + LocalCryptoMaterialsCacheTest.class, + LocalCryptoMaterialsCacheThreadStormTest.class, + UtilsTest.class, + MultipleMasterKeyTest.class, + KMSProviderBuilderMockTests.class, + JceMasterKeyTest.class, + KmsMasterKeyProviderTest.class, + KmsMasterKeyTest.class, + DiscoveryFilterTest.class, + CommittedKeyTest.class, + EncryptionMaterialsRequestTest.class, + CommitmentKATRunner.class, + BasicEncryptionExampleTest.class, + BasicMultiRegionKeyEncryptionExampleTest.class, + DiscoveryDecryptionExampleTest.class, + DiscoveryMultiRegionDecryptionExampleTest.class, + MultipleCmkEncryptExampleTest.class, + RestrictRegionExampleTest.class, + SimpleDataKeyCachingExampleTest.class, + SetEncryptionAlgorithmExampleTest.class, + SetCommitmentPolicyExampleTest.class, + ParsedCiphertextTest.class, + AwsKmsMrkAwareMasterKeyProviderTest.class, + AwsKmsMrkAwareMasterKeyTest.class, + VersionInfoTest.class, }) -public class AllTestsSuite { -} +public class AllTestsSuite {} diff --git a/src/test/java/com/amazonaws/encryptionsdk/AwsCryptoTest.java b/src/test/java/com/amazonaws/encryptionsdk/AwsCryptoTest.java index 564549ab5..9dd635740 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/AwsCryptoTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/AwsCryptoTest.java @@ -17,6 +17,21 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import com.amazonaws.encryptionsdk.caching.CachingCryptoMaterialsManager; +import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCache; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; +import com.amazonaws.encryptionsdk.internal.StaticMasterKey; +import com.amazonaws.encryptionsdk.internal.TestIOUtils; +import com.amazonaws.encryptionsdk.internal.Utils; +import com.amazonaws.encryptionsdk.jce.JceMasterKey; +import com.amazonaws.encryptionsdk.model.CiphertextType; +import com.amazonaws.encryptionsdk.model.DecryptionMaterials; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.model.EncryptionMaterials; +import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; +import com.amazonaws.util.IOUtils; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -30,1120 +45,1300 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; - -import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; -import com.amazonaws.util.IOUtils; import org.junit.Before; import org.junit.Test; -import com.amazonaws.encryptionsdk.caching.CachingCryptoMaterialsManager; -import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCache; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; -import com.amazonaws.encryptionsdk.internal.StaticMasterKey; -import com.amazonaws.encryptionsdk.internal.TestIOUtils; -import com.amazonaws.encryptionsdk.internal.Utils; -import com.amazonaws.encryptionsdk.model.CiphertextType; -import com.amazonaws.encryptionsdk.model.DecryptionMaterials; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; -import com.amazonaws.encryptionsdk.model.EncryptionMaterials; -import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; - public class AwsCryptoTest { - private StaticMasterKey masterKeyProvider; - private AwsCrypto forbidCommitmentClient_; - private AwsCrypto encryptionClient_; - private AwsCrypto noMaxEdksClient_; - private AwsCrypto maxEdksClient_; - private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; - private static final int MESSAGE_FORMAT_MAX_EDKS = (1 << 16) - 1; - - List requireWriteCommitmentPolicies = Arrays.asList( - CommitmentPolicy.RequireEncryptAllowDecrypt, CommitmentPolicy.RequireEncryptRequireDecrypt); - - @Before - public void init() { - masterKeyProvider = spy(new StaticMasterKey("testmaterial")); - - forbidCommitmentClient_ = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt).build(); - forbidCommitmentClient_.setEncryptionAlgorithm(CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256); - encryptionClient_ = AwsCrypto.standard(); - encryptionClient_.setEncryptionAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY); - - noMaxEdksClient_ = AwsCrypto - .builder() - .withCommitmentPolicy(CommitmentPolicy.RequireEncryptAllowDecrypt) - .withEncryptionAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY) - .build(); - maxEdksClient_ = AwsCrypto - .builder() - .withMaxEncryptedDataKeys(3) - .withCommitmentPolicy(CommitmentPolicy.RequireEncryptAllowDecrypt) - .withEncryptionAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY) - .build(); + private StaticMasterKey masterKeyProvider; + private AwsCrypto forbidCommitmentClient_; + private AwsCrypto encryptionClient_; + private AwsCrypto noMaxEdksClient_; + private AwsCrypto maxEdksClient_; + private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; + private static final int MESSAGE_FORMAT_MAX_EDKS = (1 << 16) - 1; + + List requireWriteCommitmentPolicies = + Arrays.asList( + CommitmentPolicy.RequireEncryptAllowDecrypt, + CommitmentPolicy.RequireEncryptRequireDecrypt); + + @Before + public void init() { + masterKeyProvider = spy(new StaticMasterKey("testmaterial")); + + forbidCommitmentClient_ = + AwsCrypto.builder() + .withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .build(); + forbidCommitmentClient_.setEncryptionAlgorithm( + CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256); + encryptionClient_ = AwsCrypto.standard(); + encryptionClient_.setEncryptionAlgorithm( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY); + + noMaxEdksClient_ = + AwsCrypto.builder() + .withCommitmentPolicy(CommitmentPolicy.RequireEncryptAllowDecrypt) + .withEncryptionAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY) + .build(); + maxEdksClient_ = + AwsCrypto.builder() + .withMaxEncryptedDataKeys(3) + .withCommitmentPolicy(CommitmentPolicy.RequireEncryptAllowDecrypt) + .withEncryptionAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY) + .build(); + } + + private void doEncryptDecrypt( + final CryptoAlgorithm cryptoAlg, final int byteSize, final int frameSize) { + final byte[] plaintextBytes = new byte[byteSize]; + + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC1", "Encrypt-decrypt test with %d" + byteSize); + + AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; + client.setEncryptionAlgorithm(cryptoAlg); + client.setEncryptionFrameSize(frameSize); + + final byte[] cipherText = + client.encryptData(masterKeyProvider, plaintextBytes, encryptionContext).getResult(); + final byte[] decryptedText = client.decryptData(masterKeyProvider, cipherText).getResult(); + + assertArrayEquals("Bad encrypt/decrypt for " + cryptoAlg, plaintextBytes, decryptedText); + } + + private void doTamperedEncryptDecrypt( + final CryptoAlgorithm cryptoAlg, final int byteSize, final int frameSize) { + final byte[] plaintextBytes = new byte[byteSize]; + + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC1", "Encrypt-decrypt test with %d" + byteSize); + + AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; + client.setEncryptionAlgorithm(cryptoAlg); + client.setEncryptionFrameSize(frameSize); + + final byte[] cipherText = + client.encryptData(masterKeyProvider, plaintextBytes, encryptionContext).getResult(); + cipherText[cipherText.length - 2] ^= (byte) 0xff; + try { + client.decryptData(masterKeyProvider, cipherText).getResult(); + fail("Expected BadCiphertextException"); + } catch (final BadCiphertextException ex) { + // Expected exception } - - private void doEncryptDecrypt(final CryptoAlgorithm cryptoAlg, final int byteSize, final int frameSize) { - final byte[] plaintextBytes = new byte[byteSize]; - - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC1", "Encrypt-decrypt test with %d" + byteSize); - - AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; - client.setEncryptionAlgorithm(cryptoAlg); - client.setEncryptionFrameSize(frameSize); - - final byte[] cipherText = client.encryptData( - masterKeyProvider, - plaintextBytes, - encryptionContext).getResult(); - final byte[] decryptedText = client.decryptData( - masterKeyProvider, - cipherText - ).getResult(); - - assertArrayEquals("Bad encrypt/decrypt for " + cryptoAlg, plaintextBytes, decryptedText); + } + + private void doTruncatedEncryptDecrypt( + final CryptoAlgorithm cryptoAlg, final int byteSize, final int frameSize) { + final byte[] plaintextBytes = new byte[byteSize]; + + final Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC1", "Encrypt-decrypt test with %d" + byteSize); + + AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; + client.setEncryptionAlgorithm(cryptoAlg); + client.setEncryptionFrameSize(frameSize); + + final byte[] cipherText = + client.encryptData(masterKeyProvider, plaintextBytes, encryptionContext).getResult(); + final byte[] truncatedCipherText = Arrays.copyOf(cipherText, cipherText.length - 1); + try { + client.decryptData(masterKeyProvider, truncatedCipherText).getResult(); + fail("Expected BadCiphertextException"); + } catch (final BadCiphertextException ex) { + // Expected exception } - - private void doTamperedEncryptDecrypt(final CryptoAlgorithm cryptoAlg, final int byteSize, final int frameSize) { - final byte[] plaintextBytes = new byte[byteSize]; - - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC1", "Encrypt-decrypt test with %d" + byteSize); - - AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; - client.setEncryptionAlgorithm(cryptoAlg); - client.setEncryptionFrameSize(frameSize); - - final byte[] cipherText = client.encryptData( - masterKeyProvider, - plaintextBytes, - encryptionContext).getResult(); - cipherText[cipherText.length - 2] ^= (byte) 0xff; - try { - client.decryptData( - masterKeyProvider, - cipherText - ).getResult(); - fail("Expected BadCiphertextException"); - } catch (final BadCiphertextException ex) { - // Expected exception - } + } + + private void doEncryptDecryptWithParsedCiphertext( + final CryptoAlgorithm cryptoAlg, final int byteSize, final int frameSize) { + final byte[] plaintextBytes = new byte[byteSize]; + + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC1", "Encrypt-decrypt test with %d" + byteSize); + + AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; + client.setEncryptionAlgorithm(cryptoAlg); + client.setEncryptionFrameSize(frameSize); + + final byte[] cipherText = + client.encryptData(masterKeyProvider, plaintextBytes, encryptionContext).getResult(); + ParsedCiphertext pCt = new ParsedCiphertext(cipherText); + assertEquals(client.getEncryptionAlgorithm(), pCt.getCryptoAlgoId()); + assertEquals(CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA, pCt.getType()); + assertEquals(1, pCt.getEncryptedKeyBlobCount()); + assertEquals(pCt.getEncryptedKeyBlobCount(), pCt.getEncryptedKeyBlobs().size()); + assertEquals( + masterKeyProvider.getProviderId(), pCt.getEncryptedKeyBlobs().get(0).getProviderId()); + for (Map.Entry e : encryptionContext.entrySet()) { + assertEquals(e.getValue(), pCt.getEncryptionContextMap().get(e.getKey())); } - private void doTruncatedEncryptDecrypt(final CryptoAlgorithm cryptoAlg, final int byteSize, final int frameSize) { - final byte[] plaintextBytes = new byte[byteSize]; - - final Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC1", "Encrypt-decrypt test with %d" + byteSize); - - AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; - client.setEncryptionAlgorithm(cryptoAlg); - client.setEncryptionFrameSize(frameSize); - - final byte[] cipherText = client.encryptData( - masterKeyProvider, - plaintextBytes, - encryptionContext).getResult(); - final byte[] truncatedCipherText = Arrays.copyOf(cipherText, cipherText.length - 1); - try { - client.decryptData( - masterKeyProvider, - truncatedCipherText - ).getResult(); - fail("Expected BadCiphertextException"); - } catch (final BadCiphertextException ex) { - // Expected exception - } - } - - private void doEncryptDecryptWithParsedCiphertext(final CryptoAlgorithm cryptoAlg, final int byteSize, final int frameSize) { - final byte[] plaintextBytes = new byte[byteSize]; - - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC1", "Encrypt-decrypt test with %d" + byteSize); - - AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; - client.setEncryptionAlgorithm(cryptoAlg); - client.setEncryptionFrameSize(frameSize); - - final byte[] cipherText = client.encryptData( - masterKeyProvider, - plaintextBytes, - encryptionContext).getResult(); - ParsedCiphertext pCt = new ParsedCiphertext(cipherText); - assertEquals(client.getEncryptionAlgorithm(), pCt.getCryptoAlgoId()); - assertEquals(CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA, pCt.getType()); - assertEquals(1, pCt.getEncryptedKeyBlobCount()); - assertEquals(pCt.getEncryptedKeyBlobCount(), pCt.getEncryptedKeyBlobs().size()); - assertEquals(masterKeyProvider.getProviderId(), pCt.getEncryptedKeyBlobs().get(0).getProviderId()); - for (Map.Entry e : encryptionContext.entrySet()) { - assertEquals(e.getValue(), pCt.getEncryptionContextMap().get(e.getKey())); - } - - final byte[] decryptedText = client.decryptData( - masterKeyProvider, - pCt - ).getResult(); - - assertArrayEquals(plaintextBytes, decryptedText); - } + final byte[] decryptedText = client.decryptData(masterKeyProvider, pCt).getResult(); + + assertArrayEquals(plaintextBytes, decryptedText); + } + + @Test + public void encryptDecrypt() { + for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { + // Only test with crypto algs without commitment, since those + // are the only ones we can encrypt with + if (cryptoAlg.getMessageFormatVersion() != 1) { + continue; + } + + final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); + + for (int i = 0; i < frameSizeToTest.length; i++) { + final int frameSize = frameSizeToTest[i]; + int[] bytesToTest = { + 0, + 1, + frameSize - 1, + frameSize, + frameSize + 1, + (int) (frameSize * 1.5), + frameSize * 2, + 1000000 + }; - @Test - public void encryptDecrypt() { - for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { - // Only test with crypto algs without commitment, since those - // are the only ones we can encrypt with - if (cryptoAlg.getMessageFormatVersion() != 1) { - continue; - } - - final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); - - for (int i = 0; i < frameSizeToTest.length; i++) { - final int frameSize = frameSizeToTest[i]; - int[] bytesToTest = { 0, 1, frameSize - 1, frameSize, frameSize + 1, (int) (frameSize * 1.5), - frameSize * 2, 1000000 }; - - for (int j = 0; j < bytesToTest.length; j++) { - final int byteSize = bytesToTest[j]; - - if (byteSize > 500_000 && isFastTestSuiteActive()) { - continue; - } - - if (byteSize >= 0) { - doEncryptDecrypt(cryptoAlg, byteSize, frameSize); - } - } - } - } - } + for (int j = 0; j < bytesToTest.length; j++) { + final int byteSize = bytesToTest[j]; - @Test - public void encryptDecryptWithBadSignature() { - for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { - // Only test with crypto algs without commitment, since those - // are the only ones we can encrypt with - if (cryptoAlg.getMessageFormatVersion() != 1) { - continue; - } - - if (cryptoAlg.getTrailingSignatureAlgo() == null) { - continue; - } - final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); - - for (int i = 0; i < frameSizeToTest.length; i++) { - final int frameSize = frameSizeToTest[i]; - int[] bytesToTest = { 0, 1, frameSize - 1, frameSize, frameSize + 1, (int) (frameSize * 1.5), - frameSize * 2, 1000000 }; - - for (int j = 0; j < bytesToTest.length; j++) { - final int byteSize = bytesToTest[j]; - - if (byteSize > 500_000 && isFastTestSuiteActive()) { - continue; - } - - if (byteSize >= 0) { - doTamperedEncryptDecrypt(cryptoAlg, byteSize, frameSize); - } - } - } - } - } + if (byteSize > 500_000 && isFastTestSuiteActive()) { + continue; + } - @Test - public void encryptDecryptWithTruncatedCiphertext() { - for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { - // Only test with crypto algs without commitment, since those - // are the only ones we can encrypt with - if (cryptoAlg.getMessageFormatVersion() != 1) { - continue; - } - - final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); - - for (int i = 0; i < frameSizeToTest.length; i++) { - final int frameSize = frameSizeToTest[i]; - int[] bytesToTest = { 0, 1, frameSize - 1, frameSize, frameSize + 1, (int) (frameSize * 1.5), - frameSize * 2, 1000000 }; - - for (int j = 0; j < bytesToTest.length; j++) { - final int byteSize = bytesToTest[j]; - - if (byteSize > 500_000) { - continue; - } - - if (byteSize >= 0) { - doTruncatedEncryptDecrypt(cryptoAlg, byteSize, frameSize); - } - } - } + if (byteSize >= 0) { + doEncryptDecrypt(cryptoAlg, byteSize, frameSize); + } } + } } + } + + @Test + public void encryptDecryptWithBadSignature() { + for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { + // Only test with crypto algs without commitment, since those + // are the only ones we can encrypt with + if (cryptoAlg.getMessageFormatVersion() != 1) { + continue; + } + + if (cryptoAlg.getTrailingSignatureAlgo() == null) { + continue; + } + final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); + + for (int i = 0; i < frameSizeToTest.length; i++) { + final int frameSize = frameSizeToTest[i]; + int[] bytesToTest = { + 0, + 1, + frameSize - 1, + frameSize, + frameSize + 1, + (int) (frameSize * 1.5), + frameSize * 2, + 1000000 + }; - @Test - public void encryptDecryptWithParsedCiphertext() { - for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { - final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); - - for (int i = 0; i < frameSizeToTest.length; i++) { - final int frameSize = frameSizeToTest[i]; - int[] bytesToTest = { 0, 1, frameSize - 1, frameSize, frameSize + 1, (int) (frameSize * 1.5), - frameSize * 2, 1000000 }; - - for (int j = 0; j < bytesToTest.length; j++) { - final int byteSize = bytesToTest[j]; + for (int j = 0; j < bytesToTest.length; j++) { + final int byteSize = bytesToTest[j]; - if (byteSize > 500_000 && isFastTestSuiteActive()) { - continue; - } + if (byteSize > 500_000 && isFastTestSuiteActive()) { + continue; + } - if (byteSize >= 0) { - doEncryptDecryptWithParsedCiphertext(cryptoAlg, byteSize, frameSize); - } - } - } + if (byteSize >= 0) { + doTamperedEncryptDecrypt(cryptoAlg, byteSize, frameSize); + } } + } } - - @Test - public void encryptDecryptWithCustomManager() throws Exception { - boolean[] didDecrypt = new boolean[] { false }; - - CryptoMaterialsManager manager = new CryptoMaterialsManager() { - @Override public EncryptionMaterials getMaterialsForEncrypt( - EncryptionMaterialsRequest request - ) { - request = request.toBuilder().setContext(singletonMap("foo", "bar")).build(); - - EncryptionMaterials encryptionMaterials = new DefaultCryptoMaterialsManager(masterKeyProvider) - .getMaterialsForEncrypt(request); - - return encryptionMaterials; - } - - @Override public DecryptionMaterials decryptMaterials( - DecryptionMaterialsRequest request - ) { - didDecrypt[0] = true; - return new DefaultCryptoMaterialsManager(masterKeyProvider).decryptMaterials(request); - } + } + + @Test + public void encryptDecryptWithTruncatedCiphertext() { + for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { + // Only test with crypto algs without commitment, since those + // are the only ones we can encrypt with + if (cryptoAlg.getMessageFormatVersion() != 1) { + continue; + } + + final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); + + for (int i = 0; i < frameSizeToTest.length; i++) { + final int frameSize = frameSizeToTest[i]; + int[] bytesToTest = { + 0, + 1, + frameSize - 1, + frameSize, + frameSize + 1, + (int) (frameSize * 1.5), + frameSize * 2, + 1000000 }; - byte[] plaintext = new byte[100]; - CryptoResult ciphertext = encryptionClient_.encryptData(manager, plaintext); - assertEquals("bar", ciphertext.getEncryptionContext().get("foo")); - - // TODO decrypt - assertFalse(didDecrypt[0]); - CryptoResult plaintextResult = encryptionClient_.decryptData(manager, ciphertext.getResult()); - assertArrayEquals(plaintext, plaintextResult.getResult()); - assertTrue(didDecrypt[0]); - } - - @Test - public void whenCustomCMMIgnoresAlgorithm_throws() throws Exception { - boolean[] didDecrypt = new boolean[] { false }; - - CryptoMaterialsManager manager = new CryptoMaterialsManager() { - @Override public EncryptionMaterials getMaterialsForEncrypt( - EncryptionMaterialsRequest request - ) { - request = request.toBuilder().setRequestedAlgorithm(null).build(); + for (int j = 0; j < bytesToTest.length; j++) { + final int byteSize = bytesToTest[j]; - EncryptionMaterials encryptionMaterials = new DefaultCryptoMaterialsManager(masterKeyProvider) - .getMaterialsForEncrypt(request); + if (byteSize > 500_000) { + continue; + } - return encryptionMaterials; - } - - @Override public DecryptionMaterials decryptMaterials( - DecryptionMaterialsRequest request - ) { - didDecrypt[0] = true; - return new DefaultCryptoMaterialsManager(masterKeyProvider).decryptMaterials(request); - } - }; - - encryptionClient_.setEncryptionAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY); - - byte[] plaintext = new byte[100]; - assertThrows(AwsCryptoException.class, - () -> encryptionClient_.encryptData(manager, plaintext)); - assertThrows(AwsCryptoException.class, - () -> encryptionClient_.estimateCiphertextSize(manager, 12345)); - assertThrows(AwsCryptoException.class, - () -> encryptionClient_.createEncryptingStream(manager, new ByteArrayOutputStream()).write(0)); - assertThrows(AwsCryptoException.class, - () -> encryptionClient_.createEncryptingStream(manager, new ByteArrayInputStream(new byte[1024*1024])).read()); + if (byteSize >= 0) { + doTruncatedEncryptDecrypt(cryptoAlg, byteSize, frameSize); + } + } + } } - - @Test - public void whenCustomCMMUsesCommittingAlgorithmWithForbidPolicy_throws() throws Exception { - CryptoMaterialsManager manager = new CryptoMaterialsManager() { - @Override public EncryptionMaterials getMaterialsForEncrypt( - EncryptionMaterialsRequest request - ) { - EncryptionMaterials encryptionMaterials = new DefaultCryptoMaterialsManager(masterKeyProvider) - .getMaterialsForEncrypt(request); - - return encryptionMaterials.toBuilder() - .setAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384) - .build(); - } - - @Override public DecryptionMaterials decryptMaterials( - DecryptionMaterialsRequest request - ) { - return new DefaultCryptoMaterialsManager(masterKeyProvider).decryptMaterials(request); - } + } + + @Test + public void encryptDecryptWithParsedCiphertext() { + for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { + final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); + + for (int i = 0; i < frameSizeToTest.length; i++) { + final int frameSize = frameSizeToTest[i]; + int[] bytesToTest = { + 0, + 1, + frameSize - 1, + frameSize, + frameSize + 1, + (int) (frameSize * 1.5), + frameSize * 2, + 1000000 }; - // create client with null encryption algorithm and ForbidEncrypt policy - final AwsCrypto client = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt).build(); - - byte[] plaintext = new byte[100]; - assertThrows(AwsCryptoException.class, - () -> client.encryptData(manager, plaintext)); - assertThrows(AwsCryptoException.class, - () -> client.estimateCiphertextSize(manager, 12345)); - assertThrows(AwsCryptoException.class, - () -> client.createEncryptingStream(manager, new ByteArrayOutputStream()).write(0)); - assertThrows(AwsCryptoException.class, - () -> client.createEncryptingStream(manager, new ByteArrayInputStream(new byte[1024*1024])).read()); - } - - @Test - public void whenDecrypting_invokesMKPOnce() throws Exception { - byte[] data = encryptionClient_.encryptData(masterKeyProvider, new byte[1]).getResult(); - - reset(masterKeyProvider); - - encryptionClient_.decryptData(masterKeyProvider, data); - - verify(masterKeyProvider, times(1)).decryptDataKey(any(), any(), any()); - } + for (int j = 0; j < bytesToTest.length; j++) { + final int byteSize = bytesToTest[j]; - private void doEstimateCiphertextSize(final CryptoAlgorithm cryptoAlg, final int inLen, final int frameSize) { - final byte[] plaintext = TestIOUtils.generateRandomPlaintext(inLen); + if (byteSize > 500_000 && isFastTestSuiteActive()) { + continue; + } - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC1", "Ciphertext size estimation test with " + inLen); - - AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; - client.setEncryptionAlgorithm(cryptoAlg); - client.setEncryptionFrameSize(frameSize); - - final long estimatedCiphertextSize = client.estimateCiphertextSize( - masterKeyProvider, - inLen, - encryptionContext); - final byte[] cipherText = client.encryptData(masterKeyProvider, plaintext, - encryptionContext).getResult(); - - // The estimate should be close (within 16 bytes) and never less than reality - final String errMsg = "Bad estimation for " + cryptoAlg + " expected: <" + estimatedCiphertextSize - + "> but was: <" + cipherText.length + ">"; - assertTrue(errMsg, estimatedCiphertextSize - cipherText.length >= 0); - assertTrue(errMsg, estimatedCiphertextSize - cipherText.length <= 16); - } - - @Test - public void estimateCiphertextSize() { - for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { - // Only test with crypto algs without commitment, since those - // are the only ones we can encrypt with - if (cryptoAlg.getMessageFormatVersion() != 1) { - continue; - } - - final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); - - for (int i = 0; i < frameSizeToTest.length; i++) { - final int frameSize = frameSizeToTest[i]; - int[] bytesToTest = { 0, 1, frameSize - 1, frameSize, frameSize + 1, (int) (frameSize * 1.5), - frameSize * 2, 1000000 }; - - for (int j = 0; j < bytesToTest.length; j++) { - final int byteSize = bytesToTest[j]; - - if (byteSize > 500_000 && isFastTestSuiteActive()) { - continue; - } - - if (byteSize >= 0) { - doEstimateCiphertextSize(cryptoAlg, byteSize, frameSize); - } - } - } + if (byteSize >= 0) { + doEncryptDecryptWithParsedCiphertext(cryptoAlg, byteSize, frameSize); + } } + } } + } + + @Test + public void encryptDecryptWithCustomManager() throws Exception { + boolean[] didDecrypt = new boolean[] {false}; + + CryptoMaterialsManager manager = + new CryptoMaterialsManager() { + @Override + public EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request) { + request = request.toBuilder().setContext(singletonMap("foo", "bar")).build(); + + EncryptionMaterials encryptionMaterials = + new DefaultCryptoMaterialsManager(masterKeyProvider) + .getMaterialsForEncrypt(request); + + return encryptionMaterials; + } + + @Override + public DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request) { + didDecrypt[0] = true; + return new DefaultCryptoMaterialsManager(masterKeyProvider).decryptMaterials(request); + } + }; - @Test - public void estimateCiphertextSizeWithoutEncContext() { - final int inLen = 1000000; - final byte[] plaintext = TestIOUtils.generateRandomPlaintext(inLen); - - encryptionClient_.setEncryptionFrameSize(AwsCrypto.getDefaultFrameSize()); - - final long estimatedCiphertextSize = encryptionClient_.estimateCiphertextSize(masterKeyProvider, inLen); - final byte[] cipherText = encryptionClient_.encryptData(masterKeyProvider, plaintext).getResult(); - - final String errMsg = "Bad estimation expected: <" + estimatedCiphertextSize - + "> but was: <" + cipherText.length + ">"; - assertTrue(errMsg, estimatedCiphertextSize - cipherText.length >= 0); - assertTrue(errMsg, estimatedCiphertextSize - cipherText.length <= 16); - } + byte[] plaintext = new byte[100]; + CryptoResult ciphertext = encryptionClient_.encryptData(manager, plaintext); + assertEquals("bar", ciphertext.getEncryptionContext().get("foo")); + + // TODO decrypt + assertFalse(didDecrypt[0]); + CryptoResult plaintextResult = + encryptionClient_.decryptData(manager, ciphertext.getResult()); + assertArrayEquals(plaintext, plaintextResult.getResult()); + assertTrue(didDecrypt[0]); + } + + @Test + public void whenCustomCMMIgnoresAlgorithm_throws() throws Exception { + boolean[] didDecrypt = new boolean[] {false}; + + CryptoMaterialsManager manager = + new CryptoMaterialsManager() { + @Override + public EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request) { + request = request.toBuilder().setRequestedAlgorithm(null).build(); + + EncryptionMaterials encryptionMaterials = + new DefaultCryptoMaterialsManager(masterKeyProvider) + .getMaterialsForEncrypt(request); + + return encryptionMaterials; + } + + @Override + public DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request) { + didDecrypt[0] = true; + return new DefaultCryptoMaterialsManager(masterKeyProvider).decryptMaterials(request); + } + }; - @Test - public void estimateCiphertextSize_usesCachedKeys() throws Exception { - // Make sure estimateCiphertextSize works with cached CMMs - CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(masterKeyProvider)); - - CachingCryptoMaterialsManager cache = CachingCryptoMaterialsManager.newBuilder() - .withBackingMaterialsManager(cmm) - .withMaxAge(Long.MAX_VALUE, TimeUnit.SECONDS) - .withCache(new LocalCryptoMaterialsCache(1)) - .withMessageUseLimit(9999) - .withByteUseLimit(501) + encryptionClient_.setEncryptionAlgorithm( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY); + + byte[] plaintext = new byte[100]; + assertThrows(AwsCryptoException.class, () -> encryptionClient_.encryptData(manager, plaintext)); + assertThrows( + AwsCryptoException.class, () -> encryptionClient_.estimateCiphertextSize(manager, 12345)); + assertThrows( + AwsCryptoException.class, + () -> + encryptionClient_ + .createEncryptingStream(manager, new ByteArrayOutputStream()) + .write(0)); + assertThrows( + AwsCryptoException.class, + () -> + encryptionClient_ + .createEncryptingStream(manager, new ByteArrayInputStream(new byte[1024 * 1024])) + .read()); + } + + @Test + public void whenCustomCMMUsesCommittingAlgorithmWithForbidPolicy_throws() throws Exception { + CryptoMaterialsManager manager = + new CryptoMaterialsManager() { + @Override + public EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request) { + EncryptionMaterials encryptionMaterials = + new DefaultCryptoMaterialsManager(masterKeyProvider) + .getMaterialsForEncrypt(request); + + return encryptionMaterials.toBuilder() + .setAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384) .build(); + } - // These estimates should be cached, and should not consume any bytes from the byte use limit. - encryptionClient_.estimateCiphertextSize(cache, 500, new HashMap<>()); - encryptionClient_.estimateCiphertextSize(cache, 500, new HashMap<>()); - - encryptionClient_.encryptData(cache, new byte[500]); - - verify(cmm, times(1)).getMaterialsForEncrypt(any()); - } - - @Test - public void encryptDecryptWithoutEncContext() { - final int ptSize = 1000000; // 1MB - final byte[] plaintextBytes = TestIOUtils.generateRandomPlaintext(ptSize); - - final byte[] cipherText = encryptionClient_.encryptData(masterKeyProvider, plaintextBytes).getResult(); - final byte[] decryptedText = encryptionClient_.decryptData( - masterKeyProvider, - cipherText).getResult(); - - assertArrayEquals(plaintextBytes, decryptedText); - } - - @Test - public void encryptDecryptString() { - final int ptSize = 1000000; // 1MB - final String plaintextString = TestIOUtils.generateRandomString(ptSize); - - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC1", "Test Encryption Context"); - - final String ciphertext = encryptionClient_.encryptString( - masterKeyProvider, - plaintextString, - encryptionContext).getResult(); - final String decryptedText = encryptionClient_.decryptString( - masterKeyProvider, - ciphertext).getResult(); - - assertEquals(plaintextString, decryptedText); - } - - @Test - public void encryptDecryptStringWithoutEncContext() { - final int ptSize = 1000000; // 1MB - final String plaintextString = TestIOUtils.generateRandomString(ptSize); - - final String cipherText = encryptionClient_.encryptString(masterKeyProvider, plaintextString).getResult(); - final String decryptedText = encryptionClient_.decryptString( - masterKeyProvider, - cipherText).getResult(); - - assertEquals(plaintextString, decryptedText); - } - - @Test - public void encryptBytesDecryptString() { - final int ptSize = 1000000; // 1MB - final String plaintext = TestIOUtils.generateRandomString(ptSize); - - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC1", "Test Encryption Context"); - - final byte[] cipherText = encryptionClient_.encryptData( - masterKeyProvider, - plaintext.getBytes(StandardCharsets.UTF_8), - encryptionContext).getResult(); - final String decryptedText = encryptionClient_.decryptString( - masterKeyProvider, - Utils.encodeBase64String(cipherText)).getResult(); - - assertEquals(plaintext, decryptedText); - } - - @Test - public void encryptStringDecryptBytes() { - final int ptSize = 1000000; // 1MB - final byte[] plaintextBytes = TestIOUtils.generateRandomPlaintext(ptSize); - final String plaintextString = new String(plaintextBytes, StandardCharsets.UTF_8); - - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC1", "Test Encryption Context"); - - final String ciphertext = encryptionClient_.encryptString( - masterKeyProvider, - plaintextString, - encryptionContext).getResult(); - final byte[] decryptedText = encryptionClient_.decryptData( - masterKeyProvider, - Utils.decodeBase64String(ciphertext)).getResult(); - - assertArrayEquals(plaintextString.getBytes(StandardCharsets.UTF_8), decryptedText); - } - - @Test - public void emptyEncryptionContext() { - final int ptSize = 1000000; // 1MB - final byte[] plaintextBytes = TestIOUtils.generateRandomPlaintext(ptSize); - - final Map encryptionContext = new HashMap(0); - - final byte[] cipherText = encryptionClient_.encryptData( - masterKeyProvider, - plaintextBytes, - encryptionContext).getResult(); - final byte[] decryptedText = encryptionClient_.decryptData( - masterKeyProvider, - cipherText).getResult(); - - assertArrayEquals(plaintextBytes, decryptedText); - } - - @Test - public void decryptMessageWithKeyCommitment() { - final byte[] cipherText = Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64); - JceMasterKey masterKey = TestUtils.messageWithCommitKeyMasterKey; - final CryptoResult decryptedText = encryptionClient_.decryptData(masterKey, cipherText); - - assertEquals(TestUtils.messageWithCommitKeyCryptoAlgorithm, decryptedText.getCryptoAlgorithm()); - assertArrayEquals(Utils.decodeBase64String(TestUtils.messageWithCommitKeyMessageIdBase64), decryptedText.getHeaders().getMessageId()); - assertArrayEquals(Utils.decodeBase64String(TestUtils.messageWithCommitKeyCommitmentBase64), decryptedText.getHeaders().getSuiteData()); - assertArrayEquals(TestUtils.messageWithCommitKeyExpectedResult.getBytes(), (byte[])decryptedText.getResult()); - } - - @Test - public void decryptMessageWithInvalidKeyCommitment() { - final byte[] cipherText = Utils.decodeBase64String(TestUtils.invalidMessageWithCommitKeyBase64); - JceMasterKey masterKey = TestUtils.invalidMessageWithCommitKeyMasterKey; - assertThrows(BadCiphertextException.class, "Key commitment validation failed. Key identity does not " + - "match the identity asserted in the message. Halting processing of this message.", - () -> encryptionClient_.decryptData(masterKey, cipherText)); - } - - // Test that all the parameters that aren't allowed to be null (i.e. all of them) result in immediate NPEs if - // invoked with null args - @Test - public void assertNullChecks() throws Exception { - byte[] buf = new byte[1]; - HashMap context = new HashMap<>(); - MasterKeyProvider provider = masterKeyProvider; - CryptoMaterialsManager cmm = new DefaultCryptoMaterialsManager(masterKeyProvider); - InputStream is = new ByteArrayInputStream(new byte[0]); - OutputStream os = new ByteArrayOutputStream(); - - byte[] ciphertext = encryptionClient_.encryptData(cmm, buf).getResult(); - String stringCiphertext = encryptionClient_.encryptString(cmm, "hello, world").getResult(); - - TestUtils.assertNullChecks(encryptionClient_, "estimateCiphertextSize", - MasterKeyProvider.class, provider, - Integer.TYPE, 42, - Map.class, context - ); - TestUtils.assertNullChecks(encryptionClient_, "estimateCiphertextSize", - CryptoMaterialsManager.class, cmm, - Integer.TYPE, 42, - Map.class, context - ); - TestUtils.assertNullChecks(encryptionClient_, "estimateCiphertextSize", - MasterKeyProvider.class, provider, - Integer.TYPE, 42 - ); - TestUtils.assertNullChecks(encryptionClient_, "estimateCiphertextSize", - CryptoMaterialsManager.class, cmm, - Integer.TYPE, 42 - ); - - TestUtils.assertNullChecks(encryptionClient_, "encryptData", - MasterKeyProvider.class, provider, - byte[].class, buf, - Map.class, context - ); - TestUtils.assertNullChecks(encryptionClient_, "encryptData", - CryptoMaterialsManager.class, cmm, - byte[].class, buf, - Map.class, context - ); - TestUtils.assertNullChecks(encryptionClient_, "encryptData", - MasterKeyProvider.class, provider, - byte[].class, buf - ); - TestUtils.assertNullChecks(encryptionClient_, "encryptData", - CryptoMaterialsManager.class, cmm, - byte[].class, buf - ); - TestUtils.assertNullChecks(encryptionClient_, "encryptString", - MasterKeyProvider.class, provider, - String.class, "", - Map.class, context - ); - TestUtils.assertNullChecks(encryptionClient_, "encryptString", - CryptoMaterialsManager.class, cmm, - String.class, "", - Map.class, context - ); - TestUtils.assertNullChecks(encryptionClient_, "encryptString", - MasterKeyProvider.class, provider, - String.class, "" - ); - TestUtils.assertNullChecks(encryptionClient_, "encryptString", - CryptoMaterialsManager.class, cmm, - String.class, "" - ); - - TestUtils.assertNullChecks(encryptionClient_, "decryptData", - MasterKeyProvider.class, provider, - byte[].class, ciphertext - ); - TestUtils.assertNullChecks(encryptionClient_, "decryptData", - CryptoMaterialsManager.class, cmm, - byte[].class, ciphertext - ); - TestUtils.assertNullChecks(encryptionClient_, "decryptData", - MasterKeyProvider.class, provider, - ParsedCiphertext.class, new ParsedCiphertext(ciphertext) - ); - TestUtils.assertNullChecks(encryptionClient_, "decryptData", - CryptoMaterialsManager.class, cmm, - ParsedCiphertext.class, new ParsedCiphertext(ciphertext) - ); - TestUtils.assertNullChecks(encryptionClient_, "decryptString", - MasterKeyProvider.class, provider, - String.class, stringCiphertext - ); - TestUtils.assertNullChecks(encryptionClient_, "decryptString", - CryptoMaterialsManager.class, cmm, - String.class, stringCiphertext - ); - - TestUtils.assertNullChecks(encryptionClient_, "createEncryptingStream", - MasterKeyProvider.class, provider, - OutputStream.class, os, - Map.class, context - ); - TestUtils.assertNullChecks(encryptionClient_, "createEncryptingStream", - CryptoMaterialsManager.class, cmm, - OutputStream.class, os, - Map.class, context - ); - - TestUtils.assertNullChecks(encryptionClient_, "createEncryptingStream", - MasterKeyProvider.class, provider, - OutputStream.class, os - ); - TestUtils.assertNullChecks(encryptionClient_, "createEncryptingStream", - CryptoMaterialsManager.class, cmm, - OutputStream.class, os - ); - - TestUtils.assertNullChecks(encryptionClient_, "createEncryptingStream", - MasterKeyProvider.class, provider, - InputStream.class, is, - Map.class, context - ); - TestUtils.assertNullChecks(encryptionClient_, "createEncryptingStream", - CryptoMaterialsManager.class, cmm, - InputStream.class, is, - Map.class, context - ); - - TestUtils.assertNullChecks(encryptionClient_, "createEncryptingStream", - MasterKeyProvider.class, provider, - InputStream.class, is - ); - TestUtils.assertNullChecks(encryptionClient_, "createEncryptingStream", - CryptoMaterialsManager.class, cmm, - InputStream.class, is - ); - - TestUtils.assertNullChecks(encryptionClient_, "createDecryptingStream", - MasterKeyProvider.class, provider, - OutputStream.class, os - ); - TestUtils.assertNullChecks(encryptionClient_, "createDecryptingStream", - CryptoMaterialsManager.class, cmm, - OutputStream.class, os - ); - - TestUtils.assertNullChecks(encryptionClient_, "createDecryptingStream", - MasterKeyProvider.class, provider, - InputStream.class, is - ); - TestUtils.assertNullChecks(encryptionClient_, "createDecryptingStream", - CryptoMaterialsManager.class, cmm, - InputStream.class, is - ); - } - - @Test - public void setValidFrameSize() throws IOException { - final int setFrameSize = TestUtils.DEFAULT_TEST_CRYPTO_ALG.getBlockSize() * 2; - encryptionClient_.setEncryptionFrameSize(setFrameSize); - - final int getFrameSize = encryptionClient_.getEncryptionFrameSize(); - - assertEquals(setFrameSize, getFrameSize); - } - - @Test - public void unalignedFrameSizesAreAccepted() throws IOException { - final int frameSize = TestUtils.DEFAULT_TEST_CRYPTO_ALG.getBlockSize() - 1; - encryptionClient_.setEncryptionFrameSize(frameSize); - - assertEquals(frameSize, encryptionClient_.getEncryptionFrameSize()); - } - - @Test(expected = IllegalArgumentException.class) - public void setNegativeFrameSize() throws IOException { - encryptionClient_.setEncryptionFrameSize(-1); - } - - @Test - public void setCryptoAlgorithm() throws IOException { - final CryptoAlgorithm setCryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; - encryptionClient_.setEncryptionAlgorithm(setCryptoAlgorithm); - - final CryptoAlgorithm getCryptoAlgorithm = encryptionClient_.getEncryptionAlgorithm(); - - assertEquals(setCryptoAlgorithm, getCryptoAlgorithm); - } - - @Test(expected = NullPointerException.class) - public void buildWithNullCommitmentPolicy() throws IOException { - AwsCrypto.builder().withCommitmentPolicy(null).build(); - } - - @Test - public void forbidAndSetCommittingCryptoAlgorithm() throws IOException { - final CryptoAlgorithm setCryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; - - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.builder() - .withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) - .build() - .setEncryptionAlgorithm(setCryptoAlgorithm)); - } - - @Test - public void requireAndSetNonCommittingCryptoAlgorithm() throws IOException { - final CryptoAlgorithm setCryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; - - // Default case - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.standard().setEncryptionAlgorithm(setCryptoAlgorithm)); - - // Test explicitly for every relevant policy - for (CommitmentPolicy policy : requireWriteCommitmentPolicies) { - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.builder() - .withCommitmentPolicy(policy) - .build() - .setEncryptionAlgorithm(setCryptoAlgorithm)); - - } - } - - @Test - public void forbidAndBuildWithCommittingCryptoAlgorithm() throws IOException { - final CryptoAlgorithm setCryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; - - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) - .withEncryptionAlgorithm(setCryptoAlgorithm) - .build()); - } - - @Test - public void requireAndBuildWithNonCommittingCryptoAlgorithm() throws IOException { - final CryptoAlgorithm setCryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; - - // Test default case - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.builder().withEncryptionAlgorithm(setCryptoAlgorithm).build()); - - // Test explicitly for every relevant policy - for (CommitmentPolicy policy : requireWriteCommitmentPolicies) { - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.builder() - .withCommitmentPolicy(policy) - .withEncryptionAlgorithm(setCryptoAlgorithm) - .build()); - } - } - - @Test - public void requireCommitmentOnDecryptFailsNonCommitting() throws IOException { - // Create non-committing ciphertext - forbidCommitmentClient_.setEncryptionAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384); - - final byte[] cipherText = forbidCommitmentClient_.encryptData( - masterKeyProvider, - new byte[1], - new HashMap<>()).getResult(); - - // Test explicit policy set - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.builder() - .withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) - .build() - .decryptData(masterKeyProvider, cipherText)); - - // Test default builder behavior - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.builder() - .build() - .decryptData(masterKeyProvider, cipherText)); - - // Test input stream - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.builder() - .build() - .createDecryptingStream(masterKeyProvider, new ByteArrayInputStream(cipherText)) - .read()); - - // Test output stream - assertThrows(AwsCryptoException.class, () -> - AwsCrypto.builder() - .build() - .createDecryptingStream(masterKeyProvider, new ByteArrayOutputStream()) - .write(cipherText)); - } - - @Test - public void whenCustomCMMUsesNonCommittingAlgorithmWithRequirePolicy_throws() throws Exception { - CryptoMaterialsManager manager = new CryptoMaterialsManager() { - @Override public EncryptionMaterials getMaterialsForEncrypt( - EncryptionMaterialsRequest request - ) { - EncryptionMaterials encryptionMaterials = new DefaultCryptoMaterialsManager(masterKeyProvider) - .getMaterialsForEncrypt(request); - - return encryptionMaterials.toBuilder() - .setAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384) - .build(); - } - - @Override public DecryptionMaterials decryptMaterials( - DecryptionMaterialsRequest request - ) { - return new DefaultCryptoMaterialsManager(masterKeyProvider).decryptMaterials(request); - } + @Override + public DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request) { + return new DefaultCryptoMaterialsManager(masterKeyProvider).decryptMaterials(request); + } }; + // create client with null encryption algorithm and ForbidEncrypt policy + final AwsCrypto client = + AwsCrypto.builder() + .withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .build(); + + byte[] plaintext = new byte[100]; + assertThrows(AwsCryptoException.class, () -> client.encryptData(manager, plaintext)); + assertThrows(AwsCryptoException.class, () -> client.estimateCiphertextSize(manager, 12345)); + assertThrows( + AwsCryptoException.class, + () -> client.createEncryptingStream(manager, new ByteArrayOutputStream()).write(0)); + assertThrows( + AwsCryptoException.class, + () -> + client + .createEncryptingStream(manager, new ByteArrayInputStream(new byte[1024 * 1024])) + .read()); + } + + @Test + public void whenDecrypting_invokesMKPOnce() throws Exception { + byte[] data = encryptionClient_.encryptData(masterKeyProvider, new byte[1]).getResult(); + + reset(masterKeyProvider); + + encryptionClient_.decryptData(masterKeyProvider, data); + + verify(masterKeyProvider, times(1)).decryptDataKey(any(), any(), any()); + } + + private void doEstimateCiphertextSize( + final CryptoAlgorithm cryptoAlg, final int inLen, final int frameSize) { + final byte[] plaintext = TestIOUtils.generateRandomPlaintext(inLen); + + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC1", "Ciphertext size estimation test with " + inLen); + + AwsCrypto client = cryptoAlg.isCommitting() ? encryptionClient_ : forbidCommitmentClient_; + client.setEncryptionAlgorithm(cryptoAlg); + client.setEncryptionFrameSize(frameSize); + + final long estimatedCiphertextSize = + client.estimateCiphertextSize(masterKeyProvider, inLen, encryptionContext); + final byte[] cipherText = + client.encryptData(masterKeyProvider, plaintext, encryptionContext).getResult(); + + // The estimate should be close (within 16 bytes) and never less than reality + final String errMsg = + "Bad estimation for " + + cryptoAlg + + " expected: <" + + estimatedCiphertextSize + + "> but was: <" + + cipherText.length + + ">"; + assertTrue(errMsg, estimatedCiphertextSize - cipherText.length >= 0); + assertTrue(errMsg, estimatedCiphertextSize - cipherText.length <= 16); + } + + @Test + public void estimateCiphertextSize() { + for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { + // Only test with crypto algs without commitment, since those + // are the only ones we can encrypt with + if (cryptoAlg.getMessageFormatVersion() != 1) { + continue; + } + + final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); + + for (int i = 0; i < frameSizeToTest.length; i++) { + final int frameSize = frameSizeToTest[i]; + int[] bytesToTest = { + 0, + 1, + frameSize - 1, + frameSize, + frameSize + 1, + (int) (frameSize * 1.5), + frameSize * 2, + 1000000 + }; - for (CommitmentPolicy policy : requireWriteCommitmentPolicies) { - // create client with null encryption algorithm and a policy that requires encryption - final AwsCrypto client = AwsCrypto.builder().withCommitmentPolicy(policy).build(); - - byte[] plaintext = new byte[100]; - assertThrows(AwsCryptoException.class, - () -> client.encryptData(manager, plaintext)); - assertThrows(AwsCryptoException.class, - () -> client.estimateCiphertextSize(manager, 12345)); - assertThrows(AwsCryptoException.class, - () -> client.createEncryptingStream(manager, new ByteArrayOutputStream()).write(0)); - assertThrows(AwsCryptoException.class, - () -> client.createEncryptingStream(manager, new ByteArrayInputStream(new byte[1024 * 1024])).read()); - } - } - - @Test - public void testDecryptMessageWithInvalidCommitment() { - for (final CryptoAlgorithm cryptoAlg : CryptoAlgorithm.values()) { - if (!cryptoAlg.isCommitting()) { - continue; - } - final Map encryptionContext = new HashMap(1); - encryptionContext.put("Commitment", "Commitment test for %s" + cryptoAlg); - encryptionClient_.setEncryptionAlgorithm(cryptoAlg); - byte[] plaintextBytes = new byte[16]; // Actual content doesn't matter - final byte[] cipherText = encryptionClient_.encryptData( - masterKeyProvider, - plaintextBytes, - encryptionContext).getResult(); - - // Find the commitment value - ParsedCiphertext parsed = new ParsedCiphertext(cipherText); - final int headerLength = parsed.getOffset(); - // The commitment value is immediately prior to the header tag for v2 encrypted messages - final int endOfCommitment = headerLength - parsed.getHeaderTag().length; - // The commitment is 32 bytes long, but if we just index one back from the endOfCommitment we know - // that we are within it. - cipherText[endOfCommitment - 1] ^= 0x01; // Tamper with the commitment value - - // Since commitment is verified prior to the header tag, we don't need to worry about actually - // creating a colliding tag but can just verify that the exception indicates an incorrect commitment - // value. - assertThrows(BadCiphertextException.class, "Key commitment validation failed. Key identity does " + - "not match the identity asserted in the message. Halting processing of this message.", - () -> encryptionClient_.decryptData(masterKeyProvider, cipherText)); - } - } - - @Test(expected = IllegalArgumentException.class) - public void setNegativeMaxEdks() { - AwsCrypto.builder().withMaxEncryptedDataKeys(-1); - } - - @Test(expected = IllegalArgumentException.class) - public void setZeroMaxEdks() { - AwsCrypto.builder().withMaxEncryptedDataKeys(0); - } + for (int j = 0; j < bytesToTest.length; j++) { + final int byteSize = bytesToTest[j]; - @Test - public void setValidMaxEdks() { - for (final int i : new int[]{1, 10, MESSAGE_FORMAT_MAX_EDKS, MESSAGE_FORMAT_MAX_EDKS + 1, Integer.MAX_VALUE}) { - AwsCrypto.builder().withMaxEncryptedDataKeys(i); - } - } + if (byteSize > 500_000 && isFastTestSuiteActive()) { + continue; + } - private MasterKeyProvider providerWithEdks(int numEdks) { - List> providers = new ArrayList<>(); - for (int i = 0; i < numEdks; i++) { - providers.add(masterKeyProvider); + if (byteSize >= 0) { + doEstimateCiphertextSize(cryptoAlg, byteSize, frameSize); + } } - return MultipleProviderFactory.buildMultiProvider(providers); - } - - @Test - public void encryptDecryptWithLessThanMaxEdks() { - MasterKeyProvider provider = providerWithEdks(2); - CryptoResult result = maxEdksClient_.encryptData(provider, new byte[] {1}); - ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult()); - assertEquals(ciphertext.getEncryptedKeyBlobCount(), 2); - maxEdksClient_.decryptData(provider, ciphertext); + } } - - @Test - public void encryptDecryptWithMaxEdks() { - MasterKeyProvider provider = providerWithEdks(3); - CryptoResult result = maxEdksClient_.encryptData(provider, new byte[] {1}); - ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult()); - assertEquals(ciphertext.getEncryptedKeyBlobCount(), 3); - maxEdksClient_.decryptData(provider, ciphertext); + } + + @Test + public void estimateCiphertextSizeWithoutEncContext() { + final int inLen = 1000000; + final byte[] plaintext = TestIOUtils.generateRandomPlaintext(inLen); + + encryptionClient_.setEncryptionFrameSize(AwsCrypto.getDefaultFrameSize()); + + final long estimatedCiphertextSize = + encryptionClient_.estimateCiphertextSize(masterKeyProvider, inLen); + final byte[] cipherText = + encryptionClient_.encryptData(masterKeyProvider, plaintext).getResult(); + + final String errMsg = + "Bad estimation expected: <" + + estimatedCiphertextSize + + "> but was: <" + + cipherText.length + + ">"; + assertTrue(errMsg, estimatedCiphertextSize - cipherText.length >= 0); + assertTrue(errMsg, estimatedCiphertextSize - cipherText.length <= 16); + } + + @Test + public void estimateCiphertextSize_usesCachedKeys() throws Exception { + // Make sure estimateCiphertextSize works with cached CMMs + CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(masterKeyProvider)); + + CachingCryptoMaterialsManager cache = + CachingCryptoMaterialsManager.newBuilder() + .withBackingMaterialsManager(cmm) + .withMaxAge(Long.MAX_VALUE, TimeUnit.SECONDS) + .withCache(new LocalCryptoMaterialsCache(1)) + .withMessageUseLimit(9999) + .withByteUseLimit(501) + .build(); + + // These estimates should be cached, and should not consume any bytes from the byte use limit. + encryptionClient_.estimateCiphertextSize(cache, 500, new HashMap<>()); + encryptionClient_.estimateCiphertextSize(cache, 500, new HashMap<>()); + + encryptionClient_.encryptData(cache, new byte[500]); + + verify(cmm, times(1)).getMaterialsForEncrypt(any()); + } + + @Test + public void encryptDecryptWithoutEncContext() { + final int ptSize = 1000000; // 1MB + final byte[] plaintextBytes = TestIOUtils.generateRandomPlaintext(ptSize); + + final byte[] cipherText = + encryptionClient_.encryptData(masterKeyProvider, plaintextBytes).getResult(); + final byte[] decryptedText = + encryptionClient_.decryptData(masterKeyProvider, cipherText).getResult(); + + assertArrayEquals(plaintextBytes, decryptedText); + } + + @Test + public void encryptDecryptString() { + final int ptSize = 1000000; // 1MB + final String plaintextString = TestIOUtils.generateRandomString(ptSize); + + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC1", "Test Encryption Context"); + + final String ciphertext = + encryptionClient_ + .encryptString(masterKeyProvider, plaintextString, encryptionContext) + .getResult(); + final String decryptedText = + encryptionClient_.decryptString(masterKeyProvider, ciphertext).getResult(); + + assertEquals(plaintextString, decryptedText); + } + + @Test + public void encryptDecryptStringWithoutEncContext() { + final int ptSize = 1000000; // 1MB + final String plaintextString = TestIOUtils.generateRandomString(ptSize); + + final String cipherText = + encryptionClient_.encryptString(masterKeyProvider, plaintextString).getResult(); + final String decryptedText = + encryptionClient_.decryptString(masterKeyProvider, cipherText).getResult(); + + assertEquals(plaintextString, decryptedText); + } + + @Test + public void encryptBytesDecryptString() { + final int ptSize = 1000000; // 1MB + final String plaintext = TestIOUtils.generateRandomString(ptSize); + + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC1", "Test Encryption Context"); + + final byte[] cipherText = + encryptionClient_ + .encryptData( + masterKeyProvider, plaintext.getBytes(StandardCharsets.UTF_8), encryptionContext) + .getResult(); + final String decryptedText = + encryptionClient_ + .decryptString(masterKeyProvider, Utils.encodeBase64String(cipherText)) + .getResult(); + + assertEquals(plaintext, decryptedText); + } + + @Test + public void encryptStringDecryptBytes() { + final int ptSize = 1000000; // 1MB + final byte[] plaintextBytes = TestIOUtils.generateRandomPlaintext(ptSize); + final String plaintextString = new String(plaintextBytes, StandardCharsets.UTF_8); + + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC1", "Test Encryption Context"); + + final String ciphertext = + encryptionClient_ + .encryptString(masterKeyProvider, plaintextString, encryptionContext) + .getResult(); + final byte[] decryptedText = + encryptionClient_ + .decryptData(masterKeyProvider, Utils.decodeBase64String(ciphertext)) + .getResult(); + + assertArrayEquals(plaintextString.getBytes(StandardCharsets.UTF_8), decryptedText); + } + + @Test + public void emptyEncryptionContext() { + final int ptSize = 1000000; // 1MB + final byte[] plaintextBytes = TestIOUtils.generateRandomPlaintext(ptSize); + + final Map encryptionContext = new HashMap(0); + + final byte[] cipherText = + encryptionClient_ + .encryptData(masterKeyProvider, plaintextBytes, encryptionContext) + .getResult(); + final byte[] decryptedText = + encryptionClient_.decryptData(masterKeyProvider, cipherText).getResult(); + + assertArrayEquals(plaintextBytes, decryptedText); + } + + @Test + public void decryptMessageWithKeyCommitment() { + final byte[] cipherText = Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64); + JceMasterKey masterKey = TestUtils.messageWithCommitKeyMasterKey; + final CryptoResult decryptedText = encryptionClient_.decryptData(masterKey, cipherText); + + assertEquals(TestUtils.messageWithCommitKeyCryptoAlgorithm, decryptedText.getCryptoAlgorithm()); + assertArrayEquals( + Utils.decodeBase64String(TestUtils.messageWithCommitKeyMessageIdBase64), + decryptedText.getHeaders().getMessageId()); + assertArrayEquals( + Utils.decodeBase64String(TestUtils.messageWithCommitKeyCommitmentBase64), + decryptedText.getHeaders().getSuiteData()); + assertArrayEquals( + TestUtils.messageWithCommitKeyExpectedResult.getBytes(), + (byte[]) decryptedText.getResult()); + } + + @Test + public void decryptMessageWithInvalidKeyCommitment() { + final byte[] cipherText = Utils.decodeBase64String(TestUtils.invalidMessageWithCommitKeyBase64); + JceMasterKey masterKey = TestUtils.invalidMessageWithCommitKeyMasterKey; + assertThrows( + BadCiphertextException.class, + "Key commitment validation failed. Key identity does not " + + "match the identity asserted in the message. Halting processing of this message.", + () -> encryptionClient_.decryptData(masterKey, cipherText)); + } + + // Test that all the parameters that aren't allowed to be null (i.e. all of them) result in + // immediate NPEs if + // invoked with null args + @Test + public void assertNullChecks() throws Exception { + byte[] buf = new byte[1]; + HashMap context = new HashMap<>(); + MasterKeyProvider provider = masterKeyProvider; + CryptoMaterialsManager cmm = new DefaultCryptoMaterialsManager(masterKeyProvider); + InputStream is = new ByteArrayInputStream(new byte[0]); + OutputStream os = new ByteArrayOutputStream(); + + byte[] ciphertext = encryptionClient_.encryptData(cmm, buf).getResult(); + String stringCiphertext = encryptionClient_.encryptString(cmm, "hello, world").getResult(); + + TestUtils.assertNullChecks( + encryptionClient_, + "estimateCiphertextSize", + MasterKeyProvider.class, + provider, + Integer.TYPE, + 42, + Map.class, + context); + TestUtils.assertNullChecks( + encryptionClient_, + "estimateCiphertextSize", + CryptoMaterialsManager.class, + cmm, + Integer.TYPE, + 42, + Map.class, + context); + TestUtils.assertNullChecks( + encryptionClient_, + "estimateCiphertextSize", + MasterKeyProvider.class, + provider, + Integer.TYPE, + 42); + TestUtils.assertNullChecks( + encryptionClient_, + "estimateCiphertextSize", + CryptoMaterialsManager.class, + cmm, + Integer.TYPE, + 42); + + TestUtils.assertNullChecks( + encryptionClient_, + "encryptData", + MasterKeyProvider.class, + provider, + byte[].class, + buf, + Map.class, + context); + TestUtils.assertNullChecks( + encryptionClient_, + "encryptData", + CryptoMaterialsManager.class, + cmm, + byte[].class, + buf, + Map.class, + context); + TestUtils.assertNullChecks( + encryptionClient_, "encryptData", MasterKeyProvider.class, provider, byte[].class, buf); + TestUtils.assertNullChecks( + encryptionClient_, "encryptData", CryptoMaterialsManager.class, cmm, byte[].class, buf); + TestUtils.assertNullChecks( + encryptionClient_, + "encryptString", + MasterKeyProvider.class, + provider, + String.class, + "", + Map.class, + context); + TestUtils.assertNullChecks( + encryptionClient_, + "encryptString", + CryptoMaterialsManager.class, + cmm, + String.class, + "", + Map.class, + context); + TestUtils.assertNullChecks( + encryptionClient_, "encryptString", MasterKeyProvider.class, provider, String.class, ""); + TestUtils.assertNullChecks( + encryptionClient_, "encryptString", CryptoMaterialsManager.class, cmm, String.class, ""); + + TestUtils.assertNullChecks( + encryptionClient_, + "decryptData", + MasterKeyProvider.class, + provider, + byte[].class, + ciphertext); + TestUtils.assertNullChecks( + encryptionClient_, + "decryptData", + CryptoMaterialsManager.class, + cmm, + byte[].class, + ciphertext); + TestUtils.assertNullChecks( + encryptionClient_, + "decryptData", + MasterKeyProvider.class, + provider, + ParsedCiphertext.class, + new ParsedCiphertext(ciphertext)); + TestUtils.assertNullChecks( + encryptionClient_, + "decryptData", + CryptoMaterialsManager.class, + cmm, + ParsedCiphertext.class, + new ParsedCiphertext(ciphertext)); + TestUtils.assertNullChecks( + encryptionClient_, + "decryptString", + MasterKeyProvider.class, + provider, + String.class, + stringCiphertext); + TestUtils.assertNullChecks( + encryptionClient_, + "decryptString", + CryptoMaterialsManager.class, + cmm, + String.class, + stringCiphertext); + + TestUtils.assertNullChecks( + encryptionClient_, + "createEncryptingStream", + MasterKeyProvider.class, + provider, + OutputStream.class, + os, + Map.class, + context); + TestUtils.assertNullChecks( + encryptionClient_, + "createEncryptingStream", + CryptoMaterialsManager.class, + cmm, + OutputStream.class, + os, + Map.class, + context); + + TestUtils.assertNullChecks( + encryptionClient_, + "createEncryptingStream", + MasterKeyProvider.class, + provider, + OutputStream.class, + os); + TestUtils.assertNullChecks( + encryptionClient_, + "createEncryptingStream", + CryptoMaterialsManager.class, + cmm, + OutputStream.class, + os); + + TestUtils.assertNullChecks( + encryptionClient_, + "createEncryptingStream", + MasterKeyProvider.class, + provider, + InputStream.class, + is, + Map.class, + context); + TestUtils.assertNullChecks( + encryptionClient_, + "createEncryptingStream", + CryptoMaterialsManager.class, + cmm, + InputStream.class, + is, + Map.class, + context); + + TestUtils.assertNullChecks( + encryptionClient_, + "createEncryptingStream", + MasterKeyProvider.class, + provider, + InputStream.class, + is); + TestUtils.assertNullChecks( + encryptionClient_, + "createEncryptingStream", + CryptoMaterialsManager.class, + cmm, + InputStream.class, + is); + + TestUtils.assertNullChecks( + encryptionClient_, + "createDecryptingStream", + MasterKeyProvider.class, + provider, + OutputStream.class, + os); + TestUtils.assertNullChecks( + encryptionClient_, + "createDecryptingStream", + CryptoMaterialsManager.class, + cmm, + OutputStream.class, + os); + + TestUtils.assertNullChecks( + encryptionClient_, + "createDecryptingStream", + MasterKeyProvider.class, + provider, + InputStream.class, + is); + TestUtils.assertNullChecks( + encryptionClient_, + "createDecryptingStream", + CryptoMaterialsManager.class, + cmm, + InputStream.class, + is); + } + + @Test + public void setValidFrameSize() throws IOException { + final int setFrameSize = TestUtils.DEFAULT_TEST_CRYPTO_ALG.getBlockSize() * 2; + encryptionClient_.setEncryptionFrameSize(setFrameSize); + + final int getFrameSize = encryptionClient_.getEncryptionFrameSize(); + + assertEquals(setFrameSize, getFrameSize); + } + + @Test + public void unalignedFrameSizesAreAccepted() throws IOException { + final int frameSize = TestUtils.DEFAULT_TEST_CRYPTO_ALG.getBlockSize() - 1; + encryptionClient_.setEncryptionFrameSize(frameSize); + + assertEquals(frameSize, encryptionClient_.getEncryptionFrameSize()); + } + + @Test(expected = IllegalArgumentException.class) + public void setNegativeFrameSize() throws IOException { + encryptionClient_.setEncryptionFrameSize(-1); + } + + @Test + public void setCryptoAlgorithm() throws IOException { + final CryptoAlgorithm setCryptoAlgorithm = + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; + encryptionClient_.setEncryptionAlgorithm(setCryptoAlgorithm); + + final CryptoAlgorithm getCryptoAlgorithm = encryptionClient_.getEncryptionAlgorithm(); + + assertEquals(setCryptoAlgorithm, getCryptoAlgorithm); + } + + @Test(expected = NullPointerException.class) + public void buildWithNullCommitmentPolicy() throws IOException { + AwsCrypto.builder().withCommitmentPolicy(null).build(); + } + + @Test + public void forbidAndSetCommittingCryptoAlgorithm() throws IOException { + final CryptoAlgorithm setCryptoAlgorithm = + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; + + assertThrows( + AwsCryptoException.class, + () -> + AwsCrypto.builder() + .withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .build() + .setEncryptionAlgorithm(setCryptoAlgorithm)); + } + + @Test + public void requireAndSetNonCommittingCryptoAlgorithm() throws IOException { + final CryptoAlgorithm setCryptoAlgorithm = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + + // Default case + assertThrows( + AwsCryptoException.class, + () -> AwsCrypto.standard().setEncryptionAlgorithm(setCryptoAlgorithm)); + + // Test explicitly for every relevant policy + for (CommitmentPolicy policy : requireWriteCommitmentPolicies) { + assertThrows( + AwsCryptoException.class, + () -> + AwsCrypto.builder() + .withCommitmentPolicy(policy) + .build() + .setEncryptionAlgorithm(setCryptoAlgorithm)); } - - @Test - public void noEncryptWithMoreThanMaxEdks() { - MasterKeyProvider provider = providerWithEdks(4); - assertThrows(AwsCryptoException.class, "Encrypted data keys exceed maxEncryptedDataKeys", () -> - maxEdksClient_.encryptData(provider, new byte[] {1})); - } - - @Test - public void noDecryptWithMoreThanMaxEdks() { - MasterKeyProvider provider = providerWithEdks(4); - CryptoResult result = noMaxEdksClient_.encryptData(provider, new byte[] {1}); - ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult()); - assertThrows(AwsCryptoException.class, "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", () -> - maxEdksClient_.decryptData(provider, ciphertext)); - } - - @Test - public void encryptDecryptWithNoMaxEdks() { - MasterKeyProvider provider = providerWithEdks(MESSAGE_FORMAT_MAX_EDKS); - CryptoResult result = noMaxEdksClient_.encryptData(provider, new byte[] {1}); - ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult()); - assertEquals(ciphertext.getEncryptedKeyBlobCount(), MESSAGE_FORMAT_MAX_EDKS); - noMaxEdksClient_.decryptData(provider, ciphertext); + } + + @Test + public void forbidAndBuildWithCommittingCryptoAlgorithm() throws IOException { + final CryptoAlgorithm setCryptoAlgorithm = + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; + + assertThrows( + AwsCryptoException.class, + () -> + AwsCrypto.builder() + .withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .withEncryptionAlgorithm(setCryptoAlgorithm) + .build()); + } + + @Test + public void requireAndBuildWithNonCommittingCryptoAlgorithm() throws IOException { + final CryptoAlgorithm setCryptoAlgorithm = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + + // Test default case + assertThrows( + AwsCryptoException.class, + () -> AwsCrypto.builder().withEncryptionAlgorithm(setCryptoAlgorithm).build()); + + // Test explicitly for every relevant policy + for (CommitmentPolicy policy : requireWriteCommitmentPolicies) { + assertThrows( + AwsCryptoException.class, + () -> + AwsCrypto.builder() + .withCommitmentPolicy(policy) + .withEncryptionAlgorithm(setCryptoAlgorithm) + .build()); } + } + + @Test + public void requireCommitmentOnDecryptFailsNonCommitting() throws IOException { + // Create non-committing ciphertext + forbidCommitmentClient_.setEncryptionAlgorithm( + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384); + + final byte[] cipherText = + forbidCommitmentClient_ + .encryptData(masterKeyProvider, new byte[1], new HashMap<>()) + .getResult(); + + // Test explicit policy set + assertThrows( + AwsCryptoException.class, + () -> + AwsCrypto.builder() + .withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) + .build() + .decryptData(masterKeyProvider, cipherText)); + + // Test default builder behavior + assertThrows( + AwsCryptoException.class, + () -> AwsCrypto.builder().build().decryptData(masterKeyProvider, cipherText)); + + // Test input stream + assertThrows( + AwsCryptoException.class, + () -> + AwsCrypto.builder() + .build() + .createDecryptingStream(masterKeyProvider, new ByteArrayInputStream(cipherText)) + .read()); + + // Test output stream + assertThrows( + AwsCryptoException.class, + () -> + AwsCrypto.builder() + .build() + .createDecryptingStream(masterKeyProvider, new ByteArrayOutputStream()) + .write(cipherText)); + } + + @Test + public void whenCustomCMMUsesNonCommittingAlgorithmWithRequirePolicy_throws() throws Exception { + CryptoMaterialsManager manager = + new CryptoMaterialsManager() { + @Override + public EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request) { + EncryptionMaterials encryptionMaterials = + new DefaultCryptoMaterialsManager(masterKeyProvider) + .getMaterialsForEncrypt(request); + + return encryptionMaterials.toBuilder() + .setAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384) + .build(); + } - @Test - public void encryptDecryptStreamWithLessThanMaxEdks() throws IOException { - MasterKeyProvider provider = providerWithEdks(2); - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - CryptoOutputStream encryptStream = maxEdksClient_.createEncryptingStream(provider, byteArrayOutputStream); - IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream); - encryptStream.close(); - - byte[] ciphertext = byteArrayOutputStream.toByteArray(); - assertEquals(new ParsedCiphertext(ciphertext).getEncryptedKeyBlobCount(), 2); - - byteArrayOutputStream.reset(); - CryptoOutputStream decryptStream = maxEdksClient_.createDecryptingStream(provider, byteArrayOutputStream); - IOUtils.copy(new ByteArrayInputStream(ciphertext), decryptStream); - decryptStream.close(); - } + @Override + public DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request) { + return new DefaultCryptoMaterialsManager(masterKeyProvider).decryptMaterials(request); + } + }; - @Test - public void encryptDecryptStreamWithMaxEdks() throws IOException { - MasterKeyProvider provider = providerWithEdks(3); - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - CryptoOutputStream encryptStream = maxEdksClient_.createEncryptingStream(provider, byteArrayOutputStream); - IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream); - encryptStream.close(); - - byte[] ciphertext = byteArrayOutputStream.toByteArray(); - assertEquals(new ParsedCiphertext(ciphertext).getEncryptedKeyBlobCount(), 3); - - byteArrayOutputStream.reset(); - CryptoOutputStream decryptStream = maxEdksClient_.createDecryptingStream(provider, byteArrayOutputStream); - IOUtils.copy(new ByteArrayInputStream(ciphertext), decryptStream); - decryptStream.close(); + for (CommitmentPolicy policy : requireWriteCommitmentPolicies) { + // create client with null encryption algorithm and a policy that requires encryption + final AwsCrypto client = AwsCrypto.builder().withCommitmentPolicy(policy).build(); + + byte[] plaintext = new byte[100]; + assertThrows(AwsCryptoException.class, () -> client.encryptData(manager, plaintext)); + assertThrows(AwsCryptoException.class, () -> client.estimateCiphertextSize(manager, 12345)); + assertThrows( + AwsCryptoException.class, + () -> client.createEncryptingStream(manager, new ByteArrayOutputStream()).write(0)); + assertThrows( + AwsCryptoException.class, + () -> + client + .createEncryptingStream(manager, new ByteArrayInputStream(new byte[1024 * 1024])) + .read()); } - - @Test - public void noEncryptStreamWithMoreThanMaxEdks() { - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - CryptoOutputStream encryptStream = maxEdksClient_.createEncryptingStream(providerWithEdks(4), byteArrayOutputStream); - assertThrows(AwsCryptoException.class, "Encrypted data keys exceed maxEncryptedDataKeys", () -> - IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream)); + } + + @Test + public void testDecryptMessageWithInvalidCommitment() { + for (final CryptoAlgorithm cryptoAlg : CryptoAlgorithm.values()) { + if (!cryptoAlg.isCommitting()) { + continue; + } + final Map encryptionContext = new HashMap(1); + encryptionContext.put("Commitment", "Commitment test for %s" + cryptoAlg); + encryptionClient_.setEncryptionAlgorithm(cryptoAlg); + byte[] plaintextBytes = new byte[16]; // Actual content doesn't matter + final byte[] cipherText = + encryptionClient_ + .encryptData(masterKeyProvider, plaintextBytes, encryptionContext) + .getResult(); + + // Find the commitment value + ParsedCiphertext parsed = new ParsedCiphertext(cipherText); + final int headerLength = parsed.getOffset(); + // The commitment value is immediately prior to the header tag for v2 encrypted messages + final int endOfCommitment = headerLength - parsed.getHeaderTag().length; + // The commitment is 32 bytes long, but if we just index one back from the endOfCommitment we + // know + // that we are within it. + cipherText[endOfCommitment - 1] ^= 0x01; // Tamper with the commitment value + + // Since commitment is verified prior to the header tag, we don't need to worry about actually + // creating a colliding tag but can just verify that the exception indicates an incorrect + // commitment + // value. + assertThrows( + BadCiphertextException.class, + "Key commitment validation failed. Key identity does " + + "not match the identity asserted in the message. Halting processing of this message.", + () -> encryptionClient_.decryptData(masterKeyProvider, cipherText)); } - - @Test - public void noDecryptStreamWithMoreThanMaxEdks() throws IOException { - MasterKeyProvider provider = providerWithEdks(4); - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - CryptoOutputStream encryptStream = noMaxEdksClient_.createEncryptingStream(provider, byteArrayOutputStream); - IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream); - encryptStream.close(); - - byte[] ciphertext = byteArrayOutputStream.toByteArray(); - - byteArrayOutputStream.reset(); - CryptoOutputStream decryptStream = maxEdksClient_.createDecryptingStream(provider, byteArrayOutputStream); - assertThrows(AwsCryptoException.class, "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", () -> - IOUtils.copy(new ByteArrayInputStream(ciphertext), decryptStream)); + } + + @Test(expected = IllegalArgumentException.class) + public void setNegativeMaxEdks() { + AwsCrypto.builder().withMaxEncryptedDataKeys(-1); + } + + @Test(expected = IllegalArgumentException.class) + public void setZeroMaxEdks() { + AwsCrypto.builder().withMaxEncryptedDataKeys(0); + } + + @Test + public void setValidMaxEdks() { + for (final int i : + new int[] { + 1, 10, MESSAGE_FORMAT_MAX_EDKS, MESSAGE_FORMAT_MAX_EDKS + 1, Integer.MAX_VALUE + }) { + AwsCrypto.builder().withMaxEncryptedDataKeys(i); } + } - @Test - public void encryptDecryptStreamWithNoMaxEdks() throws IOException { - MasterKeyProvider provider = providerWithEdks(MESSAGE_FORMAT_MAX_EDKS); - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - CryptoOutputStream encryptStream = noMaxEdksClient_.createEncryptingStream(provider, byteArrayOutputStream); - IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream); - encryptStream.close(); - - byte[] ciphertext = byteArrayOutputStream.toByteArray(); - assertEquals(new ParsedCiphertext(ciphertext).getEncryptedKeyBlobCount(), MESSAGE_FORMAT_MAX_EDKS); - - byteArrayOutputStream.reset(); - CryptoOutputStream decryptStream = noMaxEdksClient_.createDecryptingStream(provider, byteArrayOutputStream); - IOUtils.copy(new ByteArrayInputStream(ciphertext), decryptStream); - decryptStream.close(); + private MasterKeyProvider providerWithEdks(int numEdks) { + List> providers = new ArrayList<>(); + for (int i = 0; i < numEdks; i++) { + providers.add(masterKeyProvider); } + return MultipleProviderFactory.buildMultiProvider(providers); + } + + @Test + public void encryptDecryptWithLessThanMaxEdks() { + MasterKeyProvider provider = providerWithEdks(2); + CryptoResult result = maxEdksClient_.encryptData(provider, new byte[] {1}); + ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult()); + assertEquals(ciphertext.getEncryptedKeyBlobCount(), 2); + maxEdksClient_.decryptData(provider, ciphertext); + } + + @Test + public void encryptDecryptWithMaxEdks() { + MasterKeyProvider provider = providerWithEdks(3); + CryptoResult result = maxEdksClient_.encryptData(provider, new byte[] {1}); + ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult()); + assertEquals(ciphertext.getEncryptedKeyBlobCount(), 3); + maxEdksClient_.decryptData(provider, ciphertext); + } + + @Test + public void noEncryptWithMoreThanMaxEdks() { + MasterKeyProvider provider = providerWithEdks(4); + assertThrows( + AwsCryptoException.class, + "Encrypted data keys exceed maxEncryptedDataKeys", + () -> maxEdksClient_.encryptData(provider, new byte[] {1})); + } + + @Test + public void noDecryptWithMoreThanMaxEdks() { + MasterKeyProvider provider = providerWithEdks(4); + CryptoResult result = noMaxEdksClient_.encryptData(provider, new byte[] {1}); + ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult()); + assertThrows( + AwsCryptoException.class, + "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", + () -> maxEdksClient_.decryptData(provider, ciphertext)); + } + + @Test + public void encryptDecryptWithNoMaxEdks() { + MasterKeyProvider provider = providerWithEdks(MESSAGE_FORMAT_MAX_EDKS); + CryptoResult result = noMaxEdksClient_.encryptData(provider, new byte[] {1}); + ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult()); + assertEquals(ciphertext.getEncryptedKeyBlobCount(), MESSAGE_FORMAT_MAX_EDKS); + noMaxEdksClient_.decryptData(provider, ciphertext); + } + + @Test + public void encryptDecryptStreamWithLessThanMaxEdks() throws IOException { + MasterKeyProvider provider = providerWithEdks(2); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + CryptoOutputStream encryptStream = + maxEdksClient_.createEncryptingStream(provider, byteArrayOutputStream); + IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream); + encryptStream.close(); + + byte[] ciphertext = byteArrayOutputStream.toByteArray(); + assertEquals(new ParsedCiphertext(ciphertext).getEncryptedKeyBlobCount(), 2); + + byteArrayOutputStream.reset(); + CryptoOutputStream decryptStream = + maxEdksClient_.createDecryptingStream(provider, byteArrayOutputStream); + IOUtils.copy(new ByteArrayInputStream(ciphertext), decryptStream); + decryptStream.close(); + } + + @Test + public void encryptDecryptStreamWithMaxEdks() throws IOException { + MasterKeyProvider provider = providerWithEdks(3); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + CryptoOutputStream encryptStream = + maxEdksClient_.createEncryptingStream(provider, byteArrayOutputStream); + IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream); + encryptStream.close(); + + byte[] ciphertext = byteArrayOutputStream.toByteArray(); + assertEquals(new ParsedCiphertext(ciphertext).getEncryptedKeyBlobCount(), 3); + + byteArrayOutputStream.reset(); + CryptoOutputStream decryptStream = + maxEdksClient_.createDecryptingStream(provider, byteArrayOutputStream); + IOUtils.copy(new ByteArrayInputStream(ciphertext), decryptStream); + decryptStream.close(); + } + + @Test + public void noEncryptStreamWithMoreThanMaxEdks() { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + CryptoOutputStream encryptStream = + maxEdksClient_.createEncryptingStream(providerWithEdks(4), byteArrayOutputStream); + assertThrows( + AwsCryptoException.class, + "Encrypted data keys exceed maxEncryptedDataKeys", + () -> IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream)); + } + + @Test + public void noDecryptStreamWithMoreThanMaxEdks() throws IOException { + MasterKeyProvider provider = providerWithEdks(4); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + CryptoOutputStream encryptStream = + noMaxEdksClient_.createEncryptingStream(provider, byteArrayOutputStream); + IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream); + encryptStream.close(); + + byte[] ciphertext = byteArrayOutputStream.toByteArray(); + + byteArrayOutputStream.reset(); + CryptoOutputStream decryptStream = + maxEdksClient_.createDecryptingStream(provider, byteArrayOutputStream); + assertThrows( + AwsCryptoException.class, + "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", + () -> IOUtils.copy(new ByteArrayInputStream(ciphertext), decryptStream)); + } + + @Test + public void encryptDecryptStreamWithNoMaxEdks() throws IOException { + MasterKeyProvider provider = providerWithEdks(MESSAGE_FORMAT_MAX_EDKS); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + CryptoOutputStream encryptStream = + noMaxEdksClient_.createEncryptingStream(provider, byteArrayOutputStream); + IOUtils.copy(new ByteArrayInputStream(new byte[] {1}), encryptStream); + encryptStream.close(); + + byte[] ciphertext = byteArrayOutputStream.toByteArray(); + assertEquals( + new ParsedCiphertext(ciphertext).getEncryptedKeyBlobCount(), MESSAGE_FORMAT_MAX_EDKS); + + byteArrayOutputStream.reset(); + CryptoOutputStream decryptStream = + noMaxEdksClient_.createDecryptingStream(provider, byteArrayOutputStream); + IOUtils.copy(new ByteArrayInputStream(ciphertext), decryptStream); + decryptStream.close(); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/CommitmentKATRunner.java b/src/test/java/com/amazonaws/encryptionsdk/CommitmentKATRunner.java index 6c315dee5..da69fcc2c 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/CommitmentKATRunner.java +++ b/src/test/java/com/amazonaws/encryptionsdk/CommitmentKATRunner.java @@ -3,18 +3,6 @@ package com.amazonaws.encryptionsdk; -import java.io.File; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; - -import javax.crypto.spec.SecretKeySpec; - -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; -import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; - import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertTrue; @@ -22,156 +10,167 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; +import com.amazonaws.encryptionsdk.internal.Utils; +import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import javax.crypto.spec.SecretKeySpec; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; -import com.amazonaws.encryptionsdk.internal.Utils; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.core.type.TypeReference; - @RunWith(Parameterized.class) public class CommitmentKATRunner { - private String comment; - private String keyringType; - private byte[] ciphertext; - private byte[] commitment; - private byte[] plaintext; - private byte[] decryptedDEK; - private byte[] messageId; - private byte[] header; - private Map encryptionContext; - private boolean status; - - private static final String TEST_VECTOR_RESOURCE_PATH = "/commitment-test-vectors.json"; - - public CommitmentKATRunner( - String comment, - String keyringType, - byte[] ciphertext, - byte[] commitment, - byte[] plaintext, - byte[] decryptedDEK, - byte[] messageId, - byte[] header, - Map encryptionContext, - boolean status - ) throws Exception { - this.comment = comment; - this.keyringType = keyringType; - this.ciphertext = ciphertext; - this.commitment = commitment; - this.plaintext = plaintext; - this.decryptedDEK = decryptedDEK; - this.messageId = messageId; - this.header = header; - this.encryptionContext = encryptionContext; - this.status = status; + private String comment; + private String keyringType; + private byte[] ciphertext; + private byte[] commitment; + private byte[] plaintext; + private byte[] decryptedDEK; + private byte[] messageId; + private byte[] header; + private Map encryptionContext; + private boolean status; + + private static final String TEST_VECTOR_RESOURCE_PATH = "/commitment-test-vectors.json"; + + public CommitmentKATRunner( + String comment, + String keyringType, + byte[] ciphertext, + byte[] commitment, + byte[] plaintext, + byte[] decryptedDEK, + byte[] messageId, + byte[] header, + Map encryptionContext, + boolean status) + throws Exception { + this.comment = comment; + this.keyringType = keyringType; + this.ciphertext = ciphertext; + this.commitment = commitment; + this.plaintext = plaintext; + this.decryptedDEK = decryptedDEK; + this.messageId = messageId; + this.header = header; + this.encryptionContext = encryptionContext; + this.status = status; + } + + @Parameters( + name = "{index}: testDecryptCommitment({0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9})") + public static Collection data() throws Exception { + final String testVectorsName = + CommitmentKATRunner.class.getResource(TEST_VECTOR_RESOURCE_PATH).getPath(); + final File ciphertextManifestFile = new File(testVectorsName); + + final List testCases_ = new ArrayList(); + + if (!ciphertextManifestFile.exists()) { + throw new IllegalStateException( + "Missing commitment test vectors file from src/test/resources."); } - @Parameters(name = "{index}: testDecryptCommitment({0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9})") - public static Collection data() throws Exception { - final String testVectorsName = CommitmentKATRunner.class.getResource(TEST_VECTOR_RESOURCE_PATH).getPath(); - final File ciphertextManifestFile = new File(testVectorsName); - - final List testCases_ = new ArrayList(); - - if (!ciphertextManifestFile.exists()) { - throw new IllegalStateException("Missing commitment test vectors file from src/test/resources."); - } - - final ObjectMapper ciphertextManifestMapper = new ObjectMapper(); - final Map ciphertextManifest = ciphertextManifestMapper.readValue( - ciphertextManifestFile, - new TypeReference>() { - } - ); - - final List> testCases = (List>) ciphertextManifest.get("tests"); - for (Map test : testCases) { - final String keyringType = (String) test.get("keyring-type"); - final byte[] decryptedDEK = Utils.decodeBase64String((String) test.get("decrypted-dek")); - final byte[] ciphertext = Utils.decodeBase64String((String) test.get("ciphertext")); - final byte[] commitment = Utils.decodeBase64String((String) test.get("commitment")); - final byte[] messageId = Utils.decodeBase64String((String) test.get("message-id")); - final byte[] header = Utils.decodeBase64String((String) test.get("header")); - final boolean status = (boolean) test.get("status"); - final String comment = (String) test.get("comment"); - final Map encryptionContext = (Map) test.get("encryption-context"); - - // plaintext is available for cases which succeed decryption - byte[] plaintext = null; - if (status) { - final List plaintextFrames = (List) test.get("plaintext-frames"); - plaintext = String.join("", plaintextFrames).getBytes(StandardCharsets.UTF_8); - } - - testCases_.add(new Object[]{ - comment, - keyringType, - ciphertext, - commitment, - plaintext, - decryptedDEK, - messageId, - header, - encryptionContext, - status - }); - } - return testCases_; + final ObjectMapper ciphertextManifestMapper = new ObjectMapper(); + final Map ciphertextManifest = + ciphertextManifestMapper.readValue( + ciphertextManifestFile, new TypeReference>() {}); + + final List> testCases = + (List>) ciphertextManifest.get("tests"); + for (Map test : testCases) { + final String keyringType = (String) test.get("keyring-type"); + final byte[] decryptedDEK = Utils.decodeBase64String((String) test.get("decrypted-dek")); + final byte[] ciphertext = Utils.decodeBase64String((String) test.get("ciphertext")); + final byte[] commitment = Utils.decodeBase64String((String) test.get("commitment")); + final byte[] messageId = Utils.decodeBase64String((String) test.get("message-id")); + final byte[] header = Utils.decodeBase64String((String) test.get("header")); + final boolean status = (boolean) test.get("status"); + final String comment = (String) test.get("comment"); + final Map encryptionContext = + (Map) test.get("encryption-context"); + + // plaintext is available for cases which succeed decryption + byte[] plaintext = null; + if (status) { + final List plaintextFrames = (List) test.get("plaintext-frames"); + plaintext = String.join("", plaintextFrames).getBytes(StandardCharsets.UTF_8); + } + + testCases_.add( + new Object[] { + comment, + keyringType, + ciphertext, + commitment, + plaintext, + decryptedDEK, + messageId, + header, + encryptionContext, + status + }); + } + return testCases_; + } + + @Test + public void testDecryptCommitment() throws Exception { + final AwsCrypto crypto = + AwsCrypto.builder() + .withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .build(); + + // Determine what MKP to test with based on whether the test case was create using kms or not. + // If it was, we can test decryption of the message completely end to end by using a + // Discovery KMS MKP. + // Otherwise, we must mock out a provider that just returns the data encryption key associated + // with the test case. + final MasterKeyProvider mkp; + switch (keyringType) { + case "aws-kms": + mkp = KmsMasterKeyProvider.builder().buildDiscovery(); + break; + case "static": + default: + mkp = mock(MasterKeyProvider.class); + DataKey dataKey = + new DataKey( + // All committing algorithms use HkdfSHA512 for + // the kdf. If this changes, the test vectors + // will need to indicate what algorithm suite + // was used in order for this test to + // appropriately set the secret key spec's algorithm + new SecretKeySpec(decryptedDEK, "HkdfSHA512"), new byte[0], new byte[0], null); + when(mkp.decryptDataKey(any(), any(), any())).thenReturn(dataKey); + break; } - @Test - public void testDecryptCommitment() throws Exception { - final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt).build(); - - // Determine what MKP to test with based on whether the test case was create using kms or not. - // If it was, we can test decryption of the message completely end to end by using a - // Discovery KMS MKP. - // Otherwise, we must mock out a provider that just returns the data encryption key associated - // with the test case. - final MasterKeyProvider mkp; - switch (keyringType) { - case "aws-kms": - mkp = KmsMasterKeyProvider.builder().buildDiscovery(); - break; - case "static": - default: - mkp = mock(MasterKeyProvider.class); - DataKey dataKey = new DataKey( - // All committing algorithms use HkdfSHA512 for - // the kdf. If this changes, the test vectors - // will need to indicate what algorithm suite - // was used in order for this test to - // appropriately set the secret key spec's algorithm - new SecretKeySpec(decryptedDEK, "HkdfSHA512"), - new byte[0], - new byte[0], - null - ); - when(mkp.decryptDataKey(any(), any(), any())).thenReturn(dataKey); - break; - } - - // Ensure tests that are expected to fail do so with the right exception and error message - if (!status) { - assertThrows(BadCiphertextException.class, "Key commitment validation failed. Key identity does not " + - "match the identity asserted in the message. Halting processing of this message.", - () -> crypto.decryptData(mkp, ciphertext) - ); - return; - } - - // Otherwise ensure our result matches the expected commitment data - final CryptoResult decryptResult = crypto.decryptData( - mkp, - ciphertext); - assertArrayEquals(decryptResult.getHeaders().getSuiteData(), commitment); - assertArrayEquals(decryptResult.getHeaders().getMessageId(), messageId); - assertArrayEquals(decryptResult.getHeaders().toByteArray(), header); - assertTrue(decryptResult.getEncryptionContext().equals(encryptionContext)); - assertArrayEquals((byte[]) decryptResult.getResult(), plaintext); + // Ensure tests that are expected to fail do so with the right exception and error message + if (!status) { + assertThrows( + BadCiphertextException.class, + "Key commitment validation failed. Key identity does not " + + "match the identity asserted in the message. Halting processing of this message.", + () -> crypto.decryptData(mkp, ciphertext)); + return; } + + // Otherwise ensure our result matches the expected commitment data + final CryptoResult decryptResult = crypto.decryptData(mkp, ciphertext); + assertArrayEquals(decryptResult.getHeaders().getSuiteData(), commitment); + assertArrayEquals(decryptResult.getHeaders().getMessageId(), messageId); + assertArrayEquals(decryptResult.getHeaders().toByteArray(), header); + assertTrue(decryptResult.getEncryptionContext().equals(encryptionContext)); + assertArrayEquals((byte[]) decryptResult.getResult(), plaintext); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/CryptoAlgorithmTest.java b/src/test/java/com/amazonaws/encryptionsdk/CryptoAlgorithmTest.java index 0fe45b93c..fba98ae26 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/CryptoAlgorithmTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/CryptoAlgorithmTest.java @@ -14,115 +14,147 @@ import com.amazonaws.encryptionsdk.model.CiphertextHeaders; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; -import org.junit.Test; - -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; import java.security.InvalidKeyException; import java.util.Collections; import java.util.Map; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.junit.Test; public class CryptoAlgorithmTest { - @Test - public void testDeserialization() { - for (CryptoAlgorithm algorithm : CryptoAlgorithm.values()) { - assertEquals(algorithm.toString(), - algorithm, - CryptoAlgorithm.deserialize(algorithm.getMessageFormatVersion(), algorithm.getValue())); - } - } - - @Test - public void testGetCommittedEncryptionKey() throws InvalidKeyException { - CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; - SecretKeySpec secretKey = new SecretKeySpec(Utils.decodeBase64String(TestUtils.messageWithCommitKeyDEKBase64), algorithm.getDataKeyAlgo()); - CiphertextHeaders headers = new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); - SecretKey key = algorithm.getEncryptionKeyFromDataKey(secretKey, headers); - assertNotNull(key); - assertEquals(algorithm.getKeyAlgo(), key.getAlgorithm()); - } - - @Test - public void testGetCommittedEncryptionKeyIncorrectCommitment() throws InvalidKeyException { - CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; - SecretKeySpec secretKey = new SecretKeySpec(Utils.decodeBase64String(TestUtils.messageWithCommitKeyDEKBase64), algorithm.getDataKeyAlgo()); - CiphertextHeaders headers = new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); - - // Set header to an incorrect commitment value - headers.setSuiteData(new byte[algorithm.getSuiteDataLength()]); - assertThrows(BadCiphertextException.class, "Key commitment validation failed. Key identity does not match the " + - "identity asserted in the message. Halting processing of this message.", - () -> algorithm.getEncryptionKeyFromDataKey(secretKey, headers)); - } - - @Test - public void testGetCommittedEncryptionKeyIncorrectKeySpecAlgorithm() throws InvalidKeyException { - CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; - SecretKeySpec secretKey = new SecretKeySpec(new byte[algorithm.getDataKeyLength()], "incorrectAlgorithm"); - CiphertextHeaders headers = new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); - assertThrows(InvalidKeyException.class, "DataKey of incorrect algorithm.", - () -> algorithm.getEncryptionKeyFromDataKey(secretKey, headers)); - } - - @Test - public void testGetCommittedEncryptionKeyIncorrectLength() throws InvalidKeyException { - CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; - SecretKeySpec secretKey = new SecretKeySpec(new byte[1], "HkdfSHA512"); - CiphertextHeaders headers = new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); - assertThrows(IllegalArgumentException.class, "DataKey of incorrect length.", - () -> algorithm.getEncryptionKeyFromDataKey(secretKey, headers)); - } - - @Test - public void testGetUnCommittedEncryptionKey() throws InvalidKeyException { - CryptoAlgorithm algo = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - SecretKeySpec secretKey = new SecretKeySpec(new byte[algo.getDataKeyLength()], algo.getDataKeyAlgo()); - CiphertextHeaders headers = getTestHeaders(algo); - SecretKey key = algo.getEncryptionKeyFromDataKey(secretKey, headers); - assertNotNull(key); - assertEquals(algo.getKeyAlgo(), key.getAlgorithm()); - } - - @Test - public void testGetUnCommittedEncryptionKeyFromDataKeyIncorrectKeySpecAlgorithm() throws InvalidKeyException { - CryptoAlgorithm algo = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - SecretKeySpec secretKey = new SecretKeySpec(new byte[algo.getDataKeyLength()], "incorrectAlgorithm"); - CiphertextHeaders headers = getTestHeaders(algo); - assertThrows(InvalidKeyException.class, "DataKey of incorrect algorithm.", - () -> TestUtils.messageWithCommitKeyCryptoAlgorithm.getEncryptionKeyFromDataKey(secretKey, headers)); - } - - @Test - public void testGetUnCommittedEncryptionKeyIncorrectLength() throws InvalidKeyException { - SecretKeySpec secretKey = new SecretKeySpec(new byte[1], "HkdfSHA512"); - CiphertextHeaders headers = new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); - assertThrows(IllegalArgumentException.class, "DataKey of incorrect length.", - () -> TestUtils.messageWithCommitKeyCryptoAlgorithm.getEncryptionKeyFromDataKey(secretKey, headers)); - } - - private ParsedCiphertext getTestHeaders(CryptoAlgorithm algo) { - // Generate test headers - final int frameSize_ = AwsCrypto.getDefaultFrameSize(); - final Map encryptionContext = Collections.emptyMap(); - final CommitmentPolicy policy = CommitmentPolicy.ForbidEncryptAllowDecrypt; - - final EncryptionMaterialsRequest encryptionMaterialsRequest = EncryptionMaterialsRequest.newBuilder() - .setContext(encryptionContext) - .setRequestedAlgorithm(algo) - .setCommitmentPolicy(policy) - .build(); - - final StaticMasterKey masterKeyProvider = new StaticMasterKey("mock"); - - final EncryptionMaterials encryptionMaterials = new DefaultCryptoMaterialsManager(masterKeyProvider) - .getMaterialsForEncrypt(encryptionMaterialsRequest); - - final EncryptionHandler encryptionHandler = new EncryptionHandler(frameSize_, encryptionMaterials, policy); - - final byte[] in = new byte[0]; - final int ciphertextLen = encryptionHandler.estimateOutputSize(in.length); - final byte[] ciphertext = new byte[ciphertextLen]; - encryptionHandler.processBytes(in, 0, in.length, ciphertext, 0); - return new ParsedCiphertext(ciphertext); + @Test + public void testDeserialization() { + for (CryptoAlgorithm algorithm : CryptoAlgorithm.values()) { + assertEquals( + algorithm.toString(), + algorithm, + CryptoAlgorithm.deserialize(algorithm.getMessageFormatVersion(), algorithm.getValue())); } + } + + @Test + public void testGetCommittedEncryptionKey() throws InvalidKeyException { + CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; + SecretKeySpec secretKey = + new SecretKeySpec( + Utils.decodeBase64String(TestUtils.messageWithCommitKeyDEKBase64), + algorithm.getDataKeyAlgo()); + CiphertextHeaders headers = + new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); + SecretKey key = algorithm.getEncryptionKeyFromDataKey(secretKey, headers); + assertNotNull(key); + assertEquals(algorithm.getKeyAlgo(), key.getAlgorithm()); + } + + @Test + public void testGetCommittedEncryptionKeyIncorrectCommitment() throws InvalidKeyException { + CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; + SecretKeySpec secretKey = + new SecretKeySpec( + Utils.decodeBase64String(TestUtils.messageWithCommitKeyDEKBase64), + algorithm.getDataKeyAlgo()); + CiphertextHeaders headers = + new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); + + // Set header to an incorrect commitment value + headers.setSuiteData(new byte[algorithm.getSuiteDataLength()]); + assertThrows( + BadCiphertextException.class, + "Key commitment validation failed. Key identity does not match the " + + "identity asserted in the message. Halting processing of this message.", + () -> algorithm.getEncryptionKeyFromDataKey(secretKey, headers)); + } + + @Test + public void testGetCommittedEncryptionKeyIncorrectKeySpecAlgorithm() throws InvalidKeyException { + CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; + SecretKeySpec secretKey = + new SecretKeySpec(new byte[algorithm.getDataKeyLength()], "incorrectAlgorithm"); + CiphertextHeaders headers = + new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); + assertThrows( + InvalidKeyException.class, + "DataKey of incorrect algorithm.", + () -> algorithm.getEncryptionKeyFromDataKey(secretKey, headers)); + } + + @Test + public void testGetCommittedEncryptionKeyIncorrectLength() throws InvalidKeyException { + CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; + SecretKeySpec secretKey = new SecretKeySpec(new byte[1], "HkdfSHA512"); + CiphertextHeaders headers = + new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); + assertThrows( + IllegalArgumentException.class, + "DataKey of incorrect length.", + () -> algorithm.getEncryptionKeyFromDataKey(secretKey, headers)); + } + + @Test + public void testGetUnCommittedEncryptionKey() throws InvalidKeyException { + CryptoAlgorithm algo = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + SecretKeySpec secretKey = + new SecretKeySpec(new byte[algo.getDataKeyLength()], algo.getDataKeyAlgo()); + CiphertextHeaders headers = getTestHeaders(algo); + SecretKey key = algo.getEncryptionKeyFromDataKey(secretKey, headers); + assertNotNull(key); + assertEquals(algo.getKeyAlgo(), key.getAlgorithm()); + } + + @Test + public void testGetUnCommittedEncryptionKeyFromDataKeyIncorrectKeySpecAlgorithm() + throws InvalidKeyException { + CryptoAlgorithm algo = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + SecretKeySpec secretKey = + new SecretKeySpec(new byte[algo.getDataKeyLength()], "incorrectAlgorithm"); + CiphertextHeaders headers = getTestHeaders(algo); + assertThrows( + InvalidKeyException.class, + "DataKey of incorrect algorithm.", + () -> + TestUtils.messageWithCommitKeyCryptoAlgorithm.getEncryptionKeyFromDataKey( + secretKey, headers)); + } + + @Test + public void testGetUnCommittedEncryptionKeyIncorrectLength() throws InvalidKeyException { + SecretKeySpec secretKey = new SecretKeySpec(new byte[1], "HkdfSHA512"); + CiphertextHeaders headers = + new ParsedCiphertext(Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64)); + assertThrows( + IllegalArgumentException.class, + "DataKey of incorrect length.", + () -> + TestUtils.messageWithCommitKeyCryptoAlgorithm.getEncryptionKeyFromDataKey( + secretKey, headers)); + } + + private ParsedCiphertext getTestHeaders(CryptoAlgorithm algo) { + // Generate test headers + final int frameSize_ = AwsCrypto.getDefaultFrameSize(); + final Map encryptionContext = Collections.emptyMap(); + final CommitmentPolicy policy = CommitmentPolicy.ForbidEncryptAllowDecrypt; + + final EncryptionMaterialsRequest encryptionMaterialsRequest = + EncryptionMaterialsRequest.newBuilder() + .setContext(encryptionContext) + .setRequestedAlgorithm(algo) + .setCommitmentPolicy(policy) + .build(); + + final StaticMasterKey masterKeyProvider = new StaticMasterKey("mock"); + + final EncryptionMaterials encryptionMaterials = + new DefaultCryptoMaterialsManager(masterKeyProvider) + .getMaterialsForEncrypt(encryptionMaterialsRequest); + + final EncryptionHandler encryptionHandler = + new EncryptionHandler(frameSize_, encryptionMaterials, policy); + + final byte[] in = new byte[0]; + final int ciphertextLen = encryptionHandler.estimateOutputSize(in.length); + final byte[] ciphertext = new byte[ciphertextLen]; + encryptionHandler.processBytes(in, 0, in.length, ciphertext, 0); + return new ParsedCiphertext(ciphertext); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/CryptoInputStreamTest.java b/src/test/java/com/amazonaws/encryptionsdk/CryptoInputStreamTest.java index 32ae93f80..342b1bf72 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/CryptoInputStreamTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/CryptoInputStreamTest.java @@ -13,7 +13,10 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import javax.crypto.spec.SecretKeySpec; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; +import com.amazonaws.encryptionsdk.internal.TestIOUtils; +import com.amazonaws.encryptionsdk.jce.JceMasterKey; +import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -26,7 +29,7 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Consumer; - +import javax.crypto.spec.SecretKeySpec; import org.bouncycastle.util.Arrays; import org.junit.Before; import org.junit.Test; @@ -35,667 +38,624 @@ import org.junit.runners.Parameterized; import org.mockito.ArgumentCaptor; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; -import com.amazonaws.encryptionsdk.internal.TestIOUtils; -import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; - @RunWith(Enclosed.class) public class CryptoInputStreamTest { - private static final SecureRandom RND = new SecureRandom(); - private static final MasterKey customerMasterKey; - private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; - - static { - byte[] rawKey = new byte[16]; - RND.nextBytes(rawKey); - - customerMasterKey = JceMasterKey.getInstance( - new SecretKeySpec(rawKey, "AES"), - "mockProvider", - "mockKey", - "AES/GCM/NoPadding" - ); + private static final SecureRandom RND = new SecureRandom(); + private static final MasterKey customerMasterKey; + private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; + + static { + byte[] rawKey = new byte[16]; + RND.nextBytes(rawKey); + + customerMasterKey = + JceMasterKey.getInstance( + new SecretKeySpec(rawKey, "AES"), "mockProvider", "mockKey", "AES/GCM/NoPadding"); + } + + private static void testRoundTrip( + int dataSize, + Consumer customizer, + Callback onEncrypt, + Callback onDecrypt, + CommitmentPolicy commitmentPolicy) + throws Exception { + AwsCrypto awsCrypto = AwsCrypto.builder().withCommitmentPolicy(commitmentPolicy).build(); + customizer.accept(awsCrypto); + + byte[] plaintext = insecureRandomBytes(dataSize); + + ByteArrayInputStream inputStream = new ByteArrayInputStream(plaintext); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + onEncrypt.process(awsCrypto, inputStream, outputStream); + + inputStream = new ByteArrayInputStream(outputStream.toByteArray()); + outputStream = new ByteArrayOutputStream(); + + onDecrypt.process(awsCrypto, inputStream, outputStream); + + assertArrayEquals(getSha256Hash(plaintext), getSha256Hash(outputStream.toByteArray())); + } + + private interface Callback { + void process(AwsCrypto crypto, InputStream inStream, OutputStream outStream) throws Exception; + } + + private static Callback encryptWithContext(Map encryptionContext) { + return (awsCrypto, inStream, outStream) -> { + final InputStream cryptoStream = + awsCrypto.createEncryptingStream(customerMasterKey, inStream, encryptionContext); + + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); + }; + } + + private static Callback encryptWithoutContext() { + return (awsCrypto, inStream, outStream) -> { + final InputStream cryptoStream = + awsCrypto.createEncryptingStream(customerMasterKey, inStream); + + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); + }; + } + + private static Callback basicDecrypt(int readLen) { + return (awsCrypto, inStream, outStream) -> { + final InputStream cryptoStream = + awsCrypto.createDecryptingStream(customerMasterKey, inStream); + + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream, readLen); + }; + } + + private static Callback basicDecrypt() { + return (awsCrypto, inStream, outStream) -> { + final InputStream cryptoStream = + awsCrypto.createDecryptingStream(customerMasterKey, inStream); + + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); + }; + } + + @RunWith(Parameterized.class) + public static class ParameterizedEncryptDecryptTest { + private final CryptoAlgorithm cryptoAlg; + private final int byteSize, frameSize, readLen; + + public ParameterizedEncryptDecryptTest( + CryptoAlgorithm cryptoAlg, int byteSize, int frameSize, int readLen) { + this.cryptoAlg = cryptoAlg; + this.byteSize = byteSize; + this.frameSize = frameSize; + this.readLen = readLen; } - private static void testRoundTrip( - int dataSize, - Consumer customizer, - Callback onEncrypt, - Callback onDecrypt, - CommitmentPolicy commitmentPolicy - ) throws Exception { - AwsCrypto awsCrypto = AwsCrypto.builder().withCommitmentPolicy(commitmentPolicy).build(); - customizer.accept(awsCrypto); - - byte[] plaintext = insecureRandomBytes(dataSize); - - ByteArrayInputStream inputStream = new ByteArrayInputStream(plaintext); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - onEncrypt.process(awsCrypto, inputStream, outputStream); + @Parameterized.Parameters( + name = "{index}: encryptDecrypt(algorithm={0}, byteSize={1}, frameSize={2}, readLen={3})") + public static Collection encryptDecryptParams() { + ArrayList testCases = new ArrayList<>(); + + // We'll run more exhaustive tests on the first algorithm, then go lighter weight on the rest. + boolean firstAlgorithm = true; + + for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { + final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); + + // Our bytesToTest and readLenVals arrays tend to have the bigger numbers towards the end - + // we'll chop off + // the last few as they take the longest and don't really add that much more coverage. + int skipLastNSizes; + if (!FastTestsOnlySuite.isFastTestSuiteActive()) { + skipLastNSizes = 0; + } else if (firstAlgorithm) { + // We'll run more tests for the first algorithm in the list - but not go quite so far as + // running the + // 1MB tests. + skipLastNSizes = 1; + } else { + skipLastNSizes = 2; + } - inputStream = new ByteArrayInputStream(outputStream.toByteArray()); - outputStream = new ByteArrayOutputStream(); + // iterate over frame size to test + for (final int frameSize : frameSizeToTest) { + int[] bytesToTest = { + 0, + 1, + frameSize - 1, + frameSize, + frameSize + 1, + (int) (frameSize * 1.5), + frameSize * 2, + 1000000 + }; + + bytesToTest = Arrays.copyOfRange(bytesToTest, 0, bytesToTest.length - skipLastNSizes); + + // iterate over byte size to test + for (final int byteSize : bytesToTest) { + int[] readLenVals = {1, byteSize - 1, byteSize, byteSize + 1, byteSize * 2, 1000000}; + + readLenVals = Arrays.copyOfRange(readLenVals, 0, readLenVals.length - skipLastNSizes); + + // iterate over read lengths to test + for (final int readLen : readLenVals) { + if (byteSize >= 0 && readLen > 0) { + testCases.add(new Object[] {cryptoAlg, byteSize, frameSize, readLen}); + } + } + } + } - onDecrypt.process(awsCrypto, inputStream, outputStream); + firstAlgorithm = false; + } - assertArrayEquals(getSha256Hash(plaintext), getSha256Hash(outputStream.toByteArray())); + return testCases; } - private interface Callback { - void process(AwsCrypto crypto, InputStream inStream, OutputStream outStream) throws Exception; + @Test + public void encryptDecrypt() throws Exception { + final CommitmentPolicy commitmentPolicy = + cryptoAlg.isCommitting() + ? CommitmentPolicy.RequireEncryptRequireDecrypt + : CommitmentPolicy.ForbidEncryptAllowDecrypt; + testRoundTrip( + byteSize, + awsCrypto -> { + awsCrypto.setEncryptionAlgorithm(cryptoAlg); + awsCrypto.setEncryptionFrameSize(frameSize); + }, + encryptWithoutContext(), + basicDecrypt(readLen), + commitmentPolicy); } + } - private static Callback encryptWithContext(Map encryptionContext) { - return (awsCrypto, inStream, outStream) -> { - final InputStream cryptoStream = awsCrypto.createEncryptingStream( - customerMasterKey, - inStream, - encryptionContext); + public static class NonParameterized { + private AwsCrypto encryptionClient_; - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - }; + @Before + public void setup() throws IOException { + encryptionClient_ = AwsCrypto.standard(); } - private static Callback encryptWithoutContext() { - return (awsCrypto, inStream, outStream) -> { - final InputStream cryptoStream = awsCrypto.createEncryptingStream( - customerMasterKey, - inStream - ); + @Test + public void doEncryptDecryptWithoutEncContext() throws Exception { + testRoundTrip( + 1_000_000, + awsCrypto -> {}, + encryptWithoutContext(), + basicDecrypt(), + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - }; + @Test + public void encryptBytesDecryptStream() throws Exception { + Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "encryptBytesDecryptStream"); + + testRoundTrip( + 1_000_000, + awsCrypto -> {}, + (AwsCrypto awsCrypto, InputStream inStream, OutputStream outStream) -> { + ByteArrayOutputStream inbuf = new ByteArrayOutputStream(); + TestIOUtils.copyInStreamToOutStream(inStream, inbuf); + + CryptoResult ciphertext = + awsCrypto.encryptData(customerMasterKey, inbuf.toByteArray(), encryptionContext); + + outStream.write(ciphertext.getResult()); + }, + basicDecrypt(), + CommitmentPolicy.RequireEncryptRequireDecrypt); } - private static Callback basicDecrypt(int readLen) { - return (awsCrypto, inStream, outStream) -> { - final InputStream cryptoStream = awsCrypto.createDecryptingStream( - customerMasterKey, - inStream); + @Test + public void encryptStreamDecryptBytes() throws Exception { + Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "encryptStreamDecryptBytes"); + testRoundTrip( + 1_000_000, + awsCrypto -> {}, + encryptWithContext(encryptionContext), + (AwsCrypto awsCrypto, InputStream inStream, OutputStream outStream) -> { + ByteArrayOutputStream inbuf = new ByteArrayOutputStream(); + TestIOUtils.copyInStreamToOutStream(inStream, inbuf); + + CryptoResult ciphertext = + awsCrypto.decryptData(customerMasterKey, inbuf.toByteArray()); + + outStream.write(ciphertext.getResult()); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream, readLen); - }; + @Test + public void encryptOSDecryptIS() throws Exception { + Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "encryptOSDecryptIS"); + + testRoundTrip( + 1_000_000, + awsCrypto -> {}, + (awsCrypto, inStream, outStream) -> { + OutputStream cryptoOS = + awsCrypto.createEncryptingStream(customerMasterKey, outStream, encryptionContext); + TestIOUtils.copyInStreamToOutStream(inStream, cryptoOS); + }, + basicDecrypt(), + CommitmentPolicy.RequireEncryptRequireDecrypt); } - private static Callback basicDecrypt() { - return (awsCrypto, inStream, outStream) -> { - final InputStream cryptoStream = awsCrypto.createDecryptingStream( - customerMasterKey, - inStream); + private void singleByteCopyLoop(InputStream is, OutputStream os) throws Exception { + int rv; + while (-1 != (rv = is.read())) { + os.write(rv); + } - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - }; + is.close(); + os.close(); } - @RunWith(Parameterized.class) - public static class ParameterizedEncryptDecryptTest { - private final CryptoAlgorithm cryptoAlg; - private final int byteSize, frameSize, readLen; - - public ParameterizedEncryptDecryptTest( - CryptoAlgorithm cryptoAlg, int byteSize, int frameSize, int readLen - ) { - this.cryptoAlg = cryptoAlg; - this.byteSize = byteSize; - this.frameSize = frameSize; - this.readLen = readLen; - } + @Test + public void singleByteRead() throws Exception { + Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "singleByteRead"); + + testRoundTrip( + 1_000_000, + awsCrypto -> {}, + (awsCrypto, inStream, outStream) -> { + InputStream is = + awsCrypto.createEncryptingStream(customerMasterKey, inStream, encryptionContext); + singleByteCopyLoop(is, outStream); + }, + (awsCrypto, inStream, outStream) -> { + InputStream is = awsCrypto.createDecryptingStream(customerMasterKey, inStream); + singleByteCopyLoop(is, outStream); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - @Parameterized.Parameters(name = "{index}: encryptDecrypt(algorithm={0}, byteSize={1}, frameSize={2}, readLen={3})") - public static Collection encryptDecryptParams() { - ArrayList testCases = new ArrayList<>(); - - // We'll run more exhaustive tests on the first algorithm, then go lighter weight on the rest. - boolean firstAlgorithm = true; - - for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { - final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); - - // Our bytesToTest and readLenVals arrays tend to have the bigger numbers towards the end - we'll chop off - // the last few as they take the longest and don't really add that much more coverage. - int skipLastNSizes; - if (!FastTestsOnlySuite.isFastTestSuiteActive()) { - skipLastNSizes = 0; - } else if (firstAlgorithm) { - // We'll run more tests for the first algorithm in the list - but not go quite so far as running the - // 1MB tests. - skipLastNSizes = 1; - } else { - skipLastNSizes = 2; - } - - // iterate over frame size to test - for (final int frameSize : frameSizeToTest) { - int[] bytesToTest = { - 0, 1, frameSize - 1, frameSize, frameSize + 1, (int) (frameSize * 1.5), - frameSize * 2, 1000000 - }; - - bytesToTest = Arrays.copyOfRange(bytesToTest, 0, bytesToTest.length - skipLastNSizes); - - // iterate over byte size to test - for (final int byteSize : bytesToTest) { - int[] readLenVals = {1, byteSize - 1, byteSize, byteSize + 1, byteSize * 2, 1000000}; - - readLenVals = Arrays.copyOfRange(readLenVals, 0, readLenVals.length - skipLastNSizes); - - // iterate over read lengths to test - for (final int readLen : readLenVals) { - if (byteSize >= 0 && readLen > 0) { - testCases.add(new Object[]{cryptoAlg, byteSize, frameSize, readLen}); - } - } - } - } - - firstAlgorithm = false; - } + @SuppressWarnings({"ConstantConditions", "ResultOfMethodCallIgnored"}) + @Test(expected = NullPointerException.class) + public void whenNullBufferPassed_andNoOffsetArgs_readThrowsNPE() + throws BadCiphertextException, IOException { + final Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "nullReadBuffer"); - return testCases; - } + final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); + final InputStream encryptionInStream = + encryptionClient_.createEncryptingStream(customerMasterKey, inStream, encryptionContext); - @Test - public void encryptDecrypt() throws Exception { - final CommitmentPolicy commitmentPolicy = cryptoAlg.isCommitting() ? - CommitmentPolicy.RequireEncryptRequireDecrypt : CommitmentPolicy.ForbidEncryptAllowDecrypt; - testRoundTrip( - byteSize, - awsCrypto -> { - awsCrypto.setEncryptionAlgorithm(cryptoAlg); - awsCrypto.setEncryptionFrameSize(frameSize); - }, - encryptWithoutContext(), - basicDecrypt(readLen), - commitmentPolicy - ); - } + encryptionInStream.read(null); } - public static class NonParameterized { - private AwsCrypto encryptionClient_; + @SuppressWarnings({"ConstantConditions", "ResultOfMethodCallIgnored"}) + @Test(expected = NullPointerException.class) + public void whenNullBufferPassed_andOffsetArgsPassed_readThrowsNPE() + throws BadCiphertextException, IOException { + final Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "nullReadBuffer2"); - @Before - public void setup() throws IOException { - encryptionClient_ = AwsCrypto.standard(); - } + final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); + final InputStream encryptionInStream = + encryptionClient_.createEncryptingStream(customerMasterKey, inStream, encryptionContext); - @Test - public void doEncryptDecryptWithoutEncContext() throws Exception { - testRoundTrip( - 1_000_000, - awsCrypto -> { - }, - encryptWithoutContext(), - basicDecrypt(), - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + encryptionInStream.read(null, 0, 0); + } - @Test - public void encryptBytesDecryptStream() throws Exception { - Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "encryptBytesDecryptStream"); - - testRoundTrip( - 1_000_000, - awsCrypto -> { - }, - (AwsCrypto awsCrypto, InputStream inStream, OutputStream outStream) -> { - ByteArrayOutputStream inbuf = new ByteArrayOutputStream(); - TestIOUtils.copyInStreamToOutStream(inStream, inbuf); - - CryptoResult ciphertext = awsCrypto.encryptData( - customerMasterKey, - inbuf.toByteArray(), - encryptionContext - ); - - outStream.write(ciphertext.getResult()); - }, - basicDecrypt(), - CommitmentPolicy.RequireEncryptRequireDecrypt - ); + @Test + public void zeroReadLen() throws BadCiphertextException, IOException { + final Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "zeroReadLen"); - } + final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); + final InputStream encryptionInStream = + encryptionClient_.createEncryptingStream(customerMasterKey, inStream, encryptionContext); - @Test - public void encryptStreamDecryptBytes() throws Exception { - Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "encryptStreamDecryptBytes"); - testRoundTrip( - 1_000_000, - awsCrypto -> { - }, - encryptWithContext(encryptionContext), - (AwsCrypto awsCrypto, InputStream inStream, OutputStream outStream) -> { - ByteArrayOutputStream inbuf = new ByteArrayOutputStream(); - TestIOUtils.copyInStreamToOutStream(inStream, inbuf); - - CryptoResult ciphertext = awsCrypto.decryptData( - customerMasterKey, - inbuf.toByteArray() - ); - - outStream.write(ciphertext.getResult()); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); + final byte[] tempBytes = new byte[0]; + final int readLen = encryptionInStream.read(tempBytes); + assertEquals(readLen, 0); + } - } + @SuppressWarnings("ResultOfMethodCallIgnored") + @Test(expected = IllegalArgumentException.class) + public void negativeReadLen() throws BadCiphertextException, IOException { + final Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "negativeReadLen"); - @Test - public void encryptOSDecryptIS() throws Exception { - Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "encryptOSDecryptIS"); - - testRoundTrip( - 1_000_000, - awsCrypto -> { - }, - (awsCrypto, inStream, outStream) -> { - OutputStream cryptoOS - = awsCrypto.createEncryptingStream(customerMasterKey, outStream, encryptionContext); - TestIOUtils.copyInStreamToOutStream(inStream, cryptoOS); - }, - basicDecrypt(), - CommitmentPolicy.RequireEncryptRequireDecrypt - ); + final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); + final InputStream encryptionInStream = + encryptionClient_.createEncryptingStream(customerMasterKey, inStream, encryptionContext); - } + final byte[] tempBytes = new byte[1]; + encryptionInStream.read(tempBytes, 0, -1); + } - private void singleByteCopyLoop(InputStream is, OutputStream os) throws Exception { - int rv; - while (-1 != (rv = is.read())) { - os.write(rv); - } + @SuppressWarnings("ResultOfMethodCallIgnored") + @Test(expected = IllegalArgumentException.class) + public void negativeReadOffset() throws BadCiphertextException, IOException { + final Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "negativeReadOffset"); - is.close(); - os.close(); - } + final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); + final InputStream encryptionInStream = + encryptionClient_.createEncryptingStream(customerMasterKey, inStream, encryptionContext); - @Test - public void singleByteRead() throws Exception { - Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "singleByteRead"); - - testRoundTrip( - 1_000_000, - awsCrypto -> { - }, - (awsCrypto, inStream, outStream) -> { - InputStream is = awsCrypto.createEncryptingStream(customerMasterKey, inStream, - encryptionContext); - singleByteCopyLoop(is, outStream); - }, - (awsCrypto, inStream, outStream) -> { - InputStream is = awsCrypto.createDecryptingStream(customerMasterKey, inStream); - singleByteCopyLoop(is, outStream); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + byte[] tempBytes = new byte[1]; + encryptionInStream.read(tempBytes, -1, tempBytes.length); + } - @SuppressWarnings({"ConstantConditions", "ResultOfMethodCallIgnored"}) - @Test(expected = NullPointerException.class) - public void whenNullBufferPassed_andNoOffsetArgs_readThrowsNPE() throws BadCiphertextException, IOException { - final Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "nullReadBuffer"); + @SuppressWarnings("ResultOfMethodCallIgnored") + @Test(expected = ArrayIndexOutOfBoundsException.class) + public void invalidReadOffset() throws BadCiphertextException, IOException { + final Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "invalidReadOffset"); - final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); - final InputStream encryptionInStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - inStream, - encryptionContext); + final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); + final InputStream encryptionInStream = + encryptionClient_.createEncryptingStream(customerMasterKey, inStream, encryptionContext); - encryptionInStream.read(null); - } + final byte[] tempBytes = new byte[100]; + encryptionInStream.read(tempBytes, tempBytes.length + 1, tempBytes.length); + } - @SuppressWarnings({"ConstantConditions", "ResultOfMethodCallIgnored"}) - @Test(expected = NullPointerException.class) - public void whenNullBufferPassed_andOffsetArgsPassed_readThrowsNPE() throws BadCiphertextException, IOException { - final Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "nullReadBuffer2"); + @Test + public void noOpStream() throws IOException { + final Map encryptionContext = new HashMap<>(1); + encryptionContext.put("ENC", "noOpStream"); - final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); - final InputStream encryptionInStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - inStream, - encryptionContext); + final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); + final InputStream encryptionInStream = + encryptionClient_.createEncryptingStream(customerMasterKey, inStream, encryptionContext); - encryptionInStream.read(null, 0, 0); - } + encryptionInStream.close(); + } - @Test - public void zeroReadLen() throws BadCiphertextException, IOException { - final Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "zeroReadLen"); + @Test + public void decryptEmptyFile() throws IOException { + final InputStream inStream = new ByteArrayInputStream(new byte[0]); + final InputStream decryptionInStream = + encryptionClient_.createDecryptingStream(customerMasterKey, inStream); + final ByteArrayOutputStream outStream = new ByteArrayOutputStream(); - final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); - final InputStream encryptionInStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - inStream, - encryptionContext); + TestIOUtils.copyInStreamToOutStream(decryptionInStream, outStream); - final byte[] tempBytes = new byte[0]; - final int readLen = encryptionInStream.read(tempBytes); - assertEquals(readLen, 0); - } + assertEquals(0, outStream.size()); + } - @SuppressWarnings("ResultOfMethodCallIgnored") - @Test(expected = IllegalArgumentException.class) - public void negativeReadLen() throws BadCiphertextException, IOException { - final Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "negativeReadLen"); + @Test + public void checkEncContext() throws Exception { + Map setEncryptionContext = new HashMap<>(1); + setEncryptionContext.put("ENC", "checkEncContext"); + + testRoundTrip( + 4096, + awsCrypto -> {}, + encryptWithContext(setEncryptionContext), + (crypto, inStream, outStream) -> { + CryptoInputStream cis = crypto.createDecryptingStream(customerMasterKey, inStream); + TestIOUtils.copyInStreamToOutStream(cis, outStream); + + // Note that the crypto result might have additional entries in its context, so only + // check that + // the entries we set were present, not that the entire map is equal + CryptoResult, ?> cryptoResult = cis.getCryptoResult(); + setEncryptionContext.forEach( + (k, v) -> assertEquals(v, cryptoResult.getEncryptionContext().get(k))); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); - final InputStream encryptionInStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - inStream, - encryptionContext); + @Test + public void checkKeyId() throws Exception { + testRoundTrip( + 4096, + awsCrypto -> {}, + encryptWithoutContext(), + (crypto, inStream, outStream) -> { + CryptoInputStream cis = crypto.createDecryptingStream(customerMasterKey, inStream); + TestIOUtils.copyInStreamToOutStream(cis, outStream); + + CryptoResult, ?> cryptoResult = cis.getCryptoResult(); + final String returnedKeyId = cryptoResult.getMasterKeys().get(0).getKeyId(); + + assertEquals("mockKey", returnedKeyId); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - final byte[] tempBytes = new byte[1]; - encryptionInStream.read(tempBytes, 0, -1); - } + @Test + public void checkAvailable() throws IOException { + final int byteSize = 128; + final byte[] inBytes = TestIOUtils.generateRandomPlaintext(byteSize); + final InputStream inStream = new ByteArrayInputStream(inBytes); - @SuppressWarnings("ResultOfMethodCallIgnored") - @Test(expected = IllegalArgumentException.class) - public void negativeReadOffset() throws BadCiphertextException, IOException { - final Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "negativeReadOffset"); + final int frameSize = AwsCrypto.getDefaultFrameSize(); + encryptionClient_.setEncryptionFrameSize(frameSize); - final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); - final InputStream encryptionInStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - inStream, - encryptionContext); + Map setEncryptionContext = new HashMap<>(1); + setEncryptionContext.put("ENC", "Streaming Test"); - byte[] tempBytes = new byte[1]; - encryptionInStream.read(tempBytes, -1, tempBytes.length); - } + // encryption + final InputStream encryptionInStream = + encryptionClient_.createEncryptingStream( + customerMasterKey, inStream, setEncryptionContext); - @SuppressWarnings("ResultOfMethodCallIgnored") - @Test(expected = ArrayIndexOutOfBoundsException.class) - public void invalidReadOffset() throws BadCiphertextException, IOException { - final Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "invalidReadOffset"); + assertEquals(byteSize, encryptionInStream.available()); + } - final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); - final InputStream encryptionInStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - inStream, - encryptionContext); + @Test + public void whenGetResultCalledTooEarly_noExceptionThrown() throws Exception { + testRoundTrip( + 1024, + awsCrypto -> {}, + (awsCrypto, inStream, outStream) -> { + final CryptoInputStream cryptoStream = + awsCrypto.createEncryptingStream(customerMasterKey, inStream); - final byte[] tempBytes = new byte[100]; - encryptionInStream.read(tempBytes, tempBytes.length + 1, tempBytes.length); - } + // can invoke at any time on encrypt + cryptoStream.getCryptoResult(); - @Test - public void noOpStream() throws IOException { - final Map encryptionContext = new HashMap<>(1); - encryptionContext.put("ENC", "noOpStream"); + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - final InputStream inStream = new ByteArrayInputStream(TestUtils.insecureRandomBytes(2048)); - final InputStream encryptionInStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - inStream, - encryptionContext); + cryptoStream.getCryptoResult(); + }, + (awsCrypto, inStream, outStream) -> { + final CryptoInputStream cryptoStream = + awsCrypto.createDecryptingStream(customerMasterKey, inStream); - encryptionInStream.close(); - } + // this will implicitly read the crypto headers + cryptoStream.getCryptoResult(); - @Test - public void decryptEmptyFile() throws IOException { - final InputStream inStream = new ByteArrayInputStream(new byte[0]); - final InputStream decryptionInStream = encryptionClient_.createDecryptingStream( - customerMasterKey, - inStream); - final ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - TestIOUtils.copyInStreamToOutStream(decryptionInStream, outStream); + // still works + cryptoStream.getCryptoResult(); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - assertEquals(0, outStream.size()); - } + @Test(expected = BadCiphertextException.class) + public void whenGetResultInvokedOnEmptyStream_exceptionThrown() throws IOException { + final CryptoInputStream cryptoStream = + encryptionClient_.createDecryptingStream( + customerMasterKey, new ByteArrayInputStream(new byte[0])); - @Test - public void checkEncContext() throws Exception { - Map setEncryptionContext = new HashMap<>(1); - setEncryptionContext.put("ENC", "checkEncContext"); - - testRoundTrip( - 4096, - awsCrypto -> { - }, - encryptWithContext(setEncryptionContext), - (crypto, inStream, outStream) -> { - CryptoInputStream cis = crypto.createDecryptingStream(customerMasterKey, inStream); - TestIOUtils.copyInStreamToOutStream(cis, outStream); - - // Note that the crypto result might have additional entries in its context, so only check that - // the entries we set were present, not that the entire map is equal - CryptoResult, ?> cryptoResult = cis.getCryptoResult(); - setEncryptionContext.forEach( - (k, v) -> assertEquals(v, cryptoResult.getEncryptionContext().get(k)) - ); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + cryptoStream.getCryptoResult(); + } - @Test - public void checkKeyId() throws Exception { - testRoundTrip( - 4096, - awsCrypto -> { - }, - encryptWithoutContext(), - (crypto, inStream, outStream) -> { - CryptoInputStream cis = crypto.createDecryptingStream(customerMasterKey, inStream); - TestIOUtils.copyInStreamToOutStream(cis, outStream); - - CryptoResult, ?> cryptoResult = cis.getCryptoResult(); - final String returnedKeyId = cryptoResult.getMasterKeys().get(0).getKeyId(); - - assertEquals("mockKey", returnedKeyId); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + @Test() + public void encryptUsingCryptoMaterialsManager() throws Exception { + RecordingMaterialsManager cmm = new RecordingMaterialsManager(customerMasterKey); + testRoundTrip( + 1024, + awsCrypto -> {}, + (crypto, inStream, outStream) -> { + final CryptoInputStream cryptoStream = crypto.createEncryptingStream(cmm, inStream); - @Test - public void checkAvailable() throws IOException { - final int byteSize = 128; - final byte[] inBytes = TestIOUtils.generateRandomPlaintext(byteSize); - final InputStream inStream = new ByteArrayInputStream(inBytes); + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - final int frameSize = AwsCrypto.getDefaultFrameSize(); - encryptionClient_.setEncryptionFrameSize(frameSize); + assertEquals("bar", cryptoStream.getCryptoResult().getEncryptionContext().get("foo")); + }, + basicDecrypt(), + commitmentPolicy); + } - Map setEncryptionContext = new HashMap<>(1); - setEncryptionContext.put("ENC", "Streaming Test"); + @Test + public void decryptUsingCryptoMaterialsManager() throws Exception { + RecordingMaterialsManager cmm = new RecordingMaterialsManager(customerMasterKey); - // encryption - final InputStream encryptionInStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - inStream, - setEncryptionContext); + testRoundTrip( + 1024, + awsCrypto -> {}, + encryptWithoutContext(), + (crypto, inStream, outStream) -> { + final CryptoInputStream cryptoStream = crypto.createDecryptingStream(cmm, inStream); - assertEquals(byteSize, encryptionInStream.available()); - } + assertFalse(cmm.didDecrypt); - @Test - public void whenGetResultCalledTooEarly_noExceptionThrown() throws Exception { - testRoundTrip(1024, - awsCrypto -> {}, - (awsCrypto, inStream, outStream) -> { - final CryptoInputStream cryptoStream = awsCrypto.createEncryptingStream( - customerMasterKey, inStream - ); - - // can invoke at any time on encrypt - cryptoStream.getCryptoResult(); - - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - - cryptoStream.getCryptoResult(); - }, - (awsCrypto, inStream, outStream) -> { - final CryptoInputStream cryptoStream = awsCrypto.createDecryptingStream( - customerMasterKey, inStream - ); - - // this will implicitly read the crypto headers - cryptoStream.getCryptoResult(); - - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - - // still works - cryptoStream.getCryptoResult(); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - @Test(expected = BadCiphertextException.class) - public void whenGetResultInvokedOnEmptyStream_exceptionThrown() throws IOException { - final CryptoInputStream cryptoStream = encryptionClient_.createDecryptingStream( - customerMasterKey, - new ByteArrayInputStream(new byte[0]) - ); + assertTrue(cmm.didDecrypt); + }, + commitmentPolicy); + } - cryptoStream.getCryptoResult(); - } + @Test + public void whenStreamSizeSetEarly_streamSizePassedToCMM() throws Exception { + CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); - @Test() - public void encryptUsingCryptoMaterialsManager() throws Exception { - RecordingMaterialsManager cmm = new RecordingMaterialsManager(customerMasterKey); - testRoundTrip( - 1024, - awsCrypto -> {}, - (crypto, inStream, outStream) -> { - final CryptoInputStream cryptoStream = crypto.createEncryptingStream(cmm, inStream); - - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - - assertEquals("bar", cryptoStream.getCryptoResult().getEncryptionContext().get("foo")); - }, - basicDecrypt(), - commitmentPolicy - ); - } + CryptoInputStream is = + AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayInputStream(new byte[1])); - @Test - public void decryptUsingCryptoMaterialsManager() throws Exception { - RecordingMaterialsManager cmm = new RecordingMaterialsManager(customerMasterKey); + is.setMaxInputLength(1); - testRoundTrip( - 1024, - awsCrypto -> {}, - encryptWithoutContext(), - (crypto, inStream, outStream) -> { - final CryptoInputStream cryptoStream = crypto.createDecryptingStream(cmm, inStream); + is.read(); - assertFalse(cmm.didDecrypt); + ArgumentCaptor captor = + ArgumentCaptor.forClass(EncryptionMaterialsRequest.class); + verify(cmm).getMaterialsForEncrypt(captor.capture()); - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); + assertEquals(1L, captor.getValue().getPlaintextSize()); + } - assertTrue(cmm.didDecrypt); - }, - commitmentPolicy - ); - } + @Test + public void whenStreamSizeSetEarly_andExceeded_exceptionThrown() throws Exception { + CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); - @Test - public void whenStreamSizeSetEarly_streamSizePassedToCMM() throws Exception { - CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); + CryptoInputStream is = + AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayInputStream(new byte[2])); - CryptoInputStream is - = AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayInputStream(new byte[1])); + is.setMaxInputLength(1); - is.setMaxInputLength(1); + assertThrows(() -> is.read(new byte[65536])); + } - is.read(); + @Test + public void whenStreamSizeSetLate_andExceeded_exceptionThrown() throws Exception { + CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); - ArgumentCaptor captor = ArgumentCaptor.forClass(EncryptionMaterialsRequest.class); - verify(cmm).getMaterialsForEncrypt(captor.capture()); + CryptoInputStream is = + AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayInputStream(new byte[2])); - assertEquals(1L, captor.getValue().getPlaintextSize()); - } + assertThrows( + () -> { + is.read(); + is.setMaxInputLength(1); + is.read(new byte[65536]); + }); + } - @Test - public void whenStreamSizeSetEarly_andExceeded_exceptionThrown() throws Exception { - CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); + @Test + public void whenStreamSizeSet_afterBeingExceeded_exceptionThrown() throws Exception { + CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); - CryptoInputStream is - = AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayInputStream(new byte[2])); + CryptoInputStream is = + AwsCrypto.standard() + .createEncryptingStream(cmm, new ByteArrayInputStream(new byte[1024 * 1024])); + assertThrows( + () -> { + is.read(); is.setMaxInputLength(1); + }); + } - assertThrows(()->is.read(new byte[65536])); - } - @Test - public void whenStreamSizeSetLate_andExceeded_exceptionThrown() throws Exception { - CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); - - CryptoInputStream is - = AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayInputStream(new byte[2])); - - assertThrows(() -> { - is.read(); - is.setMaxInputLength(1); - is.read(new byte[65536]); - }); - } + @Test + public void whenStreamSizeNegative_setSizeThrows() throws Exception { + CryptoInputStream is = + AwsCrypto.standard() + .createEncryptingStream(customerMasterKey, new ByteArrayInputStream(new byte[0])); - @Test - public void whenStreamSizeSet_afterBeingExceeded_exceptionThrown() throws Exception { - CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); + assertThrows(() -> is.setMaxInputLength(-1)); + } - CryptoInputStream is - = AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayInputStream(new byte[1024*1024])); + @Test + public void whenStreamSizeSet_roundTripSucceeds() throws Exception { + testRoundTrip( + 1024, + awsCrypto -> {}, + (awsCrypto, inStream, outStream) -> { + final CryptoInputStream cryptoStream = + awsCrypto.createEncryptingStream(customerMasterKey, inStream); - assertThrows(() -> { - is.read(); - is.setMaxInputLength(1); - }); - } + // we happen to know inStream is a ByteArrayInputStream which will give an accurate + // number + // of bytes remaining on .available() + cryptoStream.setMaxInputLength(inStream.available()); - @Test - public void whenStreamSizeNegative_setSizeThrows() throws Exception { - CryptoInputStream is - = AwsCrypto.standard().createEncryptingStream(customerMasterKey, new ByteArrayInputStream(new byte[0])); + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); + }, + (awsCrypto, inStream, outStream) -> { + final CryptoInputStream cryptoStream = + awsCrypto.createDecryptingStream(customerMasterKey, inStream); - assertThrows(() -> is.setMaxInputLength(-1)); - } + cryptoStream.setMaxInputLength(inStream.available()); - @Test - public void whenStreamSizeSet_roundTripSucceeds() throws Exception { - testRoundTrip( - 1024, - awsCrypto -> {}, - (awsCrypto, inStream, outStream) -> { - final CryptoInputStream cryptoStream = awsCrypto.createEncryptingStream( - customerMasterKey, - inStream - ); - - // we happen to know inStream is a ByteArrayInputStream which will give an accurate number - // of bytes remaining on .available() - cryptoStream.setMaxInputLength(inStream.available()); - - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - }, - (awsCrypto, inStream, outStream) -> { - final CryptoInputStream cryptoStream = awsCrypto.createDecryptingStream( - customerMasterKey, - inStream); - - cryptoStream.setMaxInputLength(inStream.available()); - - TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + TestIOUtils.copyInStreamToOutStream(cryptoStream, outStream); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/CryptoOutputStreamTest.java b/src/test/java/com/amazonaws/encryptionsdk/CryptoOutputStreamTest.java index 61adb4ede..466f75655 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/CryptoOutputStreamTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/CryptoOutputStreamTest.java @@ -16,7 +16,10 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import javax.crypto.spec.SecretKeySpec; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; +import com.amazonaws.encryptionsdk.internal.TestIOUtils; +import com.amazonaws.encryptionsdk.jce.JceMasterKey; +import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -32,7 +35,7 @@ import java.util.UUID; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; - +import javax.crypto.spec.SecretKeySpec; import org.junit.Before; import org.junit.Test; import org.junit.experimental.runners.Enclosed; @@ -40,645 +43,627 @@ import org.junit.runners.Parameterized; import org.mockito.ArgumentCaptor; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; -import com.amazonaws.encryptionsdk.internal.TestIOUtils; -import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; - @RunWith(Enclosed.class) public class CryptoOutputStreamTest { - private static final SecureRandom RND = new SecureRandom(); - private static final MasterKey customerMasterKey; - private static final AtomicReference RANDOM_BUFFER = new AtomicReference<>(new byte[0]); - private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; - - static { - byte[] rawKey = new byte[16]; - RND.nextBytes(rawKey); - customerMasterKey = JceMasterKey.getInstance(new SecretKeySpec(rawKey, "AES"), "mockProvider", "mockKey", - "AES/GCM/NoPadding"); + private static final SecureRandom RND = new SecureRandom(); + private static final MasterKey customerMasterKey; + private static final AtomicReference RANDOM_BUFFER = new AtomicReference<>(new byte[0]); + private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; + + static { + byte[] rawKey = new byte[16]; + RND.nextBytes(rawKey); + customerMasterKey = + JceMasterKey.getInstance( + new SecretKeySpec(rawKey, "AES"), "mockProvider", "mockKey", "AES/GCM/NoPadding"); + } + + private static void testRoundTrip( + int dataSize, + Consumer customizer, + Callback onEncrypt, + Callback onDecrypt, + CommitmentPolicy commitmentPolicy) + throws Exception { + AwsCrypto awsCrypto = AwsCrypto.builder().withCommitmentPolicy(commitmentPolicy).build(); + customizer.accept(awsCrypto); + + byte[] plaintext = insecureRandomBytes(dataSize); + + ByteArrayInputStream inputStream = new ByteArrayInputStream(plaintext); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + onEncrypt.process(awsCrypto, inputStream, outputStream); + + inputStream = new ByteArrayInputStream(outputStream.toByteArray()); + outputStream = new ByteArrayOutputStream(); + + onDecrypt.process(awsCrypto, inputStream, outputStream); + + assertArrayEquals(getSha256Hash(plaintext), getSha256Hash(outputStream.toByteArray())); + } + + private interface Callback { + void process(AwsCrypto crypto, InputStream inStream, OutputStream outStream) throws Exception; + } + + private static Callback encryptWithContext(Map encryptionContext) { + return (awsCrypto, inStream, outStream) -> { + final OutputStream encryptionOutStream = + awsCrypto.createEncryptingStream(customerMasterKey, outStream, encryptionContext); + + TestIOUtils.copyInStreamToOutStream(inStream, encryptionOutStream); + }; + } + + private static Callback encryptWithoutContext() { + return (awsCrypto, inStream, outStream) -> { + final OutputStream encryptionOutStream = + awsCrypto.createEncryptingStream(customerMasterKey, outStream); + + TestIOUtils.copyInStreamToOutStream(inStream, encryptionOutStream); + }; + } + + private static Callback basicDecrypt(int readLen) { + return (awsCrypto, inStream, outStream) -> { + final OutputStream decryptionOutStream = + awsCrypto.createDecryptingStream(customerMasterKey, outStream); + + TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream, readLen); + }; + } + + private static Callback basicDecrypt() { + return (awsCrypto, inStream, outStream) -> { + final OutputStream decryptionOutStream = + awsCrypto.createDecryptingStream(customerMasterKey, outStream); + + TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); + }; + } + + @RunWith(Parameterized.class) + public static class ParameterizedEncryptDecryptTest { + private final CryptoAlgorithm cryptoAlg; + private final int byteSize, frameSize, readLen; + + public ParameterizedEncryptDecryptTest( + CryptoAlgorithm cryptoAlg, int byteSize, int frameSize, int readLen) { + this.cryptoAlg = cryptoAlg; + this.byteSize = byteSize; + this.frameSize = frameSize; + this.readLen = readLen; + } + + @Parameterized.Parameters( + name = "{index}: encryptDecrypt(algorithm={0}, byteSize={1}, frameSize={2}, readLen={3})") + public static Collection encryptDecryptParams() { + ArrayList cases = new ArrayList<>(); + + for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { + final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); + + // iterate over frame size to test + for (int i = 0; i < frameSizeToTest.length; i++) { + final int frameSize = frameSizeToTest[i]; + int[] bytesToTest = { + 0, + 1, + frameSize - 1, + frameSize, + frameSize + 1, + (int) (frameSize * 1.5), + frameSize * 2, + 1000000 + }; + + if (isFastTestSuiteActive()) { + // Exclude the last two sizes, as they're the slowest + bytesToTest = Arrays.copyOfRange(bytesToTest, 0, bytesToTest.length - 2); + } + + // iterate over byte size to test + for (int j = 0; j < bytesToTest.length; j++) { + final int byteSize = bytesToTest[j]; + int[] readLenVals = {byteSize - 1, byteSize, byteSize + 1, byteSize * 2, 1000000}; + + if (isFastTestSuiteActive()) { + // Only test one read() call buffer length in the fast tests. This greatly cuts down + // on + // the combinatorial explosion of test cases here. + readLenVals = Arrays.copyOfRange(readLenVals, 0, 1); + } + + // iterate over read lengths to test + for (int k = 0; k < readLenVals.length; k++) { + final int readLen = readLenVals[k]; + if (byteSize >= 0 && readLen > 0) { + cases.add(new Object[] {cryptoAlg, byteSize, frameSize, readLen}); + } + } + } + } + } + + return cases; + } + + @Test + public void encryptDecrypt() throws Exception { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "Streaming Test"); + final CommitmentPolicy commitmentPolicy = + cryptoAlg.isCommitting() + ? CommitmentPolicy.RequireEncryptRequireDecrypt + : CommitmentPolicy.ForbidEncryptAllowDecrypt; + + testRoundTrip( + byteSize, + awsCrypto -> { + awsCrypto.setEncryptionFrameSize(frameSize); + awsCrypto.setEncryptionAlgorithm(cryptoAlg); + }, + encryptWithContext(encryptionContext), + basicDecrypt(readLen), + commitmentPolicy); + } + } + + public static class NonParameterized { + private AwsCrypto encryptionClient_; + + public NonParameterized() {} + + @Before + public void setup() throws IOException { + encryptionClient_ = AwsCrypto.standard(); } - private static void testRoundTrip( - int dataSize, - Consumer customizer, - Callback onEncrypt, - Callback onDecrypt, - CommitmentPolicy commitmentPolicy - ) throws Exception { - AwsCrypto awsCrypto = AwsCrypto.builder().withCommitmentPolicy(commitmentPolicy).build(); - customizer.accept(awsCrypto); + @Test + public void singleByteWrite() throws Exception { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "Streaming Test"); - byte[] plaintext = insecureRandomBytes(dataSize); + testRoundTrip( + 10_000, + awsCrypto -> {}, + (awsCrypto, inStream, outStream) -> { + final OutputStream encryptionOutStream = + awsCrypto.createEncryptingStream(customerMasterKey, outStream, encryptionContext); - ByteArrayInputStream inputStream = new ByteArrayInputStream(plaintext); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - onEncrypt.process(awsCrypto, inputStream, outputStream); + // write a single plaintext byte at a time + final byte[] writeBytes = new byte[2048]; + int read_len = 0; + while (read_len >= 0) { + read_len = inStream.read(writeBytes); + if (read_len > 0) { + for (int i = 0; i < read_len; i++) { + encryptionOutStream.write(writeBytes[i]); + } + } + } - inputStream = new ByteArrayInputStream(outputStream.toByteArray()); - outputStream = new ByteArrayOutputStream(); + encryptionOutStream.close(); + }, + (awsCrypto, inStream, outStream) -> { + final OutputStream decryptionOutStream = + awsCrypto.createDecryptingStream(customerMasterKey, outStream); - onDecrypt.process(awsCrypto, inputStream, outputStream); + // write a single decrypted byte at a time + final byte[] writeBytes = new byte[2048]; + int read_len = 0; + while (read_len >= 0) { + read_len = inStream.read(writeBytes); + if (read_len > 0) { + for (int i = 0; i < read_len; i++) { + decryptionOutStream.write(writeBytes[i]); + } + } + } - assertArrayEquals(getSha256Hash(plaintext), getSha256Hash(outputStream.toByteArray())); + decryptionOutStream.close(); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); } - private interface Callback { - void process(AwsCrypto crypto, InputStream inStream, OutputStream outStream) throws Exception; + @Test + public void doEncryptDecryptWithoutEncContext() throws Exception { + testRoundTrip( + 1_000_000, + awsCrypto -> {}, + encryptWithoutContext(), + basicDecrypt(), + CommitmentPolicy.RequireEncryptRequireDecrypt); } - private static Callback encryptWithContext(Map encryptionContext) { - return (awsCrypto, inStream, outStream) -> { - final OutputStream encryptionOutStream = awsCrypto.createEncryptingStream( - customerMasterKey, - outStream, - encryptionContext); + @Test + public void doEncryptDecryptWithContext() throws Exception { + Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "Streaming Test: inputStreamCompatiblilty"); + + testRoundTrip( + 1_000_000, + awsCrypto -> awsCrypto.setEncryptionFrameSize(getDefaultFrameSize()), + encryptWithContext(encryptionContext), + basicDecrypt(), + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - TestIOUtils.copyInStreamToOutStream(inStream, encryptionOutStream); - }; + @Test + public void encryptOneShotDecryptStream() throws Exception { + Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "encryptAPICompatibility"); + + testRoundTrip( + 1_000_000, + awsCrypto -> {}, + (crypto, inStream, outStream) -> { + outStream.write( + encryptionClient_ + .encryptData(customerMasterKey, toByteArray(inStream), encryptionContext) + .getResult()); + }, + (crypto, inStream, outStream) -> { + final OutputStream decryptionOutStream = + encryptionClient_.createDecryptingStream(customerMasterKey, outStream); + + decryptionOutStream.write(toByteArray(inStream)); + decryptionOutStream.close(); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); } - private static Callback encryptWithoutContext() { - return (awsCrypto, inStream, outStream) -> { - final OutputStream encryptionOutStream = awsCrypto.createEncryptingStream( - customerMasterKey, - outStream); + @Test + public void encryptStreamDecryptOneShot() throws Exception { + Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "decryptAPICompatibility"); + + testRoundTrip( + 1_000_000, + awsCrypto -> {}, + (crypto, inStream, outStream) -> { + final OutputStream encryptionOutStream = + encryptionClient_.createEncryptingStream( + customerMasterKey, outStream, encryptionContext); TestIOUtils.copyInStreamToOutStream(inStream, encryptionOutStream); - }; + }, + (crypto, inStream, outStream) -> { + outStream.write( + encryptionClient_ + .decryptData(customerMasterKey, toByteArray(inStream)) + .getResult()); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); } - private static Callback basicDecrypt(int readLen) { - return (awsCrypto, inStream, outStream) -> { - final OutputStream decryptionOutStream = awsCrypto.createDecryptingStream( - customerMasterKey, - outStream); + @Test(expected = IllegalArgumentException.class) + public void nullWrite() throws IOException { + final OutputStream outStream = new ByteArrayOutputStream(); + final OutputStream encryptionOutStream = + encryptionClient_.createEncryptingStream(customerMasterKey, outStream); - TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream, readLen); - }; + encryptionOutStream.write(null); } - private static Callback basicDecrypt() { - return (awsCrypto, inStream, outStream) -> { - final OutputStream decryptionOutStream = awsCrypto.createDecryptingStream( - customerMasterKey, - outStream); + @Test(expected = IllegalArgumentException.class) + public void nullWrite2() throws BadCiphertextException, IOException { + final OutputStream outStream = new ByteArrayOutputStream(); + final OutputStream encryptionOutStream = + encryptionClient_.createEncryptingStream(customerMasterKey, outStream); - TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); - }; + encryptionOutStream.write(null, 0, 0); } - @RunWith(Parameterized.class) - public static class ParameterizedEncryptDecryptTest { - private final CryptoAlgorithm cryptoAlg; - private final int byteSize, frameSize, readLen; - - public ParameterizedEncryptDecryptTest( - CryptoAlgorithm cryptoAlg, int byteSize, int frameSize, int readLen - ) { - this.cryptoAlg = cryptoAlg; - this.byteSize = byteSize; - this.frameSize = frameSize; - this.readLen = readLen; - } + @Test(expected = IllegalArgumentException.class) + public void negativeWriteLen() throws BadCiphertextException, IOException { + final OutputStream outStream = new ByteArrayOutputStream(); + final OutputStream encryptionOutStream = + encryptionClient_.createEncryptingStream(customerMasterKey, outStream); - @Parameterized.Parameters(name="{index}: encryptDecrypt(algorithm={0}, byteSize={1}, frameSize={2}, readLen={3})") - public static Collection encryptDecryptParams() { - ArrayList cases = new ArrayList<>(); - - for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { - final int[] frameSizeToTest = TestUtils.getFrameSizesToTest(cryptoAlg); - - // iterate over frame size to test - for (int i = 0; i < frameSizeToTest.length; i++) { - final int frameSize = frameSizeToTest[i]; - int[] bytesToTest = { 0, 1, frameSize - 1, frameSize, frameSize + 1, (int) (frameSize * 1.5), - frameSize * 2, 1000000 }; - - if (isFastTestSuiteActive()) { - // Exclude the last two sizes, as they're the slowest - bytesToTest = Arrays.copyOfRange(bytesToTest, 0, bytesToTest.length - 2); - } - - // iterate over byte size to test - for (int j = 0; j < bytesToTest.length; j++) { - final int byteSize = bytesToTest[j]; - int[] readLenVals = { byteSize - 1, byteSize, byteSize + 1, byteSize * 2, 1000000 }; - - if (isFastTestSuiteActive()) { - // Only test one read() call buffer length in the fast tests. This greatly cuts down on - // the combinatorial explosion of test cases here. - readLenVals = Arrays.copyOfRange(readLenVals, 0, 1); - } - - // iterate over read lengths to test - for (int k = 0; k < readLenVals.length; k++) { - final int readLen = readLenVals[k]; - if (byteSize >= 0 && readLen > 0) { - cases.add(new Object[] { cryptoAlg, byteSize, frameSize, readLen }); - } - } - } - } - } + final byte[] writeBytes = new byte[0]; + encryptionOutStream.write(writeBytes, 0, -1); + } - return cases; - } + @Test(expected = IllegalArgumentException.class) + public void negativeWriteOffset() throws BadCiphertextException, IOException { + final OutputStream outStream = new ByteArrayOutputStream(); + final OutputStream encryptionOutStream = + encryptionClient_.createEncryptingStream(customerMasterKey, outStream); - @Test - public void encryptDecrypt() throws Exception { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "Streaming Test"); - final CommitmentPolicy commitmentPolicy = cryptoAlg.isCommitting() ? CommitmentPolicy.RequireEncryptRequireDecrypt : CommitmentPolicy.ForbidEncryptAllowDecrypt; - - testRoundTrip(byteSize, - awsCrypto -> { - awsCrypto.setEncryptionFrameSize(frameSize); - awsCrypto.setEncryptionAlgorithm(cryptoAlg); - }, - encryptWithContext(encryptionContext), - basicDecrypt(readLen), - commitmentPolicy - ); - } + final byte[] writeBytes = new byte[2048]; + encryptionOutStream.write(writeBytes, -1, writeBytes.length); } - public static class NonParameterized { - private AwsCrypto encryptionClient_; - - public NonParameterized() {} + @Test + public void checkInvalidValues() throws Exception { + // test for the two formats - single-block and frame. + final int[] frameSizeToTest = {0, getDefaultFrameSize()}; + + // iterate over frame size to test + for (int i = 0; i < frameSizeToTest.length; i++) { + final int frameSize = frameSizeToTest[i]; + invalidWriteLen(frameSize); + invalidWriteOffset(frameSize); + noOpStream(frameSize); + } + } - @Before - public void setup() throws IOException { - encryptionClient_ = AwsCrypto.standard(); - } + private void invalidWriteLen(final int frameSize) throws BadCiphertextException, IOException { + AwsCrypto awsCrypto = AwsCrypto.standard(); - @Test - public void singleByteWrite() throws Exception { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "Streaming Test"); - - testRoundTrip(10_000, - awsCrypto -> { - }, - (awsCrypto, inStream, outStream) -> { - final OutputStream encryptionOutStream = awsCrypto.createEncryptingStream( - customerMasterKey, - outStream, - encryptionContext); - - // write a single plaintext byte at a time - final byte[] writeBytes = new byte[2048]; - int read_len = 0; - while (read_len >= 0) { - read_len = inStream.read(writeBytes); - if (read_len > 0) { - for (int i = 0; i < read_len; i++) { - encryptionOutStream.write(writeBytes[i]); - } - } - } - - encryptionOutStream.close(); - }, - (awsCrypto, inStream, outStream) -> { - final OutputStream decryptionOutStream = awsCrypto.createDecryptingStream( - customerMasterKey, - outStream); - - // write a single decrypted byte at a time - final byte[] writeBytes = new byte[2048]; - int read_len = 0; - while (read_len >= 0) { - read_len = inStream.read(writeBytes); - if (read_len > 0) { - for (int i = 0; i < read_len; i++) { - decryptionOutStream.write(writeBytes[i]); - } - } - } - - decryptionOutStream.close(); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + awsCrypto.setEncryptionFrameSize(frameSize); - @Test - public void doEncryptDecryptWithoutEncContext() throws Exception { - testRoundTrip(1_000_000, - awsCrypto -> { - }, - encryptWithoutContext(), - basicDecrypt(), - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + final OutputStream outStream = new ByteArrayOutputStream(); + final OutputStream encryptionOutStream = + awsCrypto.createEncryptingStream(customerMasterKey, outStream); - @Test - public void doEncryptDecryptWithContext() throws Exception { - Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "Streaming Test: inputStreamCompatiblilty"); - - testRoundTrip(1_000_000, - awsCrypto -> awsCrypto.setEncryptionFrameSize(getDefaultFrameSize()), - encryptWithContext(encryptionContext), - basicDecrypt(), - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + final byte[] writeBytes = new byte[2048]; - @Test - public void encryptOneShotDecryptStream() throws Exception { - Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "encryptAPICompatibility"); - - testRoundTrip(1_000_000, - awsCrypto -> { - }, - (crypto, inStream, outStream) -> { - outStream.write(encryptionClient_.encryptData( - customerMasterKey, - toByteArray(inStream), - encryptionContext).getResult()); - }, - (crypto, inStream, outStream) -> { - final OutputStream decryptionOutStream = encryptionClient_.createDecryptingStream( - customerMasterKey, - outStream); - - decryptionOutStream.write(toByteArray(inStream)); - decryptionOutStream.close(); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + assertThrows( + IndexOutOfBoundsException.class, + () -> encryptionOutStream.write(writeBytes, 0, 2 * writeBytes.length)); + } - @Test - public void encryptStreamDecryptOneShot() throws Exception { - Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "decryptAPICompatibility"); - - testRoundTrip(1_000_000, - awsCrypto -> { - }, - (crypto, inStream, outStream) -> { - final OutputStream encryptionOutStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - outStream, - encryptionContext); - - TestIOUtils.copyInStreamToOutStream(inStream, encryptionOutStream); - }, - (crypto, inStream, outStream) -> { - outStream.write( - encryptionClient_.decryptData(customerMasterKey, - toByteArray(inStream)).getResult()); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + private void invalidWriteOffset(final int frameSize) + throws BadCiphertextException, IOException { + AwsCrypto awsCrypto = AwsCrypto.standard(); - @Test(expected = IllegalArgumentException.class) - public void nullWrite() throws IOException { - final OutputStream outStream = new ByteArrayOutputStream(); - final OutputStream encryptionOutStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - outStream); + awsCrypto.setEncryptionFrameSize(frameSize); - encryptionOutStream.write(null); - } + final OutputStream outStream = new ByteArrayOutputStream(); + final OutputStream encryptionOutStream = + awsCrypto.createEncryptingStream(customerMasterKey, outStream); - @Test(expected = IllegalArgumentException.class) - public void nullWrite2() throws BadCiphertextException, IOException { - final OutputStream outStream = new ByteArrayOutputStream(); - final OutputStream encryptionOutStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - outStream); + final byte[] writeBytes = new byte[2048]; - encryptionOutStream.write(null, 0, 0); - } + assertThrows( + IndexOutOfBoundsException.class, + () -> encryptionOutStream.write(writeBytes, writeBytes.length + 1, writeBytes.length)); + } - @Test(expected = IllegalArgumentException.class) - public void negativeWriteLen() throws BadCiphertextException, IOException { - final OutputStream outStream = new ByteArrayOutputStream(); - final OutputStream encryptionOutStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - outStream); + private void noOpStream(final int frameSize) throws Exception { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "noOpStream size " + frameSize); - final byte[] writeBytes = new byte[0]; - encryptionOutStream.write(writeBytes, 0, -1); - } + final ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + final OutputStream encryptionOutStream = + encryptionClient_.createEncryptingStream(customerMasterKey, outStream, encryptionContext); - @Test(expected = IllegalArgumentException.class) - public void negativeWriteOffset() throws BadCiphertextException, IOException { - final OutputStream outStream = new ByteArrayOutputStream(); - final OutputStream encryptionOutStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - outStream); + encryptionOutStream.close(); - final byte[] writeBytes = new byte[2048]; - encryptionOutStream.write(writeBytes, -1, writeBytes.length); - } + testRoundTrip( + 0, + crypto -> crypto.setEncryptionFrameSize(frameSize), + encryptWithContext(encryptionContext), + basicDecrypt(), + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - @Test - public void checkInvalidValues() throws Exception { - // test for the two formats - single-block and frame. - final int[] frameSizeToTest = {0, getDefaultFrameSize()}; - - // iterate over frame size to test - for (int i = 0; i < frameSizeToTest.length; i++) { - final int frameSize = frameSizeToTest[i]; - invalidWriteLen(frameSize); - invalidWriteOffset(frameSize); - noOpStream(frameSize); - } - } + @Test + public void decryptEmptyFile() throws IOException { + final InputStream inStream = new ByteArrayInputStream(new byte[0]); + final ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + final OutputStream decryptionOutStream = + encryptionClient_.createDecryptingStream(customerMasterKey, outStream); - private void invalidWriteLen(final int frameSize) throws BadCiphertextException, IOException { - AwsCrypto awsCrypto = AwsCrypto.standard(); + TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); + inStream.close(); + decryptionOutStream.close(); - awsCrypto.setEncryptionFrameSize(frameSize); + assertEquals(0, outStream.size()); + } - final OutputStream outStream = new ByteArrayOutputStream(); - final OutputStream encryptionOutStream = awsCrypto.createEncryptingStream( - customerMasterKey, - outStream); + @Test + public void checkEncContext() throws Exception { + Map context = new HashMap(1); + context.put("ENC", "Streaming Test " + UUID.randomUUID()); - final byte[] writeBytes = new byte[2048]; + testRoundTrip( + 1, + awsCrypto -> {}, + encryptWithContext(context), + (crypto, inStream, outStream) -> { + final CryptoOutputStream decryptionOutStream = + encryptionClient_.createDecryptingStream(customerMasterKey, outStream); - assertThrows(IndexOutOfBoundsException.class, - () -> encryptionOutStream.write(writeBytes, 0, 2 * writeBytes.length)); + TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); - } + Map getEncryptionContext = + decryptionOutStream.getCryptoResult().getEncryptionContext(); - private void invalidWriteOffset(final int frameSize) throws BadCiphertextException, IOException { - AwsCrypto awsCrypto = AwsCrypto.standard(); + // Since more values may have been added, we need to check to ensure that all + // of setEncryptionContext is present, not that there is nothing else + for (final Map.Entry e : context.entrySet()) { + assertEquals(e.getValue(), getEncryptionContext.get(e.getKey())); + } + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - awsCrypto.setEncryptionFrameSize(frameSize); + @Test + public void checkKeyId() throws Exception { + Map context = new HashMap(1); + context.put("ENC", "Streaming Test " + UUID.randomUUID()); - final OutputStream outStream = new ByteArrayOutputStream(); - final OutputStream encryptionOutStream = awsCrypto.createEncryptingStream( - customerMasterKey, - outStream); + testRoundTrip( + 1, + awsCrypto -> {}, + encryptWithContext(context), + (crypto, inStream, outStream) -> { + final CryptoOutputStream decryptionOutStream = + encryptionClient_.createDecryptingStream(customerMasterKey, outStream); - final byte[] writeBytes = new byte[2048]; + TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); - assertThrows(IndexOutOfBoundsException.class, - () -> encryptionOutStream.write(writeBytes, writeBytes.length + 1, writeBytes.length)); - } + final String returnedKeyId = + decryptionOutStream.getCryptoResult().getMasterKeys().get(0).getKeyId(); - private void noOpStream(final int frameSize) throws Exception { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "noOpStream size " + frameSize); + assertEquals("mockKey", returnedKeyId); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - final ByteArrayOutputStream outStream = new ByteArrayOutputStream(); - final OutputStream encryptionOutStream = encryptionClient_.createEncryptingStream( - customerMasterKey, - outStream, - encryptionContext); + @Test + public void whenGetResultCalledTooEarly_exceptionThrown() throws Exception { + testRoundTrip( + 1024, + awsCrypto -> {}, + (awsCrypto2, inStream, outStream) -> { + final CryptoOutputStream encryptionOutStream = + awsCrypto2.createEncryptingStream(customerMasterKey, outStream); - encryptionOutStream.close(); + // we should be able to get the cryptoResult on the encrypt path immediately + encryptionOutStream.getCryptoResult(); - testRoundTrip(0, - crypto -> crypto.setEncryptionFrameSize(frameSize), - encryptWithContext(encryptionContext), - basicDecrypt(), - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + TestIOUtils.copyInStreamToOutStream(inStream, encryptionOutStream); + }, + (awsCrypto1, inStream, outStream) -> { + final CryptoOutputStream decryptionOutStream = + awsCrypto1.createDecryptingStream(customerMasterKey, outStream); - @Test - public void decryptEmptyFile() throws IOException { - final InputStream inStream = new ByteArrayInputStream(new byte[0]); - final ByteArrayOutputStream outStream = new ByteArrayOutputStream(); - final OutputStream decryptionOutStream = encryptionClient_.createDecryptingStream( - customerMasterKey, - outStream); + // Can't get headers until we write them to the outstream + assertThrows(IllegalStateException.class, decryptionOutStream::getCryptoResult); TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); - inStream.close(); - decryptionOutStream.close(); - assertEquals(0, outStream.size()); - } + // Now we can get headers + decryptionOutStream.getCryptoResult(); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); + } - @Test - public void checkEncContext() throws Exception { - Map context = new HashMap(1); - context.put("ENC", "Streaming Test " + UUID.randomUUID()); - - testRoundTrip(1, - awsCrypto -> { - }, - encryptWithContext(context), - (crypto, inStream, outStream) -> { - final CryptoOutputStream decryptionOutStream - = encryptionClient_.createDecryptingStream( - customerMasterKey, - outStream); - - TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); - - Map getEncryptionContext = decryptionOutStream.getCryptoResult() - .getEncryptionContext(); - - // Since more values may have been added, we need to check to ensure that all - // of setEncryptionContext is present, not that there is nothing else - for (final Map.Entry e : context.entrySet()) { - assertEquals(e.getValue(), getEncryptionContext.get(e.getKey())); - } - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + @Test + public void encryptUsingCryptoMaterialsManager() throws Exception { + RecordingMaterialsManager cmm = new RecordingMaterialsManager(customerMasterKey); - @Test - public void checkKeyId() throws Exception { - Map context = new HashMap(1); - context.put("ENC", "Streaming Test " + UUID.randomUUID()); - - testRoundTrip(1, - awsCrypto -> { - }, - encryptWithContext(context), - (crypto, inStream, outStream) -> { - final CryptoOutputStream decryptionOutStream - = encryptionClient_.createDecryptingStream( - customerMasterKey, - outStream); - - TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); - - final String returnedKeyId = decryptionOutStream.getCryptoResult() - .getMasterKeys() - .get(0) - .getKeyId(); - - assertEquals("mockKey", returnedKeyId); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + testRoundTrip( + 1024, + awsCrypto -> {}, + (crypto, inStream, outStream) -> { + final CryptoOutputStream cryptoStream = + crypto.createEncryptingStream(cmm, outStream); - @Test - public void whenGetResultCalledTooEarly_exceptionThrown() throws Exception { - testRoundTrip( - 1024, - awsCrypto -> {}, - (awsCrypto2, inStream, outStream) -> { - final CryptoOutputStream encryptionOutStream = awsCrypto2.createEncryptingStream( - customerMasterKey, outStream - ); - - // we should be able to get the cryptoResult on the encrypt path immediately - encryptionOutStream.getCryptoResult(); - - TestIOUtils.copyInStreamToOutStream(inStream, encryptionOutStream); - }, - (awsCrypto1, inStream, outStream) -> { - final CryptoOutputStream decryptionOutStream = awsCrypto1.createDecryptingStream( - customerMasterKey, outStream - ); - - // Can't get headers until we write them to the outstream - assertThrows(IllegalStateException.class, decryptionOutStream::getCryptoResult); - - TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); - - // Now we can get headers - decryptionOutStream.getCryptoResult(); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + TestIOUtils.copyInStreamToOutStream(inStream, cryptoStream); - @Test - public void encryptUsingCryptoMaterialsManager() throws Exception { - RecordingMaterialsManager cmm = new RecordingMaterialsManager(customerMasterKey); + assertEquals("bar", cryptoStream.getCryptoResult().getEncryptionContext().get("foo")); + }, + basicDecrypt(), + commitmentPolicy); + } - testRoundTrip( - 1024, - awsCrypto -> {}, - (crypto, inStream, outStream) -> { - final CryptoOutputStream cryptoStream = crypto.createEncryptingStream(cmm, outStream); + @Test + public void decryptUsingCryptoMaterialsManager() throws Exception { + RecordingMaterialsManager cmm = new RecordingMaterialsManager(customerMasterKey); - TestIOUtils.copyInStreamToOutStream(inStream, cryptoStream); + testRoundTrip( + 1024, + awsCrypto -> {}, + encryptWithoutContext(), + (crypto, inStream, outStream) -> { + final CryptoOutputStream cryptoStream = + crypto.createDecryptingStream(cmm, outStream); - assertEquals("bar", cryptoStream.getCryptoResult().getEncryptionContext().get("foo")); - }, - basicDecrypt(), - commitmentPolicy - ); - } + assertFalse(cmm.didDecrypt); - @Test - public void decryptUsingCryptoMaterialsManager() throws Exception { - RecordingMaterialsManager cmm = new RecordingMaterialsManager(customerMasterKey); + TestIOUtils.copyInStreamToOutStream(inStream, cryptoStream); - testRoundTrip( - 1024, - awsCrypto -> {}, - encryptWithoutContext(), - (crypto, inStream, outStream) -> { - final CryptoOutputStream cryptoStream = crypto.createDecryptingStream(cmm, outStream); + assertTrue(cmm.didDecrypt); + }, + commitmentPolicy); + } - assertFalse(cmm.didDecrypt); + @Test + public void whenStreamSizeSetEarly_streamSizePassedToCMM() throws Exception { + CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); - TestIOUtils.copyInStreamToOutStream(inStream, cryptoStream); + CryptoOutputStream os = + AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayOutputStream()); - assertTrue(cmm.didDecrypt); - }, - commitmentPolicy - ); - } + os.setMaxInputLength(1); - @Test - public void whenStreamSizeSetEarly_streamSizePassedToCMM() throws Exception { - CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); + os.write(0); - CryptoOutputStream os = AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayOutputStream()); + ArgumentCaptor captor = + ArgumentCaptor.forClass(EncryptionMaterialsRequest.class); + verify(cmm).getMaterialsForEncrypt(captor.capture()); - os.setMaxInputLength(1); + assertEquals(1L, captor.getValue().getPlaintextSize()); + } - os.write(0); + @Test + public void whenStreamSizeSetEarly_andExceeded_exceptionThrown() throws Exception { + CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); - ArgumentCaptor captor = ArgumentCaptor.forClass(EncryptionMaterialsRequest.class); - verify(cmm).getMaterialsForEncrypt(captor.capture()); + CryptoOutputStream os = + AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayOutputStream()); - assertEquals(1L, captor.getValue().getPlaintextSize()); - } + os.setMaxInputLength(1); + os.write(0); - @Test - public void whenStreamSizeSetEarly_andExceeded_exceptionThrown() throws Exception { - CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); + assertThrows(() -> os.write(0)); + } - CryptoOutputStream os = AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayOutputStream()); + @Test + public void whenStreamSizeSetLate_andExceeded_exceptionThrown() throws Exception { + CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); - os.setMaxInputLength(1); - os.write(0); + CryptoOutputStream os = + AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayOutputStream()); - assertThrows(()-> os.write(0)); - } - @Test - public void whenStreamSizeSetLate_andExceeded_exceptionThrown() throws Exception { - CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); + os.write(0); + os.setMaxInputLength(1); - CryptoOutputStream os = AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayOutputStream()); + assertThrows(() -> os.write(0)); + } - os.write(0); - os.setMaxInputLength(1); + @Test + public void whenStreamSizeSet_afterBeingExceeded_exceptionThrown() throws Exception { + CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); - assertThrows(() -> os.write(0)); - } + CryptoOutputStream os = + AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayOutputStream()); - @Test - public void whenStreamSizeSet_afterBeingExceeded_exceptionThrown() throws Exception { - CryptoMaterialsManager cmm = spy(new DefaultCryptoMaterialsManager(customerMasterKey)); + os.write(0); + os.write(0); - CryptoOutputStream os = AwsCrypto.standard().createEncryptingStream(cmm, new ByteArrayOutputStream()); + assertThrows(() -> os.setMaxInputLength(1)); + } - os.write(0); - os.write(0); + @Test + public void whenStreamSizeNegative_setSizeThrows() throws Exception { + CryptoOutputStream is = + AwsCrypto.standard() + .createEncryptingStream(customerMasterKey, new ByteArrayOutputStream()); - assertThrows(() -> os.setMaxInputLength(1)); - } + assertThrows(() -> is.setMaxInputLength(-1)); + } - @Test - public void whenStreamSizeNegative_setSizeThrows() throws Exception { - CryptoOutputStream is - = AwsCrypto.standard().createEncryptingStream(customerMasterKey, new ByteArrayOutputStream()); + @Test + public void whenStreamSizeSet_roundTripSucceeds() throws Exception { + testRoundTrip( + 1024, + ignored -> {}, + (awsCrypto, inStream, outStream) -> { + final CryptoOutputStream encryptionOutStream = + awsCrypto.createEncryptingStream(customerMasterKey, outStream); - assertThrows(() -> is.setMaxInputLength(-1)); - } + encryptionOutStream.setMaxInputLength(1024); - @Test - public void whenStreamSizeSet_roundTripSucceeds() throws Exception { - testRoundTrip( - 1024, - ignored -> {}, - (awsCrypto, inStream, outStream) -> { - final CryptoOutputStream encryptionOutStream = awsCrypto.createEncryptingStream( - customerMasterKey, - outStream); - - encryptionOutStream.setMaxInputLength(1024); - - TestIOUtils.copyInStreamToOutStream(inStream, encryptionOutStream); - }, - (awsCrypto, inStream, outStream) -> { - final CryptoOutputStream decryptionOutStream = awsCrypto.createDecryptingStream( - customerMasterKey, - outStream); - - // we happen to know inStream is a ByteArrayInputStream which will give an accurate number - // of bytes remaining on .available() - decryptionOutStream.setMaxInputLength(inStream.available()); - - TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); - }, - CommitmentPolicy.RequireEncryptRequireDecrypt - ); - } + TestIOUtils.copyInStreamToOutStream(inStream, encryptionOutStream); + }, + (awsCrypto, inStream, outStream) -> { + final CryptoOutputStream decryptionOutStream = + awsCrypto.createDecryptingStream(customerMasterKey, outStream); + + // we happen to know inStream is a ByteArrayInputStream which will give an accurate + // number + // of bytes remaining on .available() + decryptionOutStream.setMaxInputLength(inStream.available()); + + TestIOUtils.copyInStreamToOutStream(inStream, decryptionOutStream); + }, + CommitmentPolicy.RequireEncryptRequireDecrypt); } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/DecryptionMethod.java b/src/test/java/com/amazonaws/encryptionsdk/DecryptionMethod.java index 54bb4a4aa..7091fa6f6 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/DecryptionMethod.java +++ b/src/test/java/com/amazonaws/encryptionsdk/DecryptionMethod.java @@ -2,180 +2,215 @@ import com.amazonaws.encryptionsdk.internal.SignaturePolicy; import com.amazonaws.encryptionsdk.internal.TestIOUtils; -import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; - import java.io.*; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.Callable; -import java.util.stream.Collectors; -import java.util.stream.Stream; enum DecryptionMethod { - OneShot { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - return crypto.decryptData(masterKeyProvider, ciphertext).getResult(); - } - }, - // Note for the record that changing the readLen parameter of copyInStreamToOutStream has minimal - // effect on the actual data flow when copying from a CryptoInputStream: it will always read from the - // underlying input stream with a fixed chunk size (4096 bytes at the time of writing this), independently - // of how many bytes its asked to read of the decryption result. It's still useful to vary the length to - // ensure the buffering in the CryptoInputStream works correctly though. - InputStreamSingleByteChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = crypto.createDecryptingStream(masterKeyProvider, new ByteArrayInputStream(ciphertext)); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - TestIOUtils.copyInStreamToOutStream(in, out, 1); - return out.toByteArray(); - } - }, - InputStreamSmallByteChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = crypto.createDecryptingStream(masterKeyProvider, new ByteArrayInputStream(ciphertext)); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - TestIOUtils.copyInStreamToOutStream(in, out, SMALL_CHUNK_SIZE); - return out.toByteArray(); - } - }, - InputStreamWholeMessageChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = crypto.createDecryptingStream(masterKeyProvider, new ByteArrayInputStream(ciphertext)); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - TestIOUtils.copyInStreamToOutStream(in, out, ciphertext.length); - return out.toByteArray(); - } - }, - UnsignedMessageInputStreamSingleByteChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = crypto.createUnsignedMessageDecryptingStream(masterKeyProvider, new ByteArrayInputStream(ciphertext)); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - TestIOUtils.copyInStreamToOutStream(in, out, 1); - return out.toByteArray(); - } + OneShot { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + return crypto.decryptData(masterKeyProvider, ciphertext).getResult(); + } + }, + // Note for the record that changing the readLen parameter of copyInStreamToOutStream has minimal + // effect on the actual data flow when copying from a CryptoInputStream: it will always read from + // the + // underlying input stream with a fixed chunk size (4096 bytes at the time of writing this), + // independently + // of how many bytes its asked to read of the decryption result. It's still useful to vary the + // length to + // ensure the buffering in the CryptoInputStream works correctly though. + InputStreamSingleByteChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = + crypto.createDecryptingStream(masterKeyProvider, new ByteArrayInputStream(ciphertext)); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + TestIOUtils.copyInStreamToOutStream(in, out, 1); + return out.toByteArray(); + } + }, + InputStreamSmallByteChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = + crypto.createDecryptingStream(masterKeyProvider, new ByteArrayInputStream(ciphertext)); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + TestIOUtils.copyInStreamToOutStream(in, out, SMALL_CHUNK_SIZE); + return out.toByteArray(); + } + }, + InputStreamWholeMessageChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = + crypto.createDecryptingStream(masterKeyProvider, new ByteArrayInputStream(ciphertext)); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + TestIOUtils.copyInStreamToOutStream(in, out, ciphertext.length); + return out.toByteArray(); + } + }, + UnsignedMessageInputStreamSingleByteChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = + crypto.createUnsignedMessageDecryptingStream( + masterKeyProvider, new ByteArrayInputStream(ciphertext)); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + TestIOUtils.copyInStreamToOutStream(in, out, 1); + return out.toByteArray(); + } - @Override - public SignaturePolicy signaturePolicy() { - return SignaturePolicy.AllowEncryptForbidDecrypt; - } - }, - UnsignedMessageInputStreamSmallByteChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = crypto.createUnsignedMessageDecryptingStream(masterKeyProvider, new ByteArrayInputStream(ciphertext)); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - TestIOUtils.copyInStreamToOutStream(in, out, SMALL_CHUNK_SIZE); - return out.toByteArray(); - } + @Override + public SignaturePolicy signaturePolicy() { + return SignaturePolicy.AllowEncryptForbidDecrypt; + } + }, + UnsignedMessageInputStreamSmallByteChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = + crypto.createUnsignedMessageDecryptingStream( + masterKeyProvider, new ByteArrayInputStream(ciphertext)); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + TestIOUtils.copyInStreamToOutStream(in, out, SMALL_CHUNK_SIZE); + return out.toByteArray(); + } - @Override - public SignaturePolicy signaturePolicy() { - return SignaturePolicy.AllowEncryptForbidDecrypt; - } - }, - UnsignedMessageInputStreamWholeMessageChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = crypto.createUnsignedMessageDecryptingStream(masterKeyProvider, new ByteArrayInputStream(ciphertext)); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - TestIOUtils.copyInStreamToOutStream(in, out, ciphertext.length); - return out.toByteArray(); - } + @Override + public SignaturePolicy signaturePolicy() { + return SignaturePolicy.AllowEncryptForbidDecrypt; + } + }, + UnsignedMessageInputStreamWholeMessageChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = + crypto.createUnsignedMessageDecryptingStream( + masterKeyProvider, new ByteArrayInputStream(ciphertext)); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + TestIOUtils.copyInStreamToOutStream(in, out, ciphertext.length); + return out.toByteArray(); + } - @Override - public SignaturePolicy signaturePolicy() { - return SignaturePolicy.AllowEncryptForbidDecrypt; - } - }, - OutputStreamSingleByteChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = new ByteArrayInputStream(ciphertext); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - OutputStream decryptingOut = crypto.createDecryptingStream(masterKeyProvider, out); - TestIOUtils.copyInStreamToOutStream(in, decryptingOut, 1); - return out.toByteArray(); - } - }, - OutputStreamSmallByteChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = new ByteArrayInputStream(ciphertext); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - OutputStream decryptingOut = crypto.createDecryptingStream(masterKeyProvider, out); - TestIOUtils.copyInStreamToOutStream(in, decryptingOut, SMALL_CHUNK_SIZE); - return out.toByteArray(); - } - }, - OutputStreamWholeMessageChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = new ByteArrayInputStream(ciphertext); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - OutputStream decryptingOut = crypto.createDecryptingStream(masterKeyProvider, out); - TestIOUtils.copyInStreamToOutStream(in, decryptingOut, ciphertext.length); - return out.toByteArray(); - } - }, - UnsignedMessageOutputStreamSingleByteChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = new ByteArrayInputStream(ciphertext); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - OutputStream decryptingOut = crypto.createUnsignedMessageDecryptingStream(masterKeyProvider, out); - TestIOUtils.copyInStreamToOutStream(in, decryptingOut, 1); - return out.toByteArray(); - } + @Override + public SignaturePolicy signaturePolicy() { + return SignaturePolicy.AllowEncryptForbidDecrypt; + } + }, + OutputStreamSingleByteChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = new ByteArrayInputStream(ciphertext); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutputStream decryptingOut = crypto.createDecryptingStream(masterKeyProvider, out); + TestIOUtils.copyInStreamToOutStream(in, decryptingOut, 1); + return out.toByteArray(); + } + }, + OutputStreamSmallByteChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = new ByteArrayInputStream(ciphertext); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutputStream decryptingOut = crypto.createDecryptingStream(masterKeyProvider, out); + TestIOUtils.copyInStreamToOutStream(in, decryptingOut, SMALL_CHUNK_SIZE); + return out.toByteArray(); + } + }, + OutputStreamWholeMessageChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = new ByteArrayInputStream(ciphertext); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutputStream decryptingOut = crypto.createDecryptingStream(masterKeyProvider, out); + TestIOUtils.copyInStreamToOutStream(in, decryptingOut, ciphertext.length); + return out.toByteArray(); + } + }, + UnsignedMessageOutputStreamSingleByteChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = new ByteArrayInputStream(ciphertext); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutputStream decryptingOut = + crypto.createUnsignedMessageDecryptingStream(masterKeyProvider, out); + TestIOUtils.copyInStreamToOutStream(in, decryptingOut, 1); + return out.toByteArray(); + } - @Override - public SignaturePolicy signaturePolicy() { - return SignaturePolicy.AllowEncryptForbidDecrypt; - } - }, - UnsignedMessageOutputStreamSmallByteChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = new ByteArrayInputStream(ciphertext); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - OutputStream decryptingOut = crypto.createUnsignedMessageDecryptingStream(masterKeyProvider, out); - TestIOUtils.copyInStreamToOutStream(in, decryptingOut, SMALL_CHUNK_SIZE); - return out.toByteArray(); - } + @Override + public SignaturePolicy signaturePolicy() { + return SignaturePolicy.AllowEncryptForbidDecrypt; + } + }, + UnsignedMessageOutputStreamSmallByteChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = new ByteArrayInputStream(ciphertext); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutputStream decryptingOut = + crypto.createUnsignedMessageDecryptingStream(masterKeyProvider, out); + TestIOUtils.copyInStreamToOutStream(in, decryptingOut, SMALL_CHUNK_SIZE); + return out.toByteArray(); + } - @Override - public SignaturePolicy signaturePolicy() { - return SignaturePolicy.AllowEncryptForbidDecrypt; - } - }, - UnsignedMessageOutputStreamWholeMessageChunks { - @Override - public byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) throws IOException { - InputStream in = new ByteArrayInputStream(ciphertext); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - OutputStream decryptingOut = crypto.createUnsignedMessageDecryptingStream(masterKeyProvider, out); - TestIOUtils.copyInStreamToOutStream(in, decryptingOut, ciphertext.length); - return out.toByteArray(); - } + @Override + public SignaturePolicy signaturePolicy() { + return SignaturePolicy.AllowEncryptForbidDecrypt; + } + }, + UnsignedMessageOutputStreamWholeMessageChunks { + @Override + public byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException { + InputStream in = new ByteArrayInputStream(ciphertext); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutputStream decryptingOut = + crypto.createUnsignedMessageDecryptingStream(masterKeyProvider, out); + TestIOUtils.copyInStreamToOutStream(in, decryptingOut, ciphertext.length); + return out.toByteArray(); + } - @Override - public SignaturePolicy signaturePolicy() { - return SignaturePolicy.AllowEncryptForbidDecrypt; - } - }; + @Override + public SignaturePolicy signaturePolicy() { + return SignaturePolicy.AllowEncryptForbidDecrypt; + } + }; - // A semi-arbitrary chunk size just to have at least one non-boundary input, and something - // that will span at least some message segments. - private static final int SMALL_CHUNK_SIZE = 100; + // A semi-arbitrary chunk size just to have at least one non-boundary input, and something + // that will span at least some message segments. + private static final int SMALL_CHUNK_SIZE = 100; - public abstract byte[] decryptMessage(AwsCrypto crypto, MasterKeyProvider masterKeyProvider, - byte[] ciphertext) throws IOException; + public abstract byte[] decryptMessage( + AwsCrypto crypto, MasterKeyProvider masterKeyProvider, byte[] ciphertext) + throws IOException; - public SignaturePolicy signaturePolicy() { - return SignaturePolicy.AllowEncryptAllowDecrypt; - } + public SignaturePolicy signaturePolicy() { + return SignaturePolicy.AllowEncryptAllowDecrypt; + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/DefaultCryptoMaterialsManagerTest.java b/src/test/java/com/amazonaws/encryptionsdk/DefaultCryptoMaterialsManagerTest.java index 79128d18a..1ea43507c 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/DefaultCryptoMaterialsManagerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/DefaultCryptoMaterialsManagerTest.java @@ -18,6 +18,17 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; +import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; +import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; +import com.amazonaws.encryptionsdk.internal.Constants; +import com.amazonaws.encryptionsdk.internal.StaticMasterKey; +import com.amazonaws.encryptionsdk.internal.TrailingSignatureAlgorithm; +import com.amazonaws.encryptionsdk.model.DecryptionMaterials; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.model.EncryptionMaterials; +import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; import java.nio.charset.StandardCharsets; import java.security.Signature; import java.util.Arrays; @@ -27,330 +38,334 @@ import java.util.Map; import java.util.Objects; import java.util.function.Consumer; - -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; -import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; -import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; -import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; -import com.amazonaws.encryptionsdk.model.DecryptionMaterials; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; -import com.amazonaws.encryptionsdk.model.EncryptionMaterials; -import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; import org.junit.Test; -import com.amazonaws.encryptionsdk.internal.Constants; -import com.amazonaws.encryptionsdk.internal.StaticMasterKey; -import com.amazonaws.encryptionsdk.internal.TrailingSignatureAlgorithm; - public class DefaultCryptoMaterialsManagerTest { - private static final MasterKey mk1 = new StaticMasterKey("mk1"); - private static final MasterKey mk2 = new StaticMasterKey("mk2"); - private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; - - @Test - public void encrypt_testBasicFunctionality() throws Exception { - EncryptionMaterialsRequest req = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(commitmentPolicy) - .build(); - EncryptionMaterials result = new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); - - assertNotNull(result.getCleartextDataKey()); - assertNotNull(result.getEncryptionContext()); - assertEquals(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, result.getAlgorithm()); - assertEquals(1, result.getEncryptedDataKeys().size()); - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk1, result.getMasterKeys().get(0)); + private static final MasterKey mk1 = new StaticMasterKey("mk1"); + private static final MasterKey mk2 = new StaticMasterKey("mk2"); + private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; + + @Test + public void encrypt_testBasicFunctionality() throws Exception { + EncryptionMaterialsRequest req = + EncryptionMaterialsRequest.newBuilder().setCommitmentPolicy(commitmentPolicy).build(); + EncryptionMaterials result = new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); + + assertNotNull(result.getCleartextDataKey()); + assertNotNull(result.getEncryptionContext()); + assertEquals( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, result.getAlgorithm()); + assertEquals(1, result.getEncryptedDataKeys().size()); + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk1, result.getMasterKeys().get(0)); + } + + @Test + public void encrypt_testNonCommittingDefaultAlgorithm() throws Exception { + EncryptionMaterialsRequest req = + EncryptionMaterialsRequest.newBuilder() + .setCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .build(); + EncryptionMaterials result = new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); + assertEquals( + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, result.getAlgorithm()); + } + + @Test + public void encrypt_testCommittingDefaultAlgorithm() throws Exception { + final List requireWritePolicies = + Arrays.asList( + CommitmentPolicy.RequireEncryptRequireDecrypt, + CommitmentPolicy.RequireEncryptAllowDecrypt); + for (CommitmentPolicy policy : requireWritePolicies) { + EncryptionMaterialsRequest req = + EncryptionMaterialsRequest.newBuilder().setCommitmentPolicy(policy).build(); + EncryptionMaterials result = + new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); + assertEquals( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, result.getAlgorithm()); } - - @Test - public void encrypt_testNonCommittingDefaultAlgorithm() throws Exception { - EncryptionMaterialsRequest req = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) - .build(); - EncryptionMaterials result = new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); - assertEquals(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, result.getAlgorithm()); - } - - @Test - public void encrypt_testCommittingDefaultAlgorithm() throws Exception { - final List requireWritePolicies = Arrays.asList( - CommitmentPolicy.RequireEncryptRequireDecrypt, CommitmentPolicy.RequireEncryptAllowDecrypt); - for (CommitmentPolicy policy : requireWritePolicies) { - EncryptionMaterialsRequest req = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(policy) - .build(); - EncryptionMaterials result = new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); - assertEquals(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, result.getAlgorithm()); - } - } - - @Test - public void encrypt_noSignatureKeyOnUnsignedAlgo() throws Exception { - CryptoAlgorithm[] algorithms = new CryptoAlgorithm[] { - CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256, - CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF, - CryptoAlgorithm.ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA256, - CryptoAlgorithm.ALG_AES_192_GCM_IV12_TAG16_NO_KDF, - CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256, - CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, - CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY + } + + @Test + public void encrypt_noSignatureKeyOnUnsignedAlgo() throws Exception { + CryptoAlgorithm[] algorithms = + new CryptoAlgorithm[] { + CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256, + CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF, + CryptoAlgorithm.ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA256, + CryptoAlgorithm.ALG_AES_192_GCM_IV12_TAG16_NO_KDF, + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256, + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY }; - for (CryptoAlgorithm algo : algorithms) { - EncryptionMaterialsRequest req = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(commitmentPolicy) - .setRequestedAlgorithm(algo) - .build(); - EncryptionMaterials result = new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); - - assertNull(result.getTrailingSignatureKey()); - assertEquals(0, result.getEncryptionContext().size()); - assertEquals(algo, result.getAlgorithm()); - } + for (CryptoAlgorithm algo : algorithms) { + EncryptionMaterialsRequest req = + EncryptionMaterialsRequest.newBuilder() + .setCommitmentPolicy(commitmentPolicy) + .setRequestedAlgorithm(algo) + .build(); + EncryptionMaterials result = + new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); + + assertNull(result.getTrailingSignatureKey()); + assertEquals(0, result.getEncryptionContext().size()); + assertEquals(algo, result.getAlgorithm()); } - - @Test - public void encrypt_hasSignatureKeyForSignedAlgo() throws Exception { - CryptoAlgorithm[] algorithms = new CryptoAlgorithm[] { - CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, - CryptoAlgorithm.ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384 + } + + @Test + public void encrypt_hasSignatureKeyForSignedAlgo() throws Exception { + CryptoAlgorithm[] algorithms = + new CryptoAlgorithm[] { + CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, + CryptoAlgorithm.ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384 }; - for (CryptoAlgorithm algo : algorithms) { + for (CryptoAlgorithm algo : algorithms) { - EncryptionMaterialsRequest req = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(commitmentPolicy) - .setRequestedAlgorithm(algo) - .build(); - EncryptionMaterials result = new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); - - assertNotNull(result.getTrailingSignatureKey()); - assertEquals(1, result.getEncryptionContext().size()); - assertNotNull(result.getEncryptionContext().get(Constants.EC_PUBLIC_KEY_FIELD)); - assertEquals(algo, result.getAlgorithm()); - } - } + EncryptionMaterialsRequest req = + EncryptionMaterialsRequest.newBuilder() + .setCommitmentPolicy(commitmentPolicy) + .setRequestedAlgorithm(algo) + .build(); + EncryptionMaterials result = + new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); - @Test - public void encrypt_dispatchesMultipleMasterKeys() throws Exception { - MasterKey mk1_spy = spy(mk1); - MasterKey mk2_spy = spy(mk2); - - DataKey[] mk1_datakey = new DataKey[1]; - - doAnswer( - invocation -> { - Object dk = invocation.callRealMethod(); - mk1_datakey[0] = (DataKey)dk; - - return dk; - } - ).when(mk1_spy).generateDataKey(any(), any()); - - MasterKeyProvider mkp = buildMultiProvider(mk1_spy, mk2_spy); - - EncryptionMaterialsRequest req = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(commitmentPolicy) - .setContext(singletonMap("foo", "bar")) - .build(); - - EncryptionMaterials result = new DefaultCryptoMaterialsManager(mkp).getMaterialsForEncrypt(req); - - //noinspection unchecked - verify(mk1_spy).generateDataKey( - any(), - // there's a weird generics issue here without downcasting to (Map) - (Map)argThat((Map m) -> Objects.equals(m.get("foo"), "bar")) - ); - - //noinspection unchecked - verify(mk2_spy).encryptDataKey( - any(), - (Map)argThat((Map m) -> Objects.equals(m.get("foo"), "bar")), - same(mk1_datakey[0]) - ); - - assertArrayEquals( - mk1_datakey[0].getKey().getEncoded(), - result.getCleartextDataKey().getEncoded() - ); + assertNotNull(result.getTrailingSignatureKey()); + assertEquals(1, result.getEncryptionContext().size()); + assertNotNull(result.getEncryptionContext().get(Constants.EC_PUBLIC_KEY_FIELD)); + assertEquals(algo, result.getAlgorithm()); } - - @Test - public void encrypt_forwardsPlaintextWhenAvailable() throws Exception { - MasterKey mk1_spy = spy(mk1); - - EncryptionMaterialsRequest request = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(commitmentPolicy) - .setPlaintext(new byte[1]) - .build(); - new DefaultCryptoMaterialsManager(mk1_spy).getMaterialsForEncrypt(request); - - verify(mk1_spy).getMasterKeysForEncryption( - argThat( - req -> Arrays.equals(req.getPlaintext(), new byte[1]) && !req.isStreaming() - ) - ); - } - - @Test - public void encrypt_forwardsPlaintextSizeWhenAvailable() throws Exception { - MasterKey mk1_spy = spy(mk1); - - EncryptionMaterialsRequest request = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(commitmentPolicy) - .setPlaintextSize(1) - .build(); - new DefaultCryptoMaterialsManager(mk1_spy).getMaterialsForEncrypt(request); - - verify(mk1_spy).getMasterKeysForEncryption( - argThat( - req -> req.getSize() == 1 && !req.isStreaming() - ) - ); - } - - @Test - public void encrypt_setsStreamingWhenNoSizeAvailable() throws Exception { - MasterKey mk1_spy = spy(mk1); - - EncryptionMaterialsRequest request = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(commitmentPolicy) - .build(); - new DefaultCryptoMaterialsManager(mk1_spy).getMaterialsForEncrypt(request); - - verify(mk1_spy).getMasterKeysForEncryption( - argThat(MasterKeyRequest::isStreaming) - ); - } - - @Test(expected = IllegalArgumentException.class) - public void encrypt_whenECContextKeyPresent_throws() throws Exception { - EncryptionMaterialsRequest req = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(commitmentPolicy) - .setRequestedAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384) - .setContext(singletonMap(Constants.EC_PUBLIC_KEY_FIELD, "some EC key")) - .build(); - - new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); - } - - @Test(expected = IllegalArgumentException.class) - public void encrypt_whenNoMasterKeys_throws() throws Exception { - EncryptionMaterialsRequest req = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(commitmentPolicy) - .build(); - new DefaultCryptoMaterialsManager(new MasterKeyProvider() { - @Override public String getDefaultProviderId() { + } + + @Test + public void encrypt_dispatchesMultipleMasterKeys() throws Exception { + MasterKey mk1_spy = spy(mk1); + MasterKey mk2_spy = spy(mk2); + + DataKey[] mk1_datakey = new DataKey[1]; + + doAnswer( + invocation -> { + Object dk = invocation.callRealMethod(); + mk1_datakey[0] = (DataKey) dk; + + return dk; + }) + .when(mk1_spy) + .generateDataKey(any(), any()); + + MasterKeyProvider mkp = buildMultiProvider(mk1_spy, mk2_spy); + + EncryptionMaterialsRequest req = + EncryptionMaterialsRequest.newBuilder() + .setCommitmentPolicy(commitmentPolicy) + .setContext(singletonMap("foo", "bar")) + .build(); + + EncryptionMaterials result = new DefaultCryptoMaterialsManager(mkp).getMaterialsForEncrypt(req); + + //noinspection unchecked + verify(mk1_spy) + .generateDataKey( + any(), + // there's a weird generics issue here without downcasting to (Map) + (Map) argThat((Map m) -> Objects.equals(m.get("foo"), "bar"))); + + //noinspection unchecked + verify(mk2_spy) + .encryptDataKey( + any(), + (Map) argThat((Map m) -> Objects.equals(m.get("foo"), "bar")), + same(mk1_datakey[0])); + + assertArrayEquals( + mk1_datakey[0].getKey().getEncoded(), result.getCleartextDataKey().getEncoded()); + } + + @Test + public void encrypt_forwardsPlaintextWhenAvailable() throws Exception { + MasterKey mk1_spy = spy(mk1); + + EncryptionMaterialsRequest request = + EncryptionMaterialsRequest.newBuilder() + .setCommitmentPolicy(commitmentPolicy) + .setPlaintext(new byte[1]) + .build(); + new DefaultCryptoMaterialsManager(mk1_spy).getMaterialsForEncrypt(request); + + verify(mk1_spy) + .getMasterKeysForEncryption( + argThat(req -> Arrays.equals(req.getPlaintext(), new byte[1]) && !req.isStreaming())); + } + + @Test + public void encrypt_forwardsPlaintextSizeWhenAvailable() throws Exception { + MasterKey mk1_spy = spy(mk1); + + EncryptionMaterialsRequest request = + EncryptionMaterialsRequest.newBuilder() + .setCommitmentPolicy(commitmentPolicy) + .setPlaintextSize(1) + .build(); + new DefaultCryptoMaterialsManager(mk1_spy).getMaterialsForEncrypt(request); + + verify(mk1_spy) + .getMasterKeysForEncryption(argThat(req -> req.getSize() == 1 && !req.isStreaming())); + } + + @Test + public void encrypt_setsStreamingWhenNoSizeAvailable() throws Exception { + MasterKey mk1_spy = spy(mk1); + + EncryptionMaterialsRequest request = + EncryptionMaterialsRequest.newBuilder().setCommitmentPolicy(commitmentPolicy).build(); + new DefaultCryptoMaterialsManager(mk1_spy).getMaterialsForEncrypt(request); + + verify(mk1_spy).getMasterKeysForEncryption(argThat(MasterKeyRequest::isStreaming)); + } + + @Test(expected = IllegalArgumentException.class) + public void encrypt_whenECContextKeyPresent_throws() throws Exception { + EncryptionMaterialsRequest req = + EncryptionMaterialsRequest.newBuilder() + .setCommitmentPolicy(commitmentPolicy) + .setRequestedAlgorithm( + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384) + .setContext(singletonMap(Constants.EC_PUBLIC_KEY_FIELD, "some EC key")) + .build(); + + new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(req); + } + + @Test(expected = IllegalArgumentException.class) + public void encrypt_whenNoMasterKeys_throws() throws Exception { + EncryptionMaterialsRequest req = + EncryptionMaterialsRequest.newBuilder().setCommitmentPolicy(commitmentPolicy).build(); + new DefaultCryptoMaterialsManager( + new MasterKeyProvider() { + @Override + public String getDefaultProviderId() { return "provider ID"; - } + } - @Override public MasterKey getMasterKey(String provider, String keyId) throws UnsupportedProviderException, - NoSuchMasterKeyException { + @Override + public MasterKey getMasterKey(String provider, String keyId) + throws UnsupportedProviderException, NoSuchMasterKeyException { throw new NoSuchMasterKeyException(); - } + } - @Override public List getMasterKeysForEncryption(MasterKeyRequest request) { + @Override + public List getMasterKeysForEncryption(MasterKeyRequest request) { return Collections.emptyList(); - } + } - @Override public DataKey decryptDataKey( - CryptoAlgorithm algorithm, Collection encryptedDataKeys, Map encryptionContext - ) throws UnsupportedProviderException, AwsCryptoException { + @Override + public DataKey decryptDataKey( + CryptoAlgorithm algorithm, Collection encryptedDataKeys, Map encryptionContext) + throws UnsupportedProviderException, AwsCryptoException { return null; - } - }).getMaterialsForEncrypt(req); - } - - private EncryptionMaterials easyGenMaterials(Consumer customizer) { - EncryptionMaterialsRequest.Builder request = EncryptionMaterialsRequest.newBuilder().setCommitmentPolicy(commitmentPolicy); - - customizer.accept(request); - - return new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(request.build()); + } + }) + .getMaterialsForEncrypt(req); + } + + private EncryptionMaterials easyGenMaterials( + Consumer customizer) { + EncryptionMaterialsRequest.Builder request = + EncryptionMaterialsRequest.newBuilder().setCommitmentPolicy(commitmentPolicy); + + customizer.accept(request); + + return new DefaultCryptoMaterialsManager(mk1).getMaterialsForEncrypt(request.build()); + } + + private DecryptionMaterialsRequest decryptReqFromMaterials(EncryptionMaterials result) { + return DecryptionMaterialsRequest.newBuilder() + .setEncryptionContext(result.getEncryptionContext()) + .setEncryptedDataKeys(result.getEncryptedDataKeys()) + .setAlgorithm(result.getAlgorithm()) + .build(); + } + + @Test + public void decrypt_testSimpleRoundTrip() throws Exception { + for (CryptoAlgorithm algorithm : CryptoAlgorithm.values()) { + CommitmentPolicy policy = + algorithm.isCommitting() + ? CommitmentPolicy.RequireEncryptRequireDecrypt + : CommitmentPolicy.ForbidEncryptAllowDecrypt; + EncryptionMaterials encryptMaterials = + easyGenMaterials(builder -> builder.setRequestedAlgorithm(algorithm)); + + DecryptionMaterials decryptMaterials = + new DefaultCryptoMaterialsManager(mk1) + .decryptMaterials(decryptReqFromMaterials(encryptMaterials)); + + assertArrayEquals( + decryptMaterials.getDataKey().getKey().getEncoded(), + encryptMaterials.getCleartextDataKey().getEncoded()); + + if (encryptMaterials.getTrailingSignatureKey() == null) { + assertNull(decryptMaterials.getTrailingSignatureKey()); + } else { + Signature sig = + Signature.getInstance( + TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).getHashAndSignAlgorithm()); + + sig.initSign(encryptMaterials.getTrailingSignatureKey()); + + byte[] data = "hello world".getBytes(StandardCharsets.UTF_8); + + sig.update(data); + byte[] signature = sig.sign(); + + sig.initVerify(decryptMaterials.getTrailingSignatureKey()); + + sig.update(data); + sig.verify(signature); + } } - - private DecryptionMaterialsRequest decryptReqFromMaterials(EncryptionMaterials result) { - return DecryptionMaterialsRequest.newBuilder() - .setEncryptionContext(result.getEncryptionContext()) - .setEncryptedDataKeys(result.getEncryptedDataKeys()) - .setAlgorithm(result.getAlgorithm()) - .build(); - } - - @Test - public void decrypt_testSimpleRoundTrip() throws Exception { - for (CryptoAlgorithm algorithm : CryptoAlgorithm.values()) { - CommitmentPolicy policy = algorithm.isCommitting() ? CommitmentPolicy.RequireEncryptRequireDecrypt : CommitmentPolicy.ForbidEncryptAllowDecrypt; - EncryptionMaterials encryptMaterials = easyGenMaterials( - builder -> builder.setRequestedAlgorithm(algorithm) - ); - - DecryptionMaterials decryptMaterials = new DefaultCryptoMaterialsManager(mk1) - .decryptMaterials(decryptReqFromMaterials(encryptMaterials)); - - assertArrayEquals(decryptMaterials.getDataKey().getKey().getEncoded(), - encryptMaterials.getCleartextDataKey().getEncoded()); - - if (encryptMaterials.getTrailingSignatureKey() == null) { - assertNull(decryptMaterials.getTrailingSignatureKey()); - } else { - Signature sig = Signature.getInstance( - TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).getHashAndSignAlgorithm() - ); - - sig.initSign(encryptMaterials.getTrailingSignatureKey()); - - byte[] data = "hello world".getBytes(StandardCharsets.UTF_8); - - sig.update(data); - byte[] signature = sig.sign(); - - sig.initVerify(decryptMaterials.getTrailingSignatureKey()); - - sig.update(data); - sig.verify(signature); - } - } - } - - @Test(expected = CannotUnwrapDataKeyException.class) - public void decrypt_onDecryptFailure() throws Exception { - new DefaultCryptoMaterialsManager(mock(MasterKeyProvider.class)).decryptMaterials( - decryptReqFromMaterials(easyGenMaterials(ignored -> {})) - ); - } - - @Test - public void decrypt_whenTrailingSigMissing_throwsException() throws Exception { - for (CryptoAlgorithm algorithm : CryptoAlgorithm.values()) { - // Only test algorithms without key commitment - if (algorithm.getMessageFormatVersion() != 1) { - continue; - } - if (algorithm.getTrailingSignatureLength() == 0) { - continue; - } - - EncryptionMaterials encryptMaterials = easyGenMaterials( - builder -> builder.setRequestedAlgorithm(algorithm) - ); - - DecryptionMaterialsRequest request = DecryptionMaterialsRequest.newBuilder() - .setEncryptedDataKeys(encryptMaterials.getEncryptedDataKeys()) - .setAlgorithm(algorithm) - .setEncryptionContext(Collections.emptyMap()) - .build(); - - try { - new DefaultCryptoMaterialsManager(mk1).decryptMaterials(request); - fail("expected exception"); - } catch (AwsCryptoException e) { - // ok - continue; - } - } + } + + @Test(expected = CannotUnwrapDataKeyException.class) + public void decrypt_onDecryptFailure() throws Exception { + new DefaultCryptoMaterialsManager(mock(MasterKeyProvider.class)) + .decryptMaterials(decryptReqFromMaterials(easyGenMaterials(ignored -> {}))); + } + + @Test + public void decrypt_whenTrailingSigMissing_throwsException() throws Exception { + for (CryptoAlgorithm algorithm : CryptoAlgorithm.values()) { + // Only test algorithms without key commitment + if (algorithm.getMessageFormatVersion() != 1) { + continue; + } + if (algorithm.getTrailingSignatureLength() == 0) { + continue; + } + + EncryptionMaterials encryptMaterials = + easyGenMaterials(builder -> builder.setRequestedAlgorithm(algorithm)); + + DecryptionMaterialsRequest request = + DecryptionMaterialsRequest.newBuilder() + .setEncryptedDataKeys(encryptMaterials.getEncryptedDataKeys()) + .setAlgorithm(algorithm) + .setEncryptionContext(Collections.emptyMap()) + .build(); + + try { + new DefaultCryptoMaterialsManager(mk1).decryptMaterials(request); + fail("expected exception"); + } catch (AwsCryptoException e) { + // ok + continue; + } } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/FastTestsOnlySuite.java b/src/test/java/com/amazonaws/encryptionsdk/FastTestsOnlySuite.java index 528ae273e..de54686a3 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/FastTestsOnlySuite.java +++ b/src/test/java/com/amazonaws/encryptionsdk/FastTestsOnlySuite.java @@ -1,7 +1,6 @@ package com.amazonaws.encryptionsdk; import java.util.concurrent.TimeUnit; - import org.junit.ClassRule; import org.junit.experimental.categories.Categories; import org.junit.rules.TestRule; @@ -15,75 +14,78 @@ import org.junit.runners.model.Statement; /** - * This test suite is intended to assist in rapid development; it filters out some of the slower, more exhaustive tests - * in the overall test suite to allow for a rapid edit-test cycle. + * This test suite is intended to assist in rapid development; it filters out some of the slower, + * more exhaustive tests in the overall test suite to allow for a rapid edit-test cycle. */ @RunWith(FastTestsOnlySuite.CustomRunner.class) -@Suite.SuiteClasses({ - AllTestsSuite.class -}) +@Suite.SuiteClasses({AllTestsSuite.class}) @Categories.ExcludeCategory(SlowTestCategory.class) public class FastTestsOnlySuite { - private static InheritableThreadLocal IS_FAST_TEST_SUITE_ACTIVE = new InheritableThreadLocal() { - @Override protected Boolean initialValue() { - return false; + private static InheritableThreadLocal IS_FAST_TEST_SUITE_ACTIVE = + new InheritableThreadLocal() { + @Override + protected Boolean initialValue() { + return false; } - }; + }; - // This method is used to adjust DataProviders to provide a smaller subset of their test cases when the fast tests - // are selected - public static boolean isFastTestSuiteActive() { - return IS_FAST_TEST_SUITE_ACTIVE.get(); - } + // This method is used to adjust DataProviders to provide a smaller subset of their test cases + // when the fast tests + // are selected + public static boolean isFastTestSuiteActive() { + return IS_FAST_TEST_SUITE_ACTIVE.get(); + } - // Require that this fast suite completes relatively quickly. If you're seeing this timeout get hit, it's time to - // pare down tests some more. As a general rule of thumb, we should avoid any single test taking more than 10s, and - // try to keep the number of such slow tests to a minimum. - @ClassRule - public static Timeout timeout = new Timeout(2, TimeUnit.MINUTES); + // Require that this fast suite completes relatively quickly. If you're seeing this timeout get + // hit, it's time to + // pare down tests some more. As a general rule of thumb, we should avoid any single test taking + // more than 10s, and + // try to keep the number of such slow tests to a minimum. + @ClassRule public static Timeout timeout = new Timeout(2, TimeUnit.MINUTES); - @ClassRule - public static EnableFastSuite enableFastSuite = new EnableFastSuite(); + @ClassRule public static EnableFastSuite enableFastSuite = new EnableFastSuite(); - // TestRules run over the execution of tests, but not over the generation of parameterized test data... - private static class EnableFastSuite implements TestRule { - @Override public Statement apply( - Statement base, Description description - ) { - return new Statement() { - @Override public void evaluate() throws Throwable { - Boolean oldValue = IS_FAST_TEST_SUITE_ACTIVE.get(); + // TestRules run over the execution of tests, but not over the generation of parameterized test + // data... + private static class EnableFastSuite implements TestRule { + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + Boolean oldValue = IS_FAST_TEST_SUITE_ACTIVE.get(); - try { - IS_FAST_TEST_SUITE_ACTIVE.set(true); - base.evaluate(); - } finally { - IS_FAST_TEST_SUITE_ACTIVE.set(oldValue); - } - } - }; + try { + IS_FAST_TEST_SUITE_ACTIVE.set(true); + base.evaluate(); + } finally { + IS_FAST_TEST_SUITE_ACTIVE.set(oldValue); + } } + }; } + } - // ... so we also need a custom TestRunner that will pass the flag on to the parameterized test data generators. - public static class CustomRunner extends Categories { - public CustomRunner(Class klass, RunnerBuilder builder) throws InitializationError { - super( - klass, - new RunnerBuilder() { - @Override public Runner runnerForClass(Class testClass) throws Throwable { - Boolean oldValue = IS_FAST_TEST_SUITE_ACTIVE.get(); + // ... so we also need a custom TestRunner that will pass the flag on to the parameterized test + // data generators. + public static class CustomRunner extends Categories { + public CustomRunner(Class klass, RunnerBuilder builder) throws InitializationError { + super( + klass, + new RunnerBuilder() { + @Override + public Runner runnerForClass(Class testClass) throws Throwable { + Boolean oldValue = IS_FAST_TEST_SUITE_ACTIVE.get(); - try { - IS_FAST_TEST_SUITE_ACTIVE.set(true); - Runner r = builder.runnerForClass(testClass); - return r; - } finally { - IS_FAST_TEST_SUITE_ACTIVE.set(oldValue); - } - } - } - ); - } + try { + IS_FAST_TEST_SUITE_ACTIVE.set(true); + Runner r = builder.runnerForClass(testClass); + return r; + } finally { + IS_FAST_TEST_SUITE_ACTIVE.set(oldValue); + } + } + }); } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/IntegrationTestSuite.java b/src/test/java/com/amazonaws/encryptionsdk/IntegrationTestSuite.java index efa5b7a18..aa6ea32ee 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/IntegrationTestSuite.java +++ b/src/test/java/com/amazonaws/encryptionsdk/IntegrationTestSuite.java @@ -1,17 +1,15 @@ package com.amazonaws.encryptionsdk; +import com.amazonaws.encryptionsdk.kms.KMSProviderBuilderIntegrationTests; import com.amazonaws.encryptionsdk.kms.MaxEncryptedDataKeysIntegrationTest; +import com.amazonaws.encryptionsdk.kms.XCompatKmsDecryptTest; import org.junit.runner.RunWith; import org.junit.runners.Suite; -import com.amazonaws.encryptionsdk.kms.KMSProviderBuilderIntegrationTests; -import com.amazonaws.encryptionsdk.kms.XCompatKmsDecryptTest; - @RunWith(Suite.class) @Suite.SuiteClasses({ - XCompatKmsDecryptTest.class, - KMSProviderBuilderIntegrationTests.class, - MaxEncryptedDataKeysIntegrationTest.class, + XCompatKmsDecryptTest.class, + KMSProviderBuilderIntegrationTests.class, + MaxEncryptedDataKeysIntegrationTest.class, }) -public class IntegrationTestSuite { -} +public class IntegrationTestSuite {} diff --git a/src/test/java/com/amazonaws/encryptionsdk/ParsedCiphertextTest.java b/src/test/java/com/amazonaws/encryptionsdk/ParsedCiphertextTest.java index 1a3267358..2de6a4a24 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/ParsedCiphertextTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/ParsedCiphertextTest.java @@ -3,137 +3,140 @@ package com.amazonaws.encryptionsdk; +import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; +import static org.junit.Assert.*; +import static org.mockito.Mockito.spy; + import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.internal.StaticMasterKey; import com.amazonaws.encryptionsdk.model.CiphertextHeaders; import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; -import com.amazonaws.util.IOUtils; -import org.junit.Before; -import org.junit.Test; - -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; - -import java.io.ByteArrayInputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Arrays; - -import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; -import static org.junit.Assert.*; -import static org.mockito.Mockito.spy; +import org.junit.Before; +import org.junit.Test; public class ParsedCiphertextTest extends CiphertextHeaders { - private static final int MESSAGE_FORMAT_MAX_EDKS = (1 << 16) - 1; - private StaticMasterKey masterKeyProvider; - private AwsCrypto encryptionClient_; - - @Before - public void init() { - masterKeyProvider = spy(new StaticMasterKey("testmaterial")); - - encryptionClient_ = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt).build(); - encryptionClient_.setEncryptionAlgorithm(CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256); - } - - @Test() - public void goodParsedCiphertext() { - final int byteSize = 0; - final int frameSize = 0; - final byte[] plaintextBytes = new byte[byteSize]; - - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC1", "ParsedCiphertext test with %d" + byteSize); - - encryptionClient_.setEncryptionFrameSize(frameSize); - - final byte[] cipherText = encryptionClient_.encryptData( - masterKeyProvider, - plaintextBytes, - encryptionContext).getResult(); - final ParsedCiphertext pCt = new ParsedCiphertext(cipherText); - - assertNotNull(pCt.getCiphertext()); - assertTrue(pCt.getOffset() > 0); - } - - @Test(expected = BadCiphertextException.class) - public void incompleteZeroByteCiphertext() { - final byte[] cipherText = {}; - ParsedCiphertext pCt = new ParsedCiphertext(cipherText); - } - - @Test(expected = BadCiphertextException.class) - public void incompleteSingleByteCiphertext() { - final byte[] cipherText = {1 /* Original ciphertext version number */}; - ParsedCiphertext pCt = new ParsedCiphertext(cipherText); - } - - @Test(expected = BadCiphertextException.class) - public void incompleteCiphertext() { - final int byteSize = 0; - final int frameSize = 0; - final byte[] plaintextBytes = new byte[byteSize]; - - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC1", "ParsedCiphertext test with %d" + byteSize); - - encryptionClient_.setEncryptionFrameSize(frameSize); - - final byte[] cipherText = encryptionClient_.encryptData( - masterKeyProvider, - plaintextBytes, - encryptionContext).getResult(); - ParsedCiphertext pCt = new ParsedCiphertext(cipherText); - - byte[] incompleteCiphertext = Arrays.copyOf(pCt.getCiphertext(), pCt.getOffset() - 1); - ParsedCiphertext badPCt = new ParsedCiphertext(incompleteCiphertext); - } - - private MasterKeyProvider providerWithEdks(int numEdks) { - List> providers = new ArrayList<>(); - for (int i = 0; i < numEdks; i++) { - providers.add(masterKeyProvider); - } - return MultipleProviderFactory.buildMultiProvider(providers); - } - - @Test - public void lessThanMaxEdks() { - MasterKeyProvider provider = providerWithEdks(2); - CryptoResult result = encryptionClient_.encryptData(provider, new byte[] {1}); - ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult(), 3); - assertEquals(ciphertext.getEncryptedKeyBlobCount(), 2); - } - - @Test - public void equalToMaxEdks() { - MasterKeyProvider provider = providerWithEdks(3); - CryptoResult result = encryptionClient_.encryptData(provider, new byte[] {1}); - ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult(), 3); - assertEquals(ciphertext.getEncryptedKeyBlobCount(), 3); - } - - @Test - public void failMoreThanMaxEdks() { - MasterKeyProvider provider = providerWithEdks(4); - CryptoResult result = encryptionClient_.encryptData(provider, new byte[] {1}); - assertThrows(AwsCryptoException.class, "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", () -> - new ParsedCiphertext(result.getResult(), 3)); - } - - @Test - public void noMaxEdks() { - MasterKeyProvider provider = providerWithEdks(MESSAGE_FORMAT_MAX_EDKS); - CryptoResult result = encryptionClient_.encryptData(provider, new byte[] {1}); - - // explicit no-max - ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult(), CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - assertEquals(ciphertext.getEncryptedKeyBlobCount(), MESSAGE_FORMAT_MAX_EDKS); - - // implicit no-max - ciphertext = new ParsedCiphertext(result.getResult()); - assertEquals(ciphertext.getEncryptedKeyBlobCount(), MESSAGE_FORMAT_MAX_EDKS); + private static final int MESSAGE_FORMAT_MAX_EDKS = (1 << 16) - 1; + private StaticMasterKey masterKeyProvider; + private AwsCrypto encryptionClient_; + + @Before + public void init() { + masterKeyProvider = spy(new StaticMasterKey("testmaterial")); + + encryptionClient_ = + AwsCrypto.builder() + .withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .build(); + encryptionClient_.setEncryptionAlgorithm( + CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256); + } + + @Test() + public void goodParsedCiphertext() { + final int byteSize = 0; + final int frameSize = 0; + final byte[] plaintextBytes = new byte[byteSize]; + + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC1", "ParsedCiphertext test with %d" + byteSize); + + encryptionClient_.setEncryptionFrameSize(frameSize); + + final byte[] cipherText = + encryptionClient_ + .encryptData(masterKeyProvider, plaintextBytes, encryptionContext) + .getResult(); + final ParsedCiphertext pCt = new ParsedCiphertext(cipherText); + + assertNotNull(pCt.getCiphertext()); + assertTrue(pCt.getOffset() > 0); + } + + @Test(expected = BadCiphertextException.class) + public void incompleteZeroByteCiphertext() { + final byte[] cipherText = {}; + ParsedCiphertext pCt = new ParsedCiphertext(cipherText); + } + + @Test(expected = BadCiphertextException.class) + public void incompleteSingleByteCiphertext() { + final byte[] cipherText = {1 /* Original ciphertext version number */}; + ParsedCiphertext pCt = new ParsedCiphertext(cipherText); + } + + @Test(expected = BadCiphertextException.class) + public void incompleteCiphertext() { + final int byteSize = 0; + final int frameSize = 0; + final byte[] plaintextBytes = new byte[byteSize]; + + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC1", "ParsedCiphertext test with %d" + byteSize); + + encryptionClient_.setEncryptionFrameSize(frameSize); + + final byte[] cipherText = + encryptionClient_ + .encryptData(masterKeyProvider, plaintextBytes, encryptionContext) + .getResult(); + ParsedCiphertext pCt = new ParsedCiphertext(cipherText); + + byte[] incompleteCiphertext = Arrays.copyOf(pCt.getCiphertext(), pCt.getOffset() - 1); + ParsedCiphertext badPCt = new ParsedCiphertext(incompleteCiphertext); + } + + private MasterKeyProvider providerWithEdks(int numEdks) { + List> providers = new ArrayList<>(); + for (int i = 0; i < numEdks; i++) { + providers.add(masterKeyProvider); } + return MultipleProviderFactory.buildMultiProvider(providers); + } + + @Test + public void lessThanMaxEdks() { + MasterKeyProvider provider = providerWithEdks(2); + CryptoResult result = encryptionClient_.encryptData(provider, new byte[] {1}); + ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult(), 3); + assertEquals(ciphertext.getEncryptedKeyBlobCount(), 2); + } + + @Test + public void equalToMaxEdks() { + MasterKeyProvider provider = providerWithEdks(3); + CryptoResult result = encryptionClient_.encryptData(provider, new byte[] {1}); + ParsedCiphertext ciphertext = new ParsedCiphertext(result.getResult(), 3); + assertEquals(ciphertext.getEncryptedKeyBlobCount(), 3); + } + + @Test + public void failMoreThanMaxEdks() { + MasterKeyProvider provider = providerWithEdks(4); + CryptoResult result = encryptionClient_.encryptData(provider, new byte[] {1}); + assertThrows( + AwsCryptoException.class, + "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", + () -> new ParsedCiphertext(result.getResult(), 3)); + } + + @Test + public void noMaxEdks() { + MasterKeyProvider provider = providerWithEdks(MESSAGE_FORMAT_MAX_EDKS); + CryptoResult result = encryptionClient_.encryptData(provider, new byte[] {1}); + + // explicit no-max + ParsedCiphertext ciphertext = + new ParsedCiphertext(result.getResult(), CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + assertEquals(ciphertext.getEncryptedKeyBlobCount(), MESSAGE_FORMAT_MAX_EDKS); + + // implicit no-max + ciphertext = new ParsedCiphertext(result.getResult()); + assertEquals(ciphertext.getEncryptedKeyBlobCount(), MESSAGE_FORMAT_MAX_EDKS); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/RecordingMaterialsManager.java b/src/test/java/com/amazonaws/encryptionsdk/RecordingMaterialsManager.java index bb61ae59a..acfacca05 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/RecordingMaterialsManager.java +++ b/src/test/java/com/amazonaws/encryptionsdk/RecordingMaterialsManager.java @@ -3,43 +3,38 @@ package com.amazonaws.encryptionsdk; -import java.util.Collections; - -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; -import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; +import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import java.util.Collections; public class RecordingMaterialsManager implements CryptoMaterialsManager { - private final CryptoMaterialsManager delegate; - private final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; + private final CryptoMaterialsManager delegate; + private final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; - public boolean didDecrypt = false; + public boolean didDecrypt = false; - public RecordingMaterialsManager(CryptoMaterialsManager delegate) { - this.delegate = delegate; - } + public RecordingMaterialsManager(CryptoMaterialsManager delegate) { + this.delegate = delegate; + } - public RecordingMaterialsManager(MasterKeyProvider delegate) { - this.delegate = new DefaultCryptoMaterialsManager(delegate); - } + public RecordingMaterialsManager(MasterKeyProvider delegate) { + this.delegate = new DefaultCryptoMaterialsManager(delegate); + } - @Override public EncryptionMaterials getMaterialsForEncrypt( - EncryptionMaterialsRequest request - ) { - request = request.toBuilder().setContext( - Collections.singletonMap("foo", "bar") - ).build(); + @Override + public EncryptionMaterials getMaterialsForEncrypt(EncryptionMaterialsRequest request) { + request = request.toBuilder().setContext(Collections.singletonMap("foo", "bar")).build(); - EncryptionMaterials result = delegate.getMaterialsForEncrypt(request); + EncryptionMaterials result = delegate.getMaterialsForEncrypt(request); - return result; - } + return result; + } - @Override public DecryptionMaterials decryptMaterials( - DecryptionMaterialsRequest request - ) { - didDecrypt = true; - return delegate.decryptMaterials(request); - } + @Override + public DecryptionMaterials decryptMaterials(DecryptionMaterialsRequest request) { + didDecrypt = true; + return delegate.decryptMaterials(request); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/SlowTestCategory.java b/src/test/java/com/amazonaws/encryptionsdk/SlowTestCategory.java index 9814b0948..58657e97f 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/SlowTestCategory.java +++ b/src/test/java/com/amazonaws/encryptionsdk/SlowTestCategory.java @@ -1,14 +1,12 @@ package com.amazonaws.encryptionsdk; /** - * JUnit category marking tests to be excluded from the FastTestsOnlySuite. Usage: - * + * JUnit category marking tests to be excluded from the FastTestsOnlySuite. Usage: * @Category(SlowTestCategory.class) * @Test * public void mySlowTest() { * // encrypt a couple terabytes of test data * } * - * */ public interface SlowTestCategory {} diff --git a/src/test/java/com/amazonaws/encryptionsdk/TestUtils.java b/src/test/java/com/amazonaws/encryptionsdk/TestUtils.java index a9d8ee020..682af4cc2 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/TestUtils.java +++ b/src/test/java/com/amazonaws/encryptionsdk/TestUtils.java @@ -3,14 +3,11 @@ package com.amazonaws.encryptionsdk; -import com.amazonaws.encryptionsdk.jce.JceMasterKey; - -import javax.crypto.spec.SecretKeySpec; - import static java.lang.String.format; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.amazonaws.encryptionsdk.jce.JceMasterKey; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; @@ -20,248 +17,266 @@ import java.util.Arrays; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicReference; +import javax.crypto.spec.SecretKeySpec; public class TestUtils { - public static final CryptoAlgorithm DEFAULT_TEST_CRYPTO_ALG = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; - public static final CommitmentPolicy DEFAULT_TEST_COMMITMENT_POLICY = CommitmentPolicy.RequireEncryptRequireDecrypt; - - // Handcrafted message for testing decryption of messages with committed keys - public static final String messageWithCommitKeyBase64 = "AgR4TfvRMU2dVZJbgXIyxeNtbj" + - "5eIw8BiTDiwsHyQ/Z9wXkAAAABAAxQcm92aWRlck5hbWUAGUtleUlkAAAAgAAAAAz45sc3cDvJZ7D4P3sAM" + - "KE7d/w8ziQt2C0qHsy1Qu2E2q92eIGE/kLnF/Y003HKvTxx7xv2Zv83YuOdwHML5QIAABAAF88I9zPbUQSf" + - "OlzLXv+uIY2+m/E6j2PMsbgeHVH/L0wLqQlY+5CL0z3xnNOMIZae/////wAAAAEAAAAAAAAAAAAAAAEAAAA" + - "OSZBKHHRpTwXOFTQVGapXXj5CwXBMouBB2ucaIJVm"; - public static final JceMasterKey messageWithCommitKeyMasterKey = JceMasterKey.getInstance( - new SecretKeySpec(new byte[32], "AES"), "ProviderName", "KeyId", "AES/GCM/NoPadding"); - public static final String messageWithCommitKeyMessageIdBase64 = "TfvRMU2dVZJbgXIyxeNtbj5eIw8BiTDiwsHyQ/Z9wXk="; - public static final String messageWithCommitKeyCommitmentBase64 = "F88I9zPbUQSfOlzLXv+uIY2+m/E6j2PMsbgeHVH/L0w="; - public static final String messageWithCommitKeyDEKBase64 = "+p6+whPVw9kOrYLZFMRBJ2n6Vli6T/7TkjDouS+25s0="; - public static final CryptoAlgorithm messageWithCommitKeyCryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; - public static final String messageWithCommitKeyExpectedResult = "GoodCommitment"; - - // Handcrafted message for testing decryption of messages with invalid committed keys - public static final String invalidMessageWithCommitKeyBase64 = "AgR4b1/73X5ErILpj0aSQIx6wNnH" + - "LEcNLxPzA0m6vYRr7kAAAAABAAxQcm92aWRlck5hbWUAGUtleUlkAAAAgAAAAAypJmXwyizUr3/pyvIAMHL" + - "U/i5GhZlGayeYC5w/CjUobyGwN4QpeMB0XpNDGTM0f1Zx72V4uM2H5wMjy/hm2wIAABAAAAECAwQFBgcICQ" + - "oLDA0ODxAREhMUFRYXGBkaGxwdHh/pQM2VSvliz2Qgi5JZf2ta/////wAAAAEAAAAAAAAAAAAAAAEAAAANS" + - "4Id4+dVHhPrvuJHEiOswo6YGSRjSGX3VDrt+0s="; - public static final JceMasterKey invalidMessageWithCommitKeyMasterKey = JceMasterKey.getInstance( - new SecretKeySpec(new byte[32], "AES"), "ProviderName", "KeyId", "AES/GCM/NoPadding"); - - // avoid spending time generating random data on every test case by caching some random test vectors - private static final AtomicReference RANDOM_CACHE = new AtomicReference<>(new byte[0]); - - private static byte[] ensureRandomCached(int length) { - byte[] buf = RANDOM_CACHE.get(); - if (buf.length >= length) { - return buf; - } + public static final CryptoAlgorithm DEFAULT_TEST_CRYPTO_ALG = + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; + public static final CommitmentPolicy DEFAULT_TEST_COMMITMENT_POLICY = + CommitmentPolicy.RequireEncryptRequireDecrypt; + + // Handcrafted message for testing decryption of messages with committed keys + public static final String messageWithCommitKeyBase64 = + "AgR4TfvRMU2dVZJbgXIyxeNtbj" + + "5eIw8BiTDiwsHyQ/Z9wXkAAAABAAxQcm92aWRlck5hbWUAGUtleUlkAAAAgAAAAAz45sc3cDvJZ7D4P3sAM" + + "KE7d/w8ziQt2C0qHsy1Qu2E2q92eIGE/kLnF/Y003HKvTxx7xv2Zv83YuOdwHML5QIAABAAF88I9zPbUQSf" + + "OlzLXv+uIY2+m/E6j2PMsbgeHVH/L0wLqQlY+5CL0z3xnNOMIZae/////wAAAAEAAAAAAAAAAAAAAAEAAAA" + + "OSZBKHHRpTwXOFTQVGapXXj5CwXBMouBB2ucaIJVm"; + public static final JceMasterKey messageWithCommitKeyMasterKey = + JceMasterKey.getInstance( + new SecretKeySpec(new byte[32], "AES"), "ProviderName", "KeyId", "AES/GCM/NoPadding"); + public static final String messageWithCommitKeyMessageIdBase64 = + "TfvRMU2dVZJbgXIyxeNtbj5eIw8BiTDiwsHyQ/Z9wXk="; + public static final String messageWithCommitKeyCommitmentBase64 = + "F88I9zPbUQSfOlzLXv+uIY2+m/E6j2PMsbgeHVH/L0w="; + public static final String messageWithCommitKeyDEKBase64 = + "+p6+whPVw9kOrYLZFMRBJ2n6Vli6T/7TkjDouS+25s0="; + public static final CryptoAlgorithm messageWithCommitKeyCryptoAlgorithm = + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; + public static final String messageWithCommitKeyExpectedResult = "GoodCommitment"; + + // Handcrafted message for testing decryption of messages with invalid committed keys + public static final String invalidMessageWithCommitKeyBase64 = + "AgR4b1/73X5ErILpj0aSQIx6wNnH" + + "LEcNLxPzA0m6vYRr7kAAAAABAAxQcm92aWRlck5hbWUAGUtleUlkAAAAgAAAAAypJmXwyizUr3/pyvIAMHL" + + "U/i5GhZlGayeYC5w/CjUobyGwN4QpeMB0XpNDGTM0f1Zx72V4uM2H5wMjy/hm2wIAABAAAAECAwQFBgcICQ" + + "oLDA0ODxAREhMUFRYXGBkaGxwdHh/pQM2VSvliz2Qgi5JZf2ta/////wAAAAEAAAAAAAAAAAAAAAEAAAANS" + + "4Id4+dVHhPrvuJHEiOswo6YGSRjSGX3VDrt+0s="; + public static final JceMasterKey invalidMessageWithCommitKeyMasterKey = + JceMasterKey.getInstance( + new SecretKeySpec(new byte[32], "AES"), "ProviderName", "KeyId", "AES/GCM/NoPadding"); + + // avoid spending time generating random data on every test case by caching some random test + // vectors + private static final AtomicReference RANDOM_CACHE = new AtomicReference<>(new byte[0]); + + private static byte[] ensureRandomCached(int length) { + byte[] buf = RANDOM_CACHE.get(); + if (buf.length >= length) { + return buf; + } - byte[] newBuf = new byte[length]; - ThreadLocalRandom.current().nextBytes(newBuf); + byte[] newBuf = new byte[length]; + ThreadLocalRandom.current().nextBytes(newBuf); - return RANDOM_CACHE.updateAndGet(oldBuf -> { - if (oldBuf.length < newBuf.length) { - return newBuf; - } else { - return oldBuf; - } + return RANDOM_CACHE.updateAndGet( + oldBuf -> { + if (oldBuf.length < newBuf.length) { + return newBuf; + } else { + return oldBuf; + } }); - } + } - @FunctionalInterface - public interface ThrowingRunnable { - void run() throws Throwable; + @FunctionalInterface + public interface ThrowingRunnable { + void run() throws Throwable; + } + + public static void assertThrows( + Class throwableClass, ThrowingRunnable callback) { + try { + callback.run(); + } catch (Throwable t) { + if (throwableClass.isAssignableFrom(t.getClass())) { + // ok + return; + } } - public static void assertThrows(Class throwableClass, ThrowingRunnable callback) { - try { - callback.run(); - } catch (Throwable t) { - if (throwableClass.isAssignableFrom(t.getClass())) { - // ok - return; - } - } + fail("Expected exception of type " + throwableClass); + } - fail("Expected exception of type " + throwableClass); + /** + * Asserts that calling {@code callback} results in a {@code throwableClass} (or sub-class) being + * thrown which has {@link Throwable#getMessage()} containing {@code message}. + */ + public static void assertThrows( + Class throwableClass, String message, ThrowingRunnable callback) { + try { + callback.run(); + fail("Expected exception of type " + throwableClass); + } catch (Throwable t) { + assertTrue( + format("Exception of wrong type. Was %s but expected %s", t.getClass(), throwableClass), + throwableClass.isAssignableFrom(t.getClass())); + assertTrue( + format( + "Exception did not contain the expected message. Actual: \"%s\" did not contain \"%s\"", + t.getMessage(), message), + t.getMessage().contains(message)); } + } - /** - * Asserts that calling {@code callback} results in a {@code throwableClass} (or sub-class) being thrown - * which has {@link Throwable#getMessage()} containing {@code message}. - */ - public static void assertThrows(Class throwableClass, String message, ThrowingRunnable callback) { - try { - callback.run(); - fail("Expected exception of type " + throwableClass); - } catch (Throwable t) { - assertTrue( - format("Exception of wrong type. Was %s but expected %s", t.getClass(), throwableClass), - throwableClass.isAssignableFrom(t.getClass())); - assertTrue( - format("Exception did not contain the expected message. Actual: \"%s\" did not contain \"%s\"", t.getMessage(), message), - t.getMessage().contains(message)); - } - } + public static void assertThrows(ThrowingRunnable callback) { + assertThrows(Throwable.class, callback); + } - public static void assertThrows(ThrowingRunnable callback) { - assertThrows(Throwable.class, callback); + /** + * Asserts that substituting any argument with null causes a NPE to be thrown. + * + *

Usage: + * + *

{@code
+   * assertNullChecks(
+   *   myAwsCrypto,
+   *   "createDecryptingStream",
+   *   CryptoMaterialsManager.class, myCMM,
+   *   InputStream.class, myIS
+   * );
+   * }
+ * + * @param callee + * @param methodName + * @param args + * @throws Exception + */ + public static void assertNullChecks( + Object callee, + String methodName, + // Class, value + Object... args) + throws Exception { + ArrayList parameterTypes = new ArrayList<>(); + for (int i = 0; i < args.length; i += 2) { + parameterTypes.add((Class) args[i]); } - /** - * Asserts that substituting any argument with null causes a NPE to be thrown. - * - * Usage: - * {@code - * - * assertNullChecks( - * myAwsCrypto, - * "createDecryptingStream", - * CryptoMaterialsManager.class, myCMM, - * InputStream.class, myIS - * ); - * } - * @param callee - * @param methodName - * @param args - * @throws Exception - */ - public static void assertNullChecks( - Object callee, - String methodName, - // Class, value - Object... args - ) throws Exception { - ArrayList parameterTypes = new ArrayList<>(); - for (int i = 0; i < args.length; i += 2) { - parameterTypes.add((Class)args[i]); - } + Method m = callee.getClass().getMethod(methodName, parameterTypes.toArray(new Class[0])); - Method m = callee.getClass().getMethod(methodName, parameterTypes.toArray(new Class[0])); - - for (int i = 0; i < args.length / 2; i++) { - if (args[i * 2 + 1] == null) { - // already null, which means null is ok here - continue; - } - - if (parameterTypes.get(i).isPrimitive()) { - // can't be null - continue; - } - - Object[] modifiedArgs = new Object[args.length/2]; - for (int j = 0; j < args.length / 2; j++) { - modifiedArgs[j] = args[j * 2 + 1]; - if (j == i) { - modifiedArgs[j] = null; - } - } - - try { - m.invoke(callee, modifiedArgs); - fail("Expected NullPointerException"); - } catch (InvocationTargetException e) { - if (e.getCause().getClass() == NullPointerException.class) { - continue; - } - - fail("Expected NullPointerException, got: " + e.getCause()); - } - } - } + for (int i = 0; i < args.length / 2; i++) { + if (args[i * 2 + 1] == null) { + // already null, which means null is ok here + continue; + } - public static byte[] toByteArray(InputStream is) throws IOException { - byte[] buffer = new byte[4096]; + if (parameterTypes.get(i).isPrimitive()) { + // can't be null + continue; + } - int offset = 0; - int rv; - while (true) { - rv = is.read(buffer, offset, buffer.length - offset); - if (rv <= 0) { - break; - } + Object[] modifiedArgs = new Object[args.length / 2]; + for (int j = 0; j < args.length / 2; j++) { + modifiedArgs[j] = args[j * 2 + 1]; + if (j == i) { + modifiedArgs[j] = null; + } + } + + try { + m.invoke(callee, modifiedArgs); + fail("Expected NullPointerException"); + } catch (InvocationTargetException e) { + if (e.getCause().getClass() == NullPointerException.class) { + continue; + } - offset += rv; + fail("Expected NullPointerException, got: " + e.getCause()); + } + } + } + + public static byte[] toByteArray(InputStream is) throws IOException { + byte[] buffer = new byte[4096]; - if (offset == buffer.length) { - if (buffer.length == Integer.MAX_VALUE) { - throw new IOException("Input data exceeds maximum array size"); - } + int offset = 0; + int rv; + while (true) { + rv = is.read(buffer, offset, buffer.length - offset); + if (rv <= 0) { + break; + } - int newSize = Math.toIntExact(Math.min(Integer.MAX_VALUE, 2L * buffer.length)); + offset += rv; - byte[] newBuffer = new byte[newSize]; - System.arraycopy(buffer, 0, newBuffer, 0, buffer.length); - buffer = newBuffer; - } + if (offset == buffer.length) { + if (buffer.length == Integer.MAX_VALUE) { + throw new IOException("Input data exceeds maximum array size"); } - return Arrays.copyOfRange(buffer, 0, offset); + int newSize = Math.toIntExact(Math.min(Integer.MAX_VALUE, 2L * buffer.length)); + + byte[] newBuffer = new byte[newSize]; + System.arraycopy(buffer, 0, newBuffer, 0, buffer.length); + buffer = newBuffer; + } } - public static byte[] insecureRandomBytes(int length) { - byte[] buf = new byte[length]; + return Arrays.copyOfRange(buffer, 0, offset); + } - System.arraycopy(ensureRandomCached(length), 0, buf, 0, length); + public static byte[] insecureRandomBytes(int length) { + byte[] buf = new byte[length]; - return buf; - } + System.arraycopy(ensureRandomCached(length), 0, buf, 0, length); - public static ByteArrayInputStream insecureRandomStream(int length) { - return new ByteArrayInputStream(ensureRandomCached(length), 0, length); - } + return buf; + } - public static int[] getFrameSizesToTest(final CryptoAlgorithm cryptoAlg) { - final int blockSize = cryptoAlg.getBlockSize(); - final int[] frameSizeToTest = { - 0, - blockSize - 1, - blockSize, - blockSize + 1, - blockSize * 2, - blockSize * 10, - blockSize * 10 + 1, - AwsCrypto.getDefaultFrameSize() - }; - return frameSizeToTest; + public static ByteArrayInputStream insecureRandomStream(int length) { + return new ByteArrayInputStream(ensureRandomCached(length), 0, length); } - /** - * Converts an array of unsigned bytes (represented as int values between 0 and 255 inclusive) - * to an array of Java primitive type byte, which are by definition signed. - * - * @param unsignedBytes An array on unsigned bytes - * @return An array of signed bytes - */ - public static byte[] unsignedBytesToSignedBytes(final int[] unsignedBytes) { - byte[] signedBytes = new byte[unsignedBytes.length]; - - for (int i = 0; i < unsignedBytes.length; i++) { - if (unsignedBytes[i] > 255) { - throw new IllegalArgumentException("Encountered unsigned byte value > 255"); - } - signedBytes[i] = (byte) (unsignedBytes[i] & 0xff); - } + public static int[] getFrameSizesToTest(final CryptoAlgorithm cryptoAlg) { + final int blockSize = cryptoAlg.getBlockSize(); + final int[] frameSizeToTest = { + 0, + blockSize - 1, + blockSize, + blockSize + 1, + blockSize * 2, + blockSize * 10, + blockSize * 10 + 1, + AwsCrypto.getDefaultFrameSize() + }; + return frameSizeToTest; + } - return signedBytes; + /** + * Converts an array of unsigned bytes (represented as int values between 0 and 255 inclusive) to + * an array of Java primitive type byte, which are by definition signed. + * + * @param unsignedBytes An array on unsigned bytes + * @return An array of signed bytes + */ + public static byte[] unsignedBytesToSignedBytes(final int[] unsignedBytes) { + byte[] signedBytes = new byte[unsignedBytes.length]; + + for (int i = 0; i < unsignedBytes.length; i++) { + if (unsignedBytes[i] > 255) { + throw new IllegalArgumentException("Encountered unsigned byte value > 255"); + } + signedBytes[i] = (byte) (unsignedBytes[i] & 0xff); } - /** - * Converts an array of Java primitive type bytes (which are by definition signed) to - * an array of unsigned bytes (represented as int values between 0 and 255 inclusive). - * - * @param signedBytes An array of signed bytes - * @return An array of unsigned bytes - */ - public static int[] signedBytesToUnsignedBytes(final byte[] signedBytes) { - int[] unsignedBytes = new int[signedBytes.length]; - - for (int i = 0; i < signedBytes.length; i++) { - unsignedBytes[i] = ((int) signedBytes[i]) & 0xff; - } + return signedBytes; + } - return unsignedBytes; + /** + * Converts an array of Java primitive type bytes (which are by definition signed) to an array of + * unsigned bytes (represented as int values between 0 and 255 inclusive). + * + * @param signedBytes An array of signed bytes + * @return An array of unsigned bytes + */ + public static int[] signedBytesToUnsignedBytes(final byte[] signedBytes) { + int[] unsignedBytes = new int[signedBytes.length]; + + for (int i = 0; i < signedBytes.length; i++) { + unsignedBytes[i] = ((int) signedBytes[i]) & 0xff; } + + return unsignedBytes; + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java b/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java index e3110a36f..cc2d9efbd 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java +++ b/src/test/java/com/amazonaws/encryptionsdk/TestVectorRunner.java @@ -3,6 +3,8 @@ package com.amazonaws.encryptionsdk; +import static java.lang.String.format; + import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.encryptionsdk.internal.SignaturePolicy; import com.amazonaws.encryptionsdk.jce.JceMasterKey; @@ -13,15 +15,6 @@ import com.amazonaws.util.IOUtils; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import org.bouncycastle.util.encoders.Base64; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; - -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; import java.io.IOException; import java.io.InputStream; import java.net.JarURLConnection; @@ -39,353 +32,403 @@ import java.util.function.Supplier; import java.util.jar.JarFile; import java.util.zip.ZipEntry; - -import static java.lang.String.format; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.bouncycastle.util.encoders.Base64; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; @RunWith(Parameterized.class) public class TestVectorRunner { - private static final int MANIFEST_VERSION = 2; - - // We save the files in memory to avoid repeatedly retrieving them. This won't work if the plaintexts are too - // large or numerous - private static final Map cachedData = new HashMap<>(); - - private final String testName; - private final TestCase testCase; - private final DecryptionMethod decryptionMethod; - - public TestVectorRunner(final String testName, TestCase testCase, DecryptionMethod decryptionMethod) { - this.testName = testName; - this.testCase = testCase; - this.decryptionMethod = decryptionMethod; - } - - @Test - public void decrypt() throws Exception { - AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt).build(); - Callable decryptor = () -> decryptionMethod.decryptMessage(crypto, testCase.mkpSupplier.get(), cachedData.get(testCase.ciphertextPath)); - testCase.matcher.Match(decryptor); + private static final int MANIFEST_VERSION = 2; + + // We save the files in memory to avoid repeatedly retrieving them. This won't work if the + // plaintexts are too + // large or numerous + private static final Map cachedData = new HashMap<>(); + + private final String testName; + private final TestCase testCase; + private final DecryptionMethod decryptionMethod; + + public TestVectorRunner( + final String testName, TestCase testCase, DecryptionMethod decryptionMethod) { + this.testName = testName; + this.testCase = testCase; + this.decryptionMethod = decryptionMethod; + } + + @Test + public void decrypt() throws Exception { + AwsCrypto crypto = + AwsCrypto.builder() + .withCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .build(); + Callable decryptor = + () -> + decryptionMethod.decryptMessage( + crypto, testCase.mkpSupplier.get(), cachedData.get(testCase.ciphertextPath)); + testCase.matcher.Match(decryptor); + } + + @Parameterized.Parameters(name = "Compatibility Test: {0} - {2}") + @SuppressWarnings("unchecked") + public static Collection data() throws Exception { + final String zipPath = System.getProperty("testVectorZip"); + if (zipPath == null) { + return Collections.emptyList(); } - @Parameterized.Parameters(name="Compatibility Test: {0} - {2}") - @SuppressWarnings("unchecked") - public static Collection data() throws Exception { - final String zipPath = System.getProperty("testVectorZip"); - if (zipPath == null) { - return Collections.emptyList(); + final JarURLConnection jarConnection = + (JarURLConnection) new URL("jar:" + zipPath + "!/").openConnection(); + + try (JarFile jar = jarConnection.getJarFile()) { + final Map manifest = readJsonMapFromJar(jar, "manifest.json"); + + final Map metaData = (Map) manifest.get("manifest"); + + // We only support "awses-decrypt" type manifests right now + if (!"awses-decrypt".equals(metaData.get("type"))) { + throw new IllegalArgumentException("Unsupported manifest type: " + metaData.get("type")); + } + + if (!Integer.valueOf(MANIFEST_VERSION).equals(metaData.get("version"))) { + throw new IllegalArgumentException( + "Unsupported manifest version: " + metaData.get("version")); + } + + final Map keys = + parseKeyManifest(readJsonMapFromJar(jar, (String) manifest.get("keys"))); + + final KmsMasterKeyProvider kmsProv = + KmsMasterKeyProvider.builder() + .withCredentials(new DefaultAWSCredentialsProviderChain()) + .buildDiscovery(); + + List testCases = new ArrayList<>(); + for (Map.Entry> testEntry : + ((Map>) manifest.get("tests")).entrySet()) { + String testName = testEntry.getKey(); + TestCase testCase = parseTest(testEntry.getKey(), testEntry.getValue(), keys, jar, kmsProv); + for (DecryptionMethod decryptionMethod : DecryptionMethod.values()) { + if (testCase.signaturePolicy.equals(decryptionMethod.signaturePolicy())) { + testCases.add(new Object[] {testName, testCase, decryptionMethod}); + } } - - final JarURLConnection jarConnection = (JarURLConnection) new URL("jar:" + zipPath + "!/").openConnection(); - - try (JarFile jar = jarConnection.getJarFile()) { - final Map manifest = readJsonMapFromJar(jar, "manifest.json"); - - final Map metaData = (Map) manifest.get("manifest"); - - // We only support "awses-decrypt" type manifests right now - if (!"awses-decrypt".equals(metaData.get("type"))) { - throw new IllegalArgumentException("Unsupported manifest type: " + metaData.get("type")); - } - - if (!Integer.valueOf(MANIFEST_VERSION).equals(metaData.get("version"))) { - throw new IllegalArgumentException("Unsupported manifest version: " + metaData.get("version")); - } - - final Map keys = parseKeyManifest(readJsonMapFromJar(jar, (String) manifest.get("keys"))); - - final KmsMasterKeyProvider kmsProv = KmsMasterKeyProvider - .builder() - .withCredentials(new DefaultAWSCredentialsProviderChain()) - .buildDiscovery(); - - List testCases = new ArrayList<>(); - for (Map.Entry> testEntry : - ((Map>) manifest.get("tests")).entrySet()) { - String testName = testEntry.getKey(); - TestCase testCase = parseTest(testEntry.getKey(), testEntry.getValue(), keys, jar, kmsProv); - for (DecryptionMethod decryptionMethod : DecryptionMethod.values()) { - if (testCase.signaturePolicy.equals(decryptionMethod.signaturePolicy())) { - testCases.add(new Object[]{testName, testCase, decryptionMethod}); - } - } - } - return testCases; - } - } - - @AfterClass - public static void teardown() { - cachedData.clear(); + } + return testCases; } + } - private static byte[] readBytesFromJar(JarFile jar, String fileName) throws IOException { - try (InputStream is = readFromJar(jar, fileName)) { - return IOUtils.toByteArray(is); - } - } + @AfterClass + public static void teardown() { + cachedData.clear(); + } - private static Map readJsonMapFromJar(JarFile jar, String fileName) throws IOException { - try (InputStream is = readFromJar(jar, fileName)) { - final ObjectMapper mapper = new ObjectMapper(); - return mapper.readValue(is, new TypeReference>() {}); - } + private static byte[] readBytesFromJar(JarFile jar, String fileName) throws IOException { + try (InputStream is = readFromJar(jar, fileName)) { + return IOUtils.toByteArray(is); } + } - private static InputStream readFromJar(JarFile jar, String name) throws IOException { - // Our manifest URIs incorrectly start with file:// rather than just file: so we need to strip this - ZipEntry entry = jar.getEntry(name.replaceFirst("^file://(?!/)", "")); - return jar.getInputStream(entry); + private static Map readJsonMapFromJar(JarFile jar, String fileName) + throws IOException { + try (InputStream is = readFromJar(jar, fileName)) { + final ObjectMapper mapper = new ObjectMapper(); + return mapper.readValue(is, new TypeReference>() {}); } - - private static void cacheData(JarFile jar, String url) throws IOException { - if (!cachedData.containsKey(url)) { - cachedData.put(url, readBytesFromJar(jar, url)); - } + } + + private static InputStream readFromJar(JarFile jar, String name) throws IOException { + // Our manifest URIs incorrectly start with file:// rather than just file: so we need to strip + // this + ZipEntry entry = jar.getEntry(name.replaceFirst("^file://(?!/)", "")); + return jar.getInputStream(entry); + } + + private static void cacheData(JarFile jar, String url) throws IOException { + if (!cachedData.containsKey(url)) { + cachedData.put(url, readBytesFromJar(jar, url)); } - - @SuppressWarnings("unchecked") - private static TestCase parseTest(String testName, Map data, Map keys, - JarFile jar, KmsMasterKeyProvider kmsProv) throws IOException { - final String ciphertextURL = (String) data.get("ciphertext"); - cacheData(jar, ciphertextURL); - - Supplier> mkpSupplier = () -> { - - @SuppressWarnings("generic") - final List> mks = new ArrayList<>(); - - for (Map mkEntry : (List>) data.get("master-keys")) { - final String type = (String) mkEntry.get("type"); - final String keyName =(String) mkEntry.get("key"); - final KeyEntry key = keys.get(keyName); - - if ("aws-kms".equals(type)) { - mks.add(kmsProv.getMasterKey(key.keyId)); - } else if ("aws-kms-mrk-aware".equals(type)) { - AwsKmsMrkAwareMasterKeyProvider provider = AwsKmsMrkAwareMasterKeyProvider.builder().buildStrict(key.keyId); - mks.add(provider.getMasterKey(key.keyId)); - } else if ("aws-kms-mrk-aware-discovery".equals(type)) { - final String defaultMrkRegion = (String) mkEntry.get("default-mrk-region"); - final Map discoveryFilterSpec = (Map) mkEntry.get("aws-kms-discovery-filter"); - final DiscoveryFilter discoveryFilter; - if (discoveryFilterSpec != null) { - discoveryFilter = new DiscoveryFilter((String) discoveryFilterSpec.get("partition"), - (List) discoveryFilterSpec.get("account-ids")); - } else { - discoveryFilter = null; - } - return AwsKmsMrkAwareMasterKeyProvider.builder() - .withDiscoveryMrkRegion(defaultMrkRegion) - .buildDiscovery(discoveryFilter); - } else if ("raw".equals(type)) { - final String provId = (String) mkEntry.get("provider-id"); - final String algorithm = (String) mkEntry.get("encryption-algorithm"); - if ("aes".equals(algorithm)) { - mks.add(JceMasterKey.getInstance((SecretKey) key.key, provId, key.keyId, "AES/GCM/NoPadding")); - } else if ("rsa".equals(algorithm)) { - String transformation = "RSA/ECB/"; - final String padding = (String) mkEntry.get("padding-algorithm"); - if ("pkcs1".equals(padding)) { - transformation += "PKCS1Padding"; - } else if ("oaep-mgf1".equals(padding)) { - final String hashName = ((String) mkEntry.get("padding-hash")) - .replace("sha", "sha-") - .toUpperCase(); - transformation += "OAEPWith" + hashName + "AndMGF1Padding"; - } else { - throw new IllegalArgumentException("Unsupported padding:" + padding); - } - final PublicKey wrappingKey; - final PrivateKey unwrappingKey; - if (key.key instanceof PublicKey) { - wrappingKey = (PublicKey) key.key; - unwrappingKey = null; - } else { - wrappingKey = null; - unwrappingKey = (PrivateKey) key.key; - } - mks.add(JceMasterKey.getInstance(wrappingKey, unwrappingKey, provId, key.keyId, transformation)); - } else { - throw new IllegalArgumentException("Unsupported algorithm: " + algorithm); - } + } + + @SuppressWarnings("unchecked") + private static TestCase parseTest( + String testName, + Map data, + Map keys, + JarFile jar, + KmsMasterKeyProvider kmsProv) + throws IOException { + final String ciphertextURL = (String) data.get("ciphertext"); + cacheData(jar, ciphertextURL); + + Supplier> mkpSupplier = + () -> { + @SuppressWarnings("generic") + final List> mks = new ArrayList<>(); + + for (Map mkEntry : (List>) data.get("master-keys")) { + final String type = (String) mkEntry.get("type"); + final String keyName = (String) mkEntry.get("key"); + final KeyEntry key = keys.get(keyName); + + if ("aws-kms".equals(type)) { + mks.add(kmsProv.getMasterKey(key.keyId)); + } else if ("aws-kms-mrk-aware".equals(type)) { + AwsKmsMrkAwareMasterKeyProvider provider = + AwsKmsMrkAwareMasterKeyProvider.builder().buildStrict(key.keyId); + mks.add(provider.getMasterKey(key.keyId)); + } else if ("aws-kms-mrk-aware-discovery".equals(type)) { + final String defaultMrkRegion = (String) mkEntry.get("default-mrk-region"); + final Map discoveryFilterSpec = + (Map) mkEntry.get("aws-kms-discovery-filter"); + final DiscoveryFilter discoveryFilter; + if (discoveryFilterSpec != null) { + discoveryFilter = + new DiscoveryFilter( + (String) discoveryFilterSpec.get("partition"), + (List) discoveryFilterSpec.get("account-ids")); + } else { + discoveryFilter = null; + } + return AwsKmsMrkAwareMasterKeyProvider.builder() + .withDiscoveryMrkRegion(defaultMrkRegion) + .buildDiscovery(discoveryFilter); + } else if ("raw".equals(type)) { + final String provId = (String) mkEntry.get("provider-id"); + final String algorithm = (String) mkEntry.get("encryption-algorithm"); + if ("aes".equals(algorithm)) { + mks.add( + JceMasterKey.getInstance( + (SecretKey) key.key, provId, key.keyId, "AES/GCM/NoPadding")); + } else if ("rsa".equals(algorithm)) { + String transformation = "RSA/ECB/"; + final String padding = (String) mkEntry.get("padding-algorithm"); + if ("pkcs1".equals(padding)) { + transformation += "PKCS1Padding"; + } else if ("oaep-mgf1".equals(padding)) { + final String hashName = + ((String) mkEntry.get("padding-hash")).replace("sha", "sha-").toUpperCase(); + transformation += "OAEPWith" + hashName + "AndMGF1Padding"; + } else { + throw new IllegalArgumentException("Unsupported padding:" + padding); + } + final PublicKey wrappingKey; + final PrivateKey unwrappingKey; + if (key.key instanceof PublicKey) { + wrappingKey = (PublicKey) key.key; + unwrappingKey = null; } else { - throw new IllegalArgumentException("Unsupported Key Type: " + type); + wrappingKey = null; + unwrappingKey = (PrivateKey) key.key; } + mks.add( + JceMasterKey.getInstance( + wrappingKey, unwrappingKey, provId, key.keyId, transformation)); + } else { + throw new IllegalArgumentException("Unsupported algorithm: " + algorithm); + } + } else { + throw new IllegalArgumentException("Unsupported Key Type: " + type); } + } - return MultipleProviderFactory.buildMultiProvider(mks); + return MultipleProviderFactory.buildMultiProvider(mks); }; - @SuppressWarnings("unchecked") - final Map resultSpec = (Map) data.get("result"); - final ResultMatcher matcher = parseResultMatcher(jar, resultSpec); + @SuppressWarnings("unchecked") + final Map resultSpec = (Map) data.get("result"); + final ResultMatcher matcher = parseResultMatcher(jar, resultSpec); + + String decryptionMethodSpec = (String) data.get("decryption-method"); + SignaturePolicy signaturePolicy = SignaturePolicy.AllowEncryptAllowDecrypt; + if (decryptionMethodSpec != null) { + if ("streaming-unsigned-only".equals(decryptionMethodSpec)) { + signaturePolicy = SignaturePolicy.AllowEncryptForbidDecrypt; + } else { + throw new IllegalArgumentException( + "Unsupported Decryption Method: " + decryptionMethodSpec); + } + } - String decryptionMethodSpec = (String) data.get("decryption-method"); - SignaturePolicy signaturePolicy = SignaturePolicy.AllowEncryptAllowDecrypt; - if (decryptionMethodSpec != null) { - if ("streaming-unsigned-only".equals(decryptionMethodSpec)) { - signaturePolicy = SignaturePolicy.AllowEncryptForbidDecrypt; - } else { - throw new IllegalArgumentException("Unsupported Decryption Method: " + decryptionMethodSpec); - } - } + return new TestCase(testName, ciphertextURL, mkpSupplier, matcher, signaturePolicy); + } - return new TestCase(testName, ciphertextURL, mkpSupplier, matcher, signaturePolicy); + private static ResultMatcher parseResultMatcher( + final JarFile jar, final Map result) throws IOException { + if (result.size() != 1) { + throw new IllegalArgumentException("Unsupported result specification: " + result); } - - private static ResultMatcher parseResultMatcher(final JarFile jar, final Map result) throws IOException { - if (result.size() != 1) { - throw new IllegalArgumentException("Unsupported result specification: " + result); - } - Map.Entry pair = result.entrySet().iterator().next(); - if (pair.getKey().equals("output")) { - Map outputSpec = (Map) pair.getValue(); - String plaintextUrl = outputSpec.get("plaintext"); - cacheData(jar, plaintextUrl); - return new OutputResultMatcher(plaintextUrl); - } else if (pair.getKey().equals("error")) { - Map errorSpec = (Map) pair.getValue(); - String errorDescription = errorSpec.get("error-description"); - return new ErrorResultMatcher(errorDescription); - } else { - throw new IllegalArgumentException("Unsupported result specification: " + result); - } + Map.Entry pair = result.entrySet().iterator().next(); + if (pair.getKey().equals("output")) { + Map outputSpec = (Map) pair.getValue(); + String plaintextUrl = outputSpec.get("plaintext"); + cacheData(jar, plaintextUrl); + return new OutputResultMatcher(plaintextUrl); + } else if (pair.getKey().equals("error")) { + Map errorSpec = (Map) pair.getValue(); + String errorDescription = errorSpec.get("error-description"); + return new ErrorResultMatcher(errorDescription); + } else { + throw new IllegalArgumentException("Unsupported result specification: " + result); } - - @SuppressWarnings("unchecked") - private static Map parseKeyManifest(final Map keysManifest) throws GeneralSecurityException { - // check our type - final Map metaData = (Map) keysManifest.get("manifest"); - if (!"keys".equals(metaData.get("type"))) { - throw new IllegalArgumentException("Invalid manifest type: " + metaData.get("type")); - } - if (!Integer.valueOf(3).equals(metaData.get("version"))) { - throw new IllegalArgumentException("Invalid manifest version: " + metaData.get("version")); - } - - final Map result = new HashMap<>(); - - Map keys = (Map) keysManifest.get("keys"); - for (Map.Entry entry : keys.entrySet()) { - final String name = entry.getKey(); - final Map data = (Map) entry.getValue(); - - final String keyType = (String) data.get("type"); - final String encoding = (String) data.get("encoding"); - final String keyId = (String) data.get("key-id"); - final String material = (String) data.get("material"); // May be null - final String algorithm = (String) data.get("algorithm"); // May be null - - final KeyEntry keyEntry; - - final KeyFactory kf; - switch (keyType) { - case "symmetric": - if (!"base64".equals(encoding)) { - throw new IllegalArgumentException(format("Key %s is symmetric but has encoding %s", keyId, encoding)); - } - keyEntry = new KeyEntry(name, keyId, keyType, - new SecretKeySpec(Base64.decode(material), algorithm.toUpperCase())); - break; - case "private": - kf = KeyFactory.getInstance(algorithm); - if (!"pem".equals(encoding)) { - throw new IllegalArgumentException(format("Key %s is private but has encoding %s", keyId, encoding)); - } - byte[] pkcs8Key = parsePem(material); - keyEntry = new KeyEntry(name, keyId, keyType, - kf.generatePrivate(new PKCS8EncodedKeySpec(pkcs8Key))); - break; - case "public": - kf = KeyFactory.getInstance(algorithm); - if (!"pem".equals(encoding)) { - throw new IllegalArgumentException(format("Key %s is private but has encoding %s", keyId, encoding)); - } - byte[] x509Key = parsePem(material); - keyEntry = new KeyEntry(name, keyId, keyType, - kf.generatePublic(new X509EncodedKeySpec(x509Key))); - break; - case "aws-kms": - keyEntry = new KeyEntry(name, keyId, keyType, null); - break; - default: - throw new IllegalArgumentException("Unsupported key type: " + keyType); - } - - result.put(name, keyEntry); - } - - return result; + } + + @SuppressWarnings("unchecked") + private static Map parseKeyManifest(final Map keysManifest) + throws GeneralSecurityException { + // check our type + final Map metaData = (Map) keysManifest.get("manifest"); + if (!"keys".equals(metaData.get("type"))) { + throw new IllegalArgumentException("Invalid manifest type: " + metaData.get("type")); } - - private static byte[] parsePem(String pem) { - final String stripped = pem.replaceAll("-+[A-Z ]+-+", ""); - return Base64.decode(stripped); + if (!Integer.valueOf(3).equals(metaData.get("version"))) { + throw new IllegalArgumentException("Invalid manifest version: " + metaData.get("version")); } - private static class KeyEntry { - final String name; - final String keyId; - final String type; - final Key key; - - private KeyEntry(String name, String keyId, String type, Key key) { - this.name = name; - this.keyId = keyId; - this.type = type; - this.key = key; - } + final Map result = new HashMap<>(); + + Map keys = (Map) keysManifest.get("keys"); + for (Map.Entry entry : keys.entrySet()) { + final String name = entry.getKey(); + final Map data = (Map) entry.getValue(); + + final String keyType = (String) data.get("type"); + final String encoding = (String) data.get("encoding"); + final String keyId = (String) data.get("key-id"); + final String material = (String) data.get("material"); // May be null + final String algorithm = (String) data.get("algorithm"); // May be null + + final KeyEntry keyEntry; + + final KeyFactory kf; + switch (keyType) { + case "symmetric": + if (!"base64".equals(encoding)) { + throw new IllegalArgumentException( + format("Key %s is symmetric but has encoding %s", keyId, encoding)); + } + keyEntry = + new KeyEntry( + name, + keyId, + keyType, + new SecretKeySpec(Base64.decode(material), algorithm.toUpperCase())); + break; + case "private": + kf = KeyFactory.getInstance(algorithm); + if (!"pem".equals(encoding)) { + throw new IllegalArgumentException( + format("Key %s is private but has encoding %s", keyId, encoding)); + } + byte[] pkcs8Key = parsePem(material); + keyEntry = + new KeyEntry( + name, keyId, keyType, kf.generatePrivate(new PKCS8EncodedKeySpec(pkcs8Key))); + break; + case "public": + kf = KeyFactory.getInstance(algorithm); + if (!"pem".equals(encoding)) { + throw new IllegalArgumentException( + format("Key %s is private but has encoding %s", keyId, encoding)); + } + byte[] x509Key = parsePem(material); + keyEntry = + new KeyEntry( + name, keyId, keyType, kf.generatePublic(new X509EncodedKeySpec(x509Key))); + break; + case "aws-kms": + keyEntry = new KeyEntry(name, keyId, keyType, null); + break; + default: + throw new IllegalArgumentException("Unsupported key type: " + keyType); + } + + result.put(name, keyEntry); } - private static class TestCase { - private final String name; - private final String ciphertextPath; - private final ResultMatcher matcher; - private final Supplier> mkpSupplier; - private final SignaturePolicy signaturePolicy; - - private TestCase(String name, String ciphertextPath, Supplier> mkpSupplier, ResultMatcher matcher, SignaturePolicy signaturePolicy) { - this.name = name; - this.ciphertextPath = ciphertextPath; - this.matcher = matcher; - this.mkpSupplier = mkpSupplier; - this.signaturePolicy = signaturePolicy; - } + return result; + } + + private static byte[] parsePem(String pem) { + final String stripped = pem.replaceAll("-+[A-Z ]+-+", ""); + return Base64.decode(stripped); + } + + private static class KeyEntry { + final String name; + final String keyId; + final String type; + final Key key; + + private KeyEntry(String name, String keyId, String type, Key key) { + this.name = name; + this.keyId = keyId; + this.type = type; + this.key = key; } - - private interface ResultMatcher { - void Match(Callable decryptor) throws Exception; + } + + private static class TestCase { + private final String name; + private final String ciphertextPath; + private final ResultMatcher matcher; + private final Supplier> mkpSupplier; + private final SignaturePolicy signaturePolicy; + + private TestCase( + String name, + String ciphertextPath, + Supplier> mkpSupplier, + ResultMatcher matcher, + SignaturePolicy signaturePolicy) { + this.name = name; + this.ciphertextPath = ciphertextPath; + this.matcher = matcher; + this.mkpSupplier = mkpSupplier; + this.signaturePolicy = signaturePolicy; } + } - private static class OutputResultMatcher implements ResultMatcher { + private interface ResultMatcher { + void Match(Callable decryptor) throws Exception; + } - private final String plaintextPath; + private static class OutputResultMatcher implements ResultMatcher { - private OutputResultMatcher(String plaintextPath) { - this.plaintextPath = plaintextPath; - } + private final String plaintextPath; - @Override - public void Match(Callable decryptor) throws Exception { - final byte[] plaintext = decryptor.call(); - final byte[] expectedPlaintext = cachedData.get(plaintextPath); - Assert.assertArrayEquals(expectedPlaintext, plaintext); - } + private OutputResultMatcher(String plaintextPath) { + this.plaintextPath = plaintextPath; } - private static class ErrorResultMatcher implements ResultMatcher { + @Override + public void Match(Callable decryptor) throws Exception { + final byte[] plaintext = decryptor.call(); + final byte[] expectedPlaintext = cachedData.get(plaintextPath); + Assert.assertArrayEquals(expectedPlaintext, plaintext); + } + } - private final String errorDescription; + private static class ErrorResultMatcher implements ResultMatcher { - private ErrorResultMatcher(String errorDescription) { - this.errorDescription = errorDescription; - } + private final String errorDescription; - @Override - public void Match(Callable decryptor) { - Assert.assertThrows("Decryption expected to fail (" + errorDescription + ") but succeeded", - Exception.class, decryptor::call); - } + private ErrorResultMatcher(String errorDescription) { + this.errorDescription = errorDescription; + } + + @Override + public void Match(Callable decryptor) { + Assert.assertThrows( + "Decryption expected to fail (" + errorDescription + ") but succeeded", + Exception.class, + decryptor::call); } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/XCompatDecryptTest.java b/src/test/java/com/amazonaws/encryptionsdk/XCompatDecryptTest.java index 6cb9d2da1..84d4df574 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/XCompatDecryptTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/XCompatDecryptTest.java @@ -3,9 +3,15 @@ package com.amazonaws.encryptionsdk; +import static org.junit.Assert.assertArrayEquals; + +import com.amazonaws.encryptionsdk.internal.Utils; +import com.amazonaws.encryptionsdk.jce.JceMasterKey; +import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import java.io.File; import java.io.StringReader; -import java.lang.IllegalArgumentException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; @@ -18,199 +24,183 @@ import java.util.HashMap; import java.util.List; import java.util.Map; - import javax.crypto.spec.SecretKeySpec; - import org.apache.commons.lang3.StringUtils; - import org.bouncycastle.util.io.pem.PemReader; - -import static org.junit.Assert.assertArrayEquals; - import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; -import com.amazonaws.encryptionsdk.internal.Utils; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.core.type.TypeReference; - -import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; - @RunWith(Parameterized.class) public class XCompatDecryptTest { - private static final String STATIC_XCOMPAT_NAME = "static-aws-xcompat"; - private static final String AES_GCM = "AES/GCM/NoPadding"; - private static final byte XCOMPAT_MESSAGE_VERSION = 1; - - private String plaintextFileName; - private String ciphertextFileName; - private MasterKeyProvider masterKeyProvider; - - public XCompatDecryptTest( - String plaintextFileName, - String ciphertextFileName, - MasterKeyProvider masterKeyProvider - ) throws Exception { - this.plaintextFileName = plaintextFileName; - this.ciphertextFileName = ciphertextFileName; - this.masterKeyProvider = masterKeyProvider; + private static final String STATIC_XCOMPAT_NAME = "static-aws-xcompat"; + private static final String AES_GCM = "AES/GCM/NoPadding"; + private static final byte XCOMPAT_MESSAGE_VERSION = 1; + + private String plaintextFileName; + private String ciphertextFileName; + private MasterKeyProvider masterKeyProvider; + + public XCompatDecryptTest( + String plaintextFileName, String ciphertextFileName, MasterKeyProvider masterKeyProvider) + throws Exception { + this.plaintextFileName = plaintextFileName; + this.ciphertextFileName = ciphertextFileName; + this.masterKeyProvider = masterKeyProvider; + } + + @Parameters(name = "{index}: testDecryptFromFile({0}, {1}, {2})") + public static Collection data() throws Exception { + String baseDirName; + baseDirName = System.getProperty("staticCompatibilityResourcesDir"); + if (baseDirName == null) { + baseDirName = + XCompatDecryptTest.class.getProtectionDomain().getCodeSource().getLocation().getPath() + + "aws_encryption_sdk_resources"; } - @Parameters(name="{index}: testDecryptFromFile({0}, {1}, {2})") - public static Collection data() throws Exception{ - String baseDirName; - baseDirName = System.getProperty("staticCompatibilityResourcesDir"); - if (baseDirName == null) { - baseDirName = - XCompatDecryptTest.class.getProtectionDomain().getCodeSource().getLocation().getPath() + - "aws_encryption_sdk_resources"; - } + List testCases_ = new ArrayList(); - List testCases_ = new ArrayList(); + String ciphertextManifestName = + StringUtils.join( + new String[] {baseDirName, "manifests", "ciphertext.manifest"}, File.separator); + File ciphertextManifestFile = new File(ciphertextManifestName); - String ciphertextManifestName = StringUtils.join( - new String[]{ - baseDirName, - "manifests", - "ciphertext.manifest" - }, - File.separator - ); - File ciphertextManifestFile = new File(ciphertextManifestName); - - if (!ciphertextManifestFile.exists()) { - return Collections.emptySet(); - } - - ObjectMapper ciphertextManifestMapper = new ObjectMapper(); - Map ciphertextManifest = ciphertextManifestMapper.readValue( - ciphertextManifestFile, - new TypeReference>(){} - ); - - HashMap> staticKeyMap = new HashMap>(); - - Map testKeys = (Map)ciphertextManifest.get("test_keys"); - - for (Map.Entry keyType : testKeys.entrySet()) { - Map keys = (Map)keyType.getValue(); - HashMap thisKeyType = new HashMap(); - for (Map.Entry key : keys.entrySet()) { - Map thisKey = (Map)key.getValue(); - String keyRaw = new String( - StringUtils.join( - (List)thisKey.get("key"), - (String)thisKey.getOrDefault("line_separator", "") - ).getBytes(), - StandardCharsets.UTF_8 - ); - byte[] keyBytes; - switch ((String)thisKey.get("encoding")) { - case "base64": - keyBytes = Utils.decodeBase64String(keyRaw); - break; - case "pem": - PemReader pemReader = new PemReader(new StringReader(keyRaw)); - keyBytes = pemReader.readPemObject().getContent(); - break; - case "raw": - default: - keyBytes = keyRaw.getBytes(); - } - thisKeyType.put((String)key.getKey(), keyBytes); - } - staticKeyMap.put((String)keyType.getKey(), thisKeyType); - } + if (!ciphertextManifestFile.exists()) { + return Collections.emptySet(); + } - final KeyFactory rsaKeyFactory = KeyFactory.getInstance("RSA"); - - List> testCases = (List>)ciphertextManifest.get("test_cases"); - for (Map testCase : testCases) { - Map plaintext = (Map)testCase.get("plaintext"); - Map ciphertext = (Map)testCase.get("ciphertext"); - - short algId = (short) Integer.parseInt((String)testCase.get("algorithm"), 16); - CryptoAlgorithm encryptionAlgorithm = CryptoAlgorithm.deserialize(XCOMPAT_MESSAGE_VERSION, algId); - - List> masterKeys = (List>)testCase.get("master_keys"); - List allMasterKeys = new ArrayList(); - for (Map aMasterKey : masterKeys) { - String providerId = (String)aMasterKey.get("provider_id"); - if (providerId.equals(STATIC_XCOMPAT_NAME) && (boolean)aMasterKey.get("decryptable")) { - String paddingAlgorithm = (String)aMasterKey.getOrDefault("padding_algorithm", ""); - String paddingHash = (String)aMasterKey.getOrDefault("padding_hash", ""); - Integer keyBits = (Integer)aMasterKey.getOrDefault( - "key_bits", - encryptionAlgorithm.getDataKeyLength() * 8 - ); - String keyId = - (String)aMasterKey.get("encryption_algorithm") + "." + - keyBits.toString() + "." + - paddingAlgorithm + "." + - paddingHash; - String encAlg = (String)aMasterKey.get("encryption_algorithm"); - switch (encAlg.toUpperCase()) { - case "RSA": - String cipherBase = "RSA/ECB/"; - String cipherName; - switch (paddingAlgorithm) { - case "OAEP-MGF1": - cipherName = cipherBase + "OAEPWith" + paddingHash + "AndMGF1Padding"; - break; - case "PKCS1": - cipherName = cipherBase + paddingAlgorithm + "Padding"; - break; - default: - throw new IllegalArgumentException("Unknown padding algorithm: " + paddingAlgorithm); - } - PrivateKey privKey = rsaKeyFactory.generatePrivate(new PKCS8EncodedKeySpec(staticKeyMap.get("RSA").get(keyBits.toString()))); - allMasterKeys.add(JceMasterKey.getInstance( - null, - privKey, - STATIC_XCOMPAT_NAME, - keyId, - cipherName - )); - break; - case "AES": - SecretKeySpec spec = new SecretKeySpec( - staticKeyMap.get("AES").get(keyBits.toString()), - 0, - encryptionAlgorithm.getDataKeyLength(), - encryptionAlgorithm.getDataKeyAlgo() - ); - allMasterKeys.add(JceMasterKey.getInstance(spec, STATIC_XCOMPAT_NAME, keyId, AES_GCM)); - break; - default: - throw new IllegalArgumentException("Unknown encryption algorithm: " + encAlg.toUpperCase()); - } - } - } - - if (allMasterKeys.size() > 0) { - final MasterKeyProvider provider = MultipleProviderFactory.buildMultiProvider(allMasterKeys); - testCases_.add(new Object[]{ - baseDirName + File.separator + plaintext.get("filename"), - baseDirName + File.separator + ciphertext.get("filename"), - provider - }); - } + ObjectMapper ciphertextManifestMapper = new ObjectMapper(); + Map ciphertextManifest = + ciphertextManifestMapper.readValue( + ciphertextManifestFile, new TypeReference>() {}); + + HashMap> staticKeyMap = + new HashMap>(); + + Map testKeys = (Map) ciphertextManifest.get("test_keys"); + + for (Map.Entry keyType : testKeys.entrySet()) { + Map keys = (Map) keyType.getValue(); + HashMap thisKeyType = new HashMap(); + for (Map.Entry key : keys.entrySet()) { + Map thisKey = (Map) key.getValue(); + String keyRaw = + new String( + StringUtils.join( + (List) thisKey.get("key"), + (String) thisKey.getOrDefault("line_separator", "")) + .getBytes(), + StandardCharsets.UTF_8); + byte[] keyBytes; + switch ((String) thisKey.get("encoding")) { + case "base64": + keyBytes = Utils.decodeBase64String(keyRaw); + break; + case "pem": + PemReader pemReader = new PemReader(new StringReader(keyRaw)); + keyBytes = pemReader.readPemObject().getContent(); + break; + case "raw": + default: + keyBytes = keyRaw.getBytes(); } - return testCases_; + thisKeyType.put((String) key.getKey(), keyBytes); + } + staticKeyMap.put((String) keyType.getKey(), thisKeyType); } - @Test - public void testDecryptFromFile() throws Exception { - AwsCrypto crypto = AwsCrypto.standard(); - byte ciphertextBytes[] = Files.readAllBytes(Paths.get(ciphertextFileName)); - byte plaintextBytes[] = Files.readAllBytes(Paths.get(plaintextFileName)); - final CryptoResult decryptResult = crypto.decryptData( - masterKeyProvider, - ciphertextBytes - ); - assertArrayEquals(plaintextBytes, (byte[])decryptResult.getResult()); + final KeyFactory rsaKeyFactory = KeyFactory.getInstance("RSA"); + + List> testCases = + (List>) ciphertextManifest.get("test_cases"); + for (Map testCase : testCases) { + Map plaintext = (Map) testCase.get("plaintext"); + Map ciphertext = (Map) testCase.get("ciphertext"); + + short algId = (short) Integer.parseInt((String) testCase.get("algorithm"), 16); + CryptoAlgorithm encryptionAlgorithm = + CryptoAlgorithm.deserialize(XCOMPAT_MESSAGE_VERSION, algId); + + List> masterKeys = + (List>) testCase.get("master_keys"); + List allMasterKeys = new ArrayList(); + for (Map aMasterKey : masterKeys) { + String providerId = (String) aMasterKey.get("provider_id"); + if (providerId.equals(STATIC_XCOMPAT_NAME) && (boolean) aMasterKey.get("decryptable")) { + String paddingAlgorithm = (String) aMasterKey.getOrDefault("padding_algorithm", ""); + String paddingHash = (String) aMasterKey.getOrDefault("padding_hash", ""); + Integer keyBits = + (Integer) + aMasterKey.getOrDefault("key_bits", encryptionAlgorithm.getDataKeyLength() * 8); + String keyId = + (String) aMasterKey.get("encryption_algorithm") + + "." + + keyBits.toString() + + "." + + paddingAlgorithm + + "." + + paddingHash; + String encAlg = (String) aMasterKey.get("encryption_algorithm"); + switch (encAlg.toUpperCase()) { + case "RSA": + String cipherBase = "RSA/ECB/"; + String cipherName; + switch (paddingAlgorithm) { + case "OAEP-MGF1": + cipherName = cipherBase + "OAEPWith" + paddingHash + "AndMGF1Padding"; + break; + case "PKCS1": + cipherName = cipherBase + paddingAlgorithm + "Padding"; + break; + default: + throw new IllegalArgumentException( + "Unknown padding algorithm: " + paddingAlgorithm); + } + PrivateKey privKey = + rsaKeyFactory.generatePrivate( + new PKCS8EncodedKeySpec(staticKeyMap.get("RSA").get(keyBits.toString()))); + allMasterKeys.add( + JceMasterKey.getInstance(null, privKey, STATIC_XCOMPAT_NAME, keyId, cipherName)); + break; + case "AES": + SecretKeySpec spec = + new SecretKeySpec( + staticKeyMap.get("AES").get(keyBits.toString()), + 0, + encryptionAlgorithm.getDataKeyLength(), + encryptionAlgorithm.getDataKeyAlgo()); + allMasterKeys.add( + JceMasterKey.getInstance(spec, STATIC_XCOMPAT_NAME, keyId, AES_GCM)); + break; + default: + throw new IllegalArgumentException( + "Unknown encryption algorithm: " + encAlg.toUpperCase()); + } + } + } + + if (allMasterKeys.size() > 0) { + final MasterKeyProvider provider = + MultipleProviderFactory.buildMultiProvider(allMasterKeys); + testCases_.add( + new Object[] { + baseDirName + File.separator + plaintext.get("filename"), + baseDirName + File.separator + ciphertext.get("filename"), + provider + }); + } } + return testCases_; + } + + @Test + public void testDecryptFromFile() throws Exception { + AwsCrypto crypto = AwsCrypto.standard(); + byte ciphertextBytes[] = Files.readAllBytes(Paths.get(ciphertextFileName)); + byte plaintextBytes[] = Files.readAllBytes(Paths.get(plaintextFileName)); + final CryptoResult decryptResult = crypto.decryptData(masterKeyProvider, ciphertextBytes); + assertArrayEquals(plaintextBytes, (byte[]) decryptResult.getResult()); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/caching/CacheIdentifierTests.java b/src/test/java/com/amazonaws/encryptionsdk/caching/CacheIdentifierTests.java index ad2c11749..dbb1ab832 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/caching/CacheIdentifierTests.java +++ b/src/test/java/com/amazonaws/encryptionsdk/caching/CacheIdentifierTests.java @@ -6,6 +6,14 @@ import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; +import com.amazonaws.encryptionsdk.CommitmentPolicy; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.CryptoMaterialsManager; +import com.amazonaws.encryptionsdk.TestUtils; +import com.amazonaws.encryptionsdk.internal.Utils; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.model.KeyBlob; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; @@ -14,188 +22,205 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; - -import com.amazonaws.encryptionsdk.TestUtils; -import com.amazonaws.encryptionsdk.CommitmentPolicy; import org.bouncycastle.util.encoders.Hex; import org.junit.Test; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.CryptoMaterialsManager; -import com.amazonaws.encryptionsdk.internal.Utils; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; -import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; -import com.amazonaws.encryptionsdk.model.KeyBlob; - public class CacheIdentifierTests { - static String partitionName = "c15b9079-6d0e-42b6-8784-5e804b025692"; - static Map contextEmpty = Collections.emptyMap(); - static Map contextFull; - static CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; - static { - contextFull = new HashMap<>(); - contextFull.put("this", "is"); - contextFull.put("a", "non-empty"); - contextFull.put("encryption", "context"); - } - - static List keyBlobs = Arrays.asList( - new KeyBlob("this is a provider ID", "this is some key info".getBytes(UTF_8), - "super secret key, now with encryption!".getBytes(UTF_8) - ), - new KeyBlob("another provider ID!", "this is some different key info".getBytes(UTF_8), - "better super secret key, now with encryption!".getBytes(UTF_8) - ) - ); - - @Test - public void pythonTestVecs() throws Exception { - assertEncryptId(partitionName, null, contextEmpty, - "rkrFAso1YyPbOJbmwVMjrPw+wwLJT7xusn8tA8zMe9e3+OqbtfDueB7bvoKLU3fsmdUvZ6eMt7mBp1ThMMB25Q=="); - assertEncryptId(partitionName, ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - contextEmpty, - "3icBIkLK4V3fVwbm3zSxUdUQV6ZvZYUOLl8buN36g6gDMqAkghcGryxX7QiVABkW1JhB6GRp5z+bzbiuciBcKQ=="); - assertEncryptId(partitionName, null, contextFull, - "IHiUHYOUVUEFTc3BcZPJDlsWct2Qy1A7JdfQl9sQoV/ILIbRpoz9q7RtGd/MlibaGl5ihE66cN8ygM8A5rtYbg=="); - assertEncryptId(partitionName, ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - contextFull, - "mRNK7qhTb/kJiiyGPgAevp0gwFRcET4KeeNYwZHhoEDvSUzQiDgl8Of+YRDaVzKxAqpNBgcAuFXde9JlaRRsmw=="); - - assertDecryptId(partitionName, - ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256, - Collections.singletonList(keyBlobs.get(0)), - contextEmpty, - "n0zVzk9QIVxhz6ET+aJIKKOJNxtpGtSe1yAbu7WU5l272Iw/jmhlER4psDHJs9Mr8KYiIvLGSXzggNDCc23+9w==" - ); - - assertDecryptId(partitionName, - ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - keyBlobs, - contextFull, - "+rtwUe38CGnczGmYu12iqGWHIyDyZ44EvYQ4S6ACmsgS8VaEpiw0RTGpDk6Z/7YYN/jVHOAcNKDyCNP8EmstFg==" - ); - } - - void assertDecryptId(String partitionName, CryptoAlgorithm algo, List blobs, Map context, String expect) throws Exception { - DecryptionMaterialsRequest request = - DecryptionMaterialsRequest.newBuilder() - .setAlgorithm(algo) - .setEncryptionContext(context) - .setEncryptedDataKeys(blobs) - .build(); - - byte[] id = getCacheIdentifier(getCMM(partitionName), request); - - assertEquals(expect, Utils.encodeBase64String(id)); - } - - void assertEncryptId(String partitionName, CryptoAlgorithm algo, Map context, String expect) throws Exception { - EncryptionMaterialsRequest request = EncryptionMaterialsRequest.newBuilder() - .setContext(context) - .setRequestedAlgorithm(algo) - .setCommitmentPolicy(commitmentPolicy) - .build(); - - byte[] id = getCacheIdentifier(getCMM(partitionName), request); - - assertEquals(expect, Utils.encodeBase64String(id)); - } - - @Test - public void encryptDigestTestVector() throws Exception { - HashMap contextMap = new HashMap<>(); - - contextMap.put("\0\0TEST", "\0\0test"); - // Note! This key is actually U+10000, but java treats it as a UTF-16 surrogate pair. - // UTF-8 encoding should be 0xF0 0x90 0x80 0x80 - contextMap.put("\uD800\uDC00", "UTF-16 surrogate"); - contextMap.put("\uABCD", "\\uABCD"); - - byte[] id = getCacheIdentifier(getCMM("partition ID"), - EncryptionMaterialsRequest.newBuilder() - .setContext(contextMap) - .setRequestedAlgorithm(null) - .setCommitmentPolicy(commitmentPolicy) - .build() - ); - - assertEquals( - "683328d033fc60a20e3d3936190b33d91aad0143163226af9530e7d1b3de0e96" + - "39c00a2885f9cea09cf9a273bef316a39616475b50adc2441b69f67e1a25145f", - new String(Hex.encode(id))); - - id = getCacheIdentifier(getCMM("partition ID"), - EncryptionMaterialsRequest.newBuilder() - .setContext(contextMap) - .setRequestedAlgorithm(CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256) - .setCommitmentPolicy(commitmentPolicy) - .build() - ); - - assertEquals( - "3dc70ff1d4621059b97179563ab6592dff4319bfaf8ed1a819c96d33d3194d5c" + - "354a361e879d0356e4d9e868170ebc9e934fa5eaf6e6d11de4ee801645723fa9", - new String(Hex.encode(id))); - } - - @Test - public void decryptDigestTestVector() throws Exception { - HashMap contextMap = new HashMap<>(); - - contextMap.put("\0\0TEST", "\0\0test"); - // Note! This key is actually U+10000, but java treats it as a UTF-16 surrogate pair. - // UTF-8 encoding should be 0xF0 0x90 0x80 0x80 - contextMap.put("\uD800\uDC00", "UTF-16 surrogate"); - contextMap.put("\uABCD", "\\uABCD"); - - ArrayList keyBlobs = new ArrayList<>(); - - keyBlobs.addAll( - Arrays.asList( - new KeyBlob("", new byte[] {}, new byte[] {}), // always first - new KeyBlob("\0", new byte[] { 0 }, new byte[] { 0 }), - new KeyBlob("\u0081", new byte[] { (byte) 0x81 }, new byte[] { (byte) 0x81 }), - new KeyBlob("abc", Hex.decode("deadbeef"), Hex.decode("bad0cafe")) - ) - ); - - assertEquals( - "e16344634350fe8cb51e69ec4e0681c84ac7ef2df427bd4de4aefbebcd3ead22" + - "95f1b15a98ce60699e0efbf69dbbc12e2552b16eff84a6e9b5766ee4d69a7897", - - new String(Hex.encode( - getCacheIdentifier(getCMM("partition ID"), - DecryptionMaterialsRequest.newBuilder() - .setAlgorithm(CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256) - .setEncryptionContext(contextMap) - .setEncryptedDataKeys(keyBlobs) - .build() - ) - )) - ); - } - - private byte[] getCacheIdentifier(CachingCryptoMaterialsManager cmm, EncryptionMaterialsRequest request) throws Exception { - Method m = CachingCryptoMaterialsManager.class.getDeclaredMethod("getCacheIdentifier", EncryptionMaterialsRequest.class); - m.setAccessible(true); - - return (byte[])m.invoke(cmm, request); - } - - private byte[] getCacheIdentifier(CachingCryptoMaterialsManager cmm, DecryptionMaterialsRequest request) throws Exception { - Method m = CachingCryptoMaterialsManager.class.getDeclaredMethod("getCacheIdentifier", DecryptionMaterialsRequest.class); - m.setAccessible(true); - - return (byte[])m.invoke(cmm, request); - } - - private CachingCryptoMaterialsManager getCMM(final String partitionName) { - return CachingCryptoMaterialsManager.newBuilder() - .withCache(mock(CryptoMaterialsCache.class)) - .withBackingMaterialsManager(mock(CryptoMaterialsManager.class)) - .withMaxAge(1, TimeUnit.MILLISECONDS) - .withPartitionId(partitionName) - .build(); - } + static String partitionName = "c15b9079-6d0e-42b6-8784-5e804b025692"; + static Map contextEmpty = Collections.emptyMap(); + static Map contextFull; + static CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; + + static { + contextFull = new HashMap<>(); + contextFull.put("this", "is"); + contextFull.put("a", "non-empty"); + contextFull.put("encryption", "context"); + } + + static List keyBlobs = + Arrays.asList( + new KeyBlob( + "this is a provider ID", + "this is some key info".getBytes(UTF_8), + "super secret key, now with encryption!".getBytes(UTF_8)), + new KeyBlob( + "another provider ID!", + "this is some different key info".getBytes(UTF_8), + "better super secret key, now with encryption!".getBytes(UTF_8))); + + @Test + public void pythonTestVecs() throws Exception { + assertEncryptId( + partitionName, + null, + contextEmpty, + "rkrFAso1YyPbOJbmwVMjrPw+wwLJT7xusn8tA8zMe9e3+OqbtfDueB7bvoKLU3fsmdUvZ6eMt7mBp1ThMMB25Q=="); + assertEncryptId( + partitionName, + ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + contextEmpty, + "3icBIkLK4V3fVwbm3zSxUdUQV6ZvZYUOLl8buN36g6gDMqAkghcGryxX7QiVABkW1JhB6GRp5z+bzbiuciBcKQ=="); + assertEncryptId( + partitionName, + null, + contextFull, + "IHiUHYOUVUEFTc3BcZPJDlsWct2Qy1A7JdfQl9sQoV/ILIbRpoz9q7RtGd/MlibaGl5ihE66cN8ygM8A5rtYbg=="); + assertEncryptId( + partitionName, + ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + contextFull, + "mRNK7qhTb/kJiiyGPgAevp0gwFRcET4KeeNYwZHhoEDvSUzQiDgl8Of+YRDaVzKxAqpNBgcAuFXde9JlaRRsmw=="); + + assertDecryptId( + partitionName, + ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256, + Collections.singletonList(keyBlobs.get(0)), + contextEmpty, + "n0zVzk9QIVxhz6ET+aJIKKOJNxtpGtSe1yAbu7WU5l272Iw/jmhlER4psDHJs9Mr8KYiIvLGSXzggNDCc23+9w=="); + + assertDecryptId( + partitionName, + ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + keyBlobs, + contextFull, + "+rtwUe38CGnczGmYu12iqGWHIyDyZ44EvYQ4S6ACmsgS8VaEpiw0RTGpDk6Z/7YYN/jVHOAcNKDyCNP8EmstFg=="); + } + + void assertDecryptId( + String partitionName, + CryptoAlgorithm algo, + List blobs, + Map context, + String expect) + throws Exception { + DecryptionMaterialsRequest request = + DecryptionMaterialsRequest.newBuilder() + .setAlgorithm(algo) + .setEncryptionContext(context) + .setEncryptedDataKeys(blobs) + .build(); + + byte[] id = getCacheIdentifier(getCMM(partitionName), request); + + assertEquals(expect, Utils.encodeBase64String(id)); + } + + void assertEncryptId( + String partitionName, CryptoAlgorithm algo, Map context, String expect) + throws Exception { + EncryptionMaterialsRequest request = + EncryptionMaterialsRequest.newBuilder() + .setContext(context) + .setRequestedAlgorithm(algo) + .setCommitmentPolicy(commitmentPolicy) + .build(); + + byte[] id = getCacheIdentifier(getCMM(partitionName), request); + + assertEquals(expect, Utils.encodeBase64String(id)); + } + + @Test + public void encryptDigestTestVector() throws Exception { + HashMap contextMap = new HashMap<>(); + + contextMap.put("\0\0TEST", "\0\0test"); + // Note! This key is actually U+10000, but java treats it as a UTF-16 surrogate pair. + // UTF-8 encoding should be 0xF0 0x90 0x80 0x80 + contextMap.put("\uD800\uDC00", "UTF-16 surrogate"); + contextMap.put("\uABCD", "\\uABCD"); + + byte[] id = + getCacheIdentifier( + getCMM("partition ID"), + EncryptionMaterialsRequest.newBuilder() + .setContext(contextMap) + .setRequestedAlgorithm(null) + .setCommitmentPolicy(commitmentPolicy) + .build()); + + assertEquals( + "683328d033fc60a20e3d3936190b33d91aad0143163226af9530e7d1b3de0e96" + + "39c00a2885f9cea09cf9a273bef316a39616475b50adc2441b69f67e1a25145f", + new String(Hex.encode(id))); + + id = + getCacheIdentifier( + getCMM("partition ID"), + EncryptionMaterialsRequest.newBuilder() + .setContext(contextMap) + .setRequestedAlgorithm(CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256) + .setCommitmentPolicy(commitmentPolicy) + .build()); + + assertEquals( + "3dc70ff1d4621059b97179563ab6592dff4319bfaf8ed1a819c96d33d3194d5c" + + "354a361e879d0356e4d9e868170ebc9e934fa5eaf6e6d11de4ee801645723fa9", + new String(Hex.encode(id))); + } + + @Test + public void decryptDigestTestVector() throws Exception { + HashMap contextMap = new HashMap<>(); + + contextMap.put("\0\0TEST", "\0\0test"); + // Note! This key is actually U+10000, but java treats it as a UTF-16 surrogate pair. + // UTF-8 encoding should be 0xF0 0x90 0x80 0x80 + contextMap.put("\uD800\uDC00", "UTF-16 surrogate"); + contextMap.put("\uABCD", "\\uABCD"); + + ArrayList keyBlobs = new ArrayList<>(); + + keyBlobs.addAll( + Arrays.asList( + new KeyBlob("", new byte[] {}, new byte[] {}), // always first + new KeyBlob("\0", new byte[] {0}, new byte[] {0}), + new KeyBlob("\u0081", new byte[] {(byte) 0x81}, new byte[] {(byte) 0x81}), + new KeyBlob("abc", Hex.decode("deadbeef"), Hex.decode("bad0cafe")))); + + assertEquals( + "e16344634350fe8cb51e69ec4e0681c84ac7ef2df427bd4de4aefbebcd3ead22" + + "95f1b15a98ce60699e0efbf69dbbc12e2552b16eff84a6e9b5766ee4d69a7897", + new String( + Hex.encode( + getCacheIdentifier( + getCMM("partition ID"), + DecryptionMaterialsRequest.newBuilder() + .setAlgorithm( + CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256) + .setEncryptionContext(contextMap) + .setEncryptedDataKeys(keyBlobs) + .build())))); + } + + private byte[] getCacheIdentifier( + CachingCryptoMaterialsManager cmm, EncryptionMaterialsRequest request) throws Exception { + Method m = + CachingCryptoMaterialsManager.class.getDeclaredMethod( + "getCacheIdentifier", EncryptionMaterialsRequest.class); + m.setAccessible(true); + + return (byte[]) m.invoke(cmm, request); + } + + private byte[] getCacheIdentifier( + CachingCryptoMaterialsManager cmm, DecryptionMaterialsRequest request) throws Exception { + Method m = + CachingCryptoMaterialsManager.class.getDeclaredMethod( + "getCacheIdentifier", DecryptionMaterialsRequest.class); + m.setAccessible(true); + + return (byte[]) m.invoke(cmm, request); + } + + private CachingCryptoMaterialsManager getCMM(final String partitionName) { + return CachingCryptoMaterialsManager.newBuilder() + .withCache(mock(CryptoMaterialsCache.class)) + .withBackingMaterialsManager(mock(CryptoMaterialsManager.class)) + .withMaxAge(1, TimeUnit.MILLISECONDS) + .withPartitionId(partitionName) + .build(); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/caching/CacheTestFixtures.java b/src/test/java/com/amazonaws/encryptionsdk/caching/CacheTestFixtures.java index 4b3f38eef..d3dd3850b 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/caching/CacheTestFixtures.java +++ b/src/test/java/com/amazonaws/encryptionsdk/caching/CacheTestFixtures.java @@ -1,89 +1,90 @@ package com.amazonaws.encryptionsdk.caching; -import static org.mockito.ArgumentMatchers.eq; - -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import java.util.Collections; - +import com.amazonaws.encryptionsdk.CommitmentPolicy; import com.amazonaws.encryptionsdk.DataKey; import com.amazonaws.encryptionsdk.DefaultCryptoMaterialsManager; import com.amazonaws.encryptionsdk.MasterKey; import com.amazonaws.encryptionsdk.TestUtils; import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import com.amazonaws.encryptionsdk.CommitmentPolicy; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; -import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; +import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import java.util.Collections; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; public class CacheTestFixtures { - private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; + private static final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; - private static final MasterKey FIXED_KEY = JceMasterKey.getInstance( - new SecretKeySpec(TestUtils.insecureRandomBytes(16), "AES"), - "prov", - "keyid", - "AES/GCM/NoPadding" - ); + private static final MasterKey FIXED_KEY = + JceMasterKey.getInstance( + new SecretKeySpec(TestUtils.insecureRandomBytes(16), "AES"), + "prov", + "keyid", + "AES/GCM/NoPadding"); - public static EncryptionMaterialsRequest createMaterialsRequest(int index) { - return EncryptionMaterialsRequest.newBuilder() - .setContext(Collections.singletonMap("index", Integer.toString(index))) - .setCommitmentPolicy(commitmentPolicy) - .build(); - } + public static EncryptionMaterialsRequest createMaterialsRequest(int index) { + return EncryptionMaterialsRequest.newBuilder() + .setContext(Collections.singletonMap("index", Integer.toString(index))) + .setCommitmentPolicy(commitmentPolicy) + .build(); + } - public static EncryptionMaterials createMaterialsResult(EncryptionMaterialsRequest request) { - return new DefaultCryptoMaterialsManager(FIXED_KEY).getMaterialsForEncrypt(request) - .toBuilder() - .setCleartextDataKey(new SentinelKey()) - .build(); - } + public static EncryptionMaterials createMaterialsResult(EncryptionMaterialsRequest request) { + return new DefaultCryptoMaterialsManager(FIXED_KEY) + .getMaterialsForEncrypt(request).toBuilder().setCleartextDataKey(new SentinelKey()).build(); + } - public static DecryptionMaterialsRequest createDecryptRequest(int index) { - EncryptionMaterialsRequest mreq = createMaterialsRequest(index); - EncryptionMaterials mres = createMaterialsResult(mreq); + public static DecryptionMaterialsRequest createDecryptRequest(int index) { + EncryptionMaterialsRequest mreq = createMaterialsRequest(index); + EncryptionMaterials mres = createMaterialsResult(mreq); - return createDecryptRequest(mres); - } + return createDecryptRequest(mres); + } - public static DecryptionMaterialsRequest createDecryptRequest(EncryptionMaterials mres) { - return DecryptionMaterialsRequest.newBuilder() - .setAlgorithm(mres.getAlgorithm()) - .setEncryptionContext(mres.getEncryptionContext()) - .setEncryptedDataKeys(mres.getEncryptedDataKeys()) - .build(); - } + public static DecryptionMaterialsRequest createDecryptRequest(EncryptionMaterials mres) { + return DecryptionMaterialsRequest.newBuilder() + .setAlgorithm(mres.getAlgorithm()) + .setEncryptionContext(mres.getEncryptionContext()) + .setEncryptedDataKeys(mres.getEncryptedDataKeys()) + .build(); + } - public static DecryptionMaterials createDecryptResult(DecryptionMaterialsRequest request) { - DecryptionMaterials realResult = new DefaultCryptoMaterialsManager(FIXED_KEY).decryptMaterials(request); - return realResult - .toBuilder() - .setDataKey(new DataKey(new SentinelKey(), - realResult.getDataKey().getEncryptedDataKey(), - realResult.getDataKey().getProviderInformation(), - realResult.getDataKey().getMasterKey())) - .build(); - } + public static DecryptionMaterials createDecryptResult(DecryptionMaterialsRequest request) { + DecryptionMaterials realResult = + new DefaultCryptoMaterialsManager(FIXED_KEY).decryptMaterials(request); + return realResult.toBuilder() + .setDataKey( + new DataKey( + new SentinelKey(), + realResult.getDataKey().getEncryptedDataKey(), + realResult.getDataKey().getProviderInformation(), + realResult.getDataKey().getMasterKey())) + .build(); + } - static EncryptionMaterials createMaterialsResult() { - return createMaterialsResult(createMaterialsRequest(0)); - } + static EncryptionMaterials createMaterialsResult() { + return createMaterialsResult(createMaterialsRequest(0)); + } - // These SentinelKeys let us detect when a particular DecryptionMaterials or EncryptionMaterials is being used, without - // being concerned about matching all of the fields - we can just use object identity. - public static class SentinelKey implements SecretKey { - @Override public String getAlgorithm() { - return "AES"; - } + // These SentinelKeys let us detect when a particular DecryptionMaterials or EncryptionMaterials + // is being used, without + // being concerned about matching all of the fields - we can just use object identity. + public static class SentinelKey implements SecretKey { + @Override + public String getAlgorithm() { + return "AES"; + } - @Override public String getFormat() { - return "RAW"; - } + @Override + public String getFormat() { + return "RAW"; + } - @Override public byte[] getEncoded() { - return null; - } + @Override + public byte[] getEncoded() { + return null; } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/caching/CachingCryptoMaterialsManagerTest.java b/src/test/java/com/amazonaws/encryptionsdk/caching/CachingCryptoMaterialsManagerTest.java index c1f0ba785..2e59be627 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/caching/CachingCryptoMaterialsManagerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/caching/CachingCryptoMaterialsManagerTest.java @@ -14,394 +14,401 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -import javax.crypto.spec.SecretKeySpec; -import java.util.Arrays; -import java.util.concurrent.TimeUnit; - +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.CryptoMaterialsManager; +import com.amazonaws.encryptionsdk.caching.CryptoMaterialsCache.EncryptCacheEntry; +import com.amazonaws.encryptionsdk.caching.CryptoMaterialsCache.UsageStats; +import com.amazonaws.encryptionsdk.jce.JceMasterKey; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; +import javax.crypto.spec.SecretKeySpec; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import com.amazonaws.encryptionsdk.CryptoMaterialsManager; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.caching.CryptoMaterialsCache.EncryptCacheEntry; -import com.amazonaws.encryptionsdk.caching.CryptoMaterialsCache.UsageStats; -import com.amazonaws.encryptionsdk.jce.JceMasterKey; - public class CachingCryptoMaterialsManagerTest { - private static final String PARTITION_ID = "partition ID"; - @Mock private CryptoMaterialsCache cache; - @Mock private CryptoMaterialsManager delegate; - private CachingCryptoMaterialsManager cmm; - private CachingCryptoMaterialsManager.Builder builder; - private long maxAgeMs = 123456789; - - @Before - public void setUp() throws Exception { - MockitoAnnotations.initMocks(this); - - when(cache.putEntryForEncrypt(any(), any(), any(), any())).thenAnswer( - invocation -> entryFor((EncryptionMaterials)invocation.getArguments()[1], UsageStats.ZERO) - ); - when(delegate.getMaterialsForEncrypt(any())).thenThrow(new RuntimeException("Unexpected invocation")); - when(delegate.decryptMaterials(any())).thenThrow(new RuntimeException("Unexpected invocation")); - - builder = CachingCryptoMaterialsManager.newBuilder().withBackingMaterialsManager(delegate) - .withCache(cache) - .withPartitionId(PARTITION_ID) - .withMaxAge(maxAgeMs, TimeUnit.MILLISECONDS) - .withByteUseLimit(200) - .withMessageUseLimit(100); - cmm = builder.build(); - } - - @Test - public void whenCacheIsEmpty_performsCacheMiss() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0) - .toBuilder().setPlaintextSize(100).build(); - EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); - - EncryptCacheEntry entry = setupForCacheMiss(request, result); - - EncryptionMaterials actualResult = cmm.getMaterialsForEncrypt(request); - - assertEquals(result, actualResult); - - verify(delegate).getMaterialsForEncrypt(request); - verify(cache).putEntryForEncrypt(any(), any(), any(), eq(new UsageStats(100, 1))); - } - - @Test - public void whenCacheMisses_correctHintAndUsagePassed() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0) - .toBuilder().setPlaintextSize(100).build(); - EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); - - setupForCacheMiss(request, result); - cmm.getMaterialsForEncrypt(request); - - ArgumentCaptor hintCaptor = ArgumentCaptor.forClass(CryptoMaterialsCache.CacheHint.class); - verify(cache).putEntryForEncrypt(any(), any(), hintCaptor.capture(), any()); - - assertEquals(maxAgeMs, hintCaptor.getValue().getMaxAgeMillis()); - } - - @Test - public void whenCacheHasEntry_performsCacheHit() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0) - .toBuilder() - .setPlaintextSize(100) - .build(); - EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); - EncryptCacheEntry entry = entryFor(result, UsageStats.ZERO); - when(cache.getEntryForEncrypt(any(), any())).thenReturn(entry); - - assertEquals(result, cmm.getMaterialsForEncrypt(request)); - verify(delegate, never()).getMaterialsForEncrypt(any()); - - ArgumentCaptor statsCaptor = ArgumentCaptor.forClass(UsageStats.class); - verify(cache).getEntryForEncrypt(any(), statsCaptor.capture()); - assertEquals(statsCaptor.getValue(), new UsageStats(100, 1)); - } - - @Test - public void whenAlgorithmIsUncachable_resultNotStoredInCache() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0) - .toBuilder() - .setPlaintextSize(100) - .build(); - EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request).toBuilder() - .setAlgorithm(CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF) - .build(); - setupForCacheMiss(request, result); - - CachingCryptoMaterialsManager allowNoKdfCMM = CachingCryptoMaterialsManager.newBuilder() - .withBackingMaterialsManager(delegate) - .withCache(cache) - .withPartitionId(PARTITION_ID) - .withMaxAge(maxAgeMs, TimeUnit.MILLISECONDS) - .withByteUseLimit(200) - .withMessageUseLimit(100) - .build(); - - assertEquals(result, allowNoKdfCMM.getMaterialsForEncrypt(request)); - verify(cache, never()).putEntryForEncrypt(any(), any(), any(), any()); - } - - @Test - public void whenInitialUsageExceedsLimit_cacheIsBypassed() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0) - .toBuilder() - // Even at _exactly_ the byte-use limit, we won't try the cache, - // because it's unlikely to be useful to leave an entry with zero - // bytes remaining. - .setPlaintextSize(200) - .build(); - EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request).toBuilder() - .setAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY) - .build(); - setupForCacheMiss(request, result); - - assertEquals(result, cmm.getMaterialsForEncrypt(request)); - verifyNoMoreInteractions(cache); - } - - @Test - public void whenCacheEntryIsExhausted_byMessageLimit_performsCacheMiss() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0) - .toBuilder() - .setPlaintextSize(100) - .build(); - EncryptionMaterials cacheHitResult = CacheTestFixtures.createMaterialsResult(request); - doReturn(CacheTestFixtures.createMaterialsResult(request)).when(delegate).getMaterialsForEncrypt(request); - - EncryptCacheEntry entry = entryFor(cacheHitResult, new UsageStats(0, 101)); - - when(cache.getEntryForEncrypt(any(), any())).thenReturn(entry); - - EncryptionMaterials returnedResult = cmm.getMaterialsForEncrypt(request); - - assertNotEquals(cacheHitResult, returnedResult); - verify(delegate, times(1)).getMaterialsForEncrypt(any()); - verify(cache).putEntryForEncrypt(any(), eq(returnedResult), any(), any()); - } - - @Test - public void whenEncryptCacheEntryIsExpired_performsCacheMiss() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0) - .toBuilder() - .setPlaintextSize(100) - .build(); - EncryptionMaterials cacheHitResult = CacheTestFixtures.createMaterialsResult(request); - doReturn(CacheTestFixtures.createMaterialsResult(request)).when(delegate).getMaterialsForEncrypt(request); - - EncryptCacheEntry entry = entryFor(cacheHitResult, new UsageStats(0, 100)); - when(entry.getEntryCreationTime()).thenReturn(System.currentTimeMillis() - maxAgeMs - 1); - - when(cache.getEntryForEncrypt(any(), any())).thenReturn(entry); - - EncryptionMaterials returnedResult = cmm.getMaterialsForEncrypt(request); - - assertNotEquals(cacheHitResult, returnedResult); - verify(delegate, times(1)).getMaterialsForEncrypt(any()); - verify(cache).putEntryForEncrypt(any(), eq(returnedResult), any(), any()); - verify(entry).invalidate(); + private static final String PARTITION_ID = "partition ID"; + @Mock private CryptoMaterialsCache cache; + @Mock private CryptoMaterialsManager delegate; + private CachingCryptoMaterialsManager cmm; + private CachingCryptoMaterialsManager.Builder builder; + private long maxAgeMs = 123456789; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + + when(cache.putEntryForEncrypt(any(), any(), any(), any())) + .thenAnswer( + invocation -> + entryFor((EncryptionMaterials) invocation.getArguments()[1], UsageStats.ZERO)); + when(delegate.getMaterialsForEncrypt(any())) + .thenThrow(new RuntimeException("Unexpected invocation")); + when(delegate.decryptMaterials(any())).thenThrow(new RuntimeException("Unexpected invocation")); + + builder = + CachingCryptoMaterialsManager.newBuilder() + .withBackingMaterialsManager(delegate) + .withCache(cache) + .withPartitionId(PARTITION_ID) + .withMaxAge(maxAgeMs, TimeUnit.MILLISECONDS) + .withByteUseLimit(200) + .withMessageUseLimit(100); + cmm = builder.build(); + } + + @Test + public void whenCacheIsEmpty_performsCacheMiss() throws Exception { + EncryptionMaterialsRequest request = + CacheTestFixtures.createMaterialsRequest(0).toBuilder().setPlaintextSize(100).build(); + EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); + + EncryptCacheEntry entry = setupForCacheMiss(request, result); + + EncryptionMaterials actualResult = cmm.getMaterialsForEncrypt(request); + + assertEquals(result, actualResult); + + verify(delegate).getMaterialsForEncrypt(request); + verify(cache).putEntryForEncrypt(any(), any(), any(), eq(new UsageStats(100, 1))); + } + + @Test + public void whenCacheMisses_correctHintAndUsagePassed() throws Exception { + EncryptionMaterialsRequest request = + CacheTestFixtures.createMaterialsRequest(0).toBuilder().setPlaintextSize(100).build(); + EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); + + setupForCacheMiss(request, result); + cmm.getMaterialsForEncrypt(request); + + ArgumentCaptor hintCaptor = + ArgumentCaptor.forClass(CryptoMaterialsCache.CacheHint.class); + verify(cache).putEntryForEncrypt(any(), any(), hintCaptor.capture(), any()); + + assertEquals(maxAgeMs, hintCaptor.getValue().getMaxAgeMillis()); + } + + @Test + public void whenCacheHasEntry_performsCacheHit() throws Exception { + EncryptionMaterialsRequest request = + CacheTestFixtures.createMaterialsRequest(0).toBuilder().setPlaintextSize(100).build(); + EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); + EncryptCacheEntry entry = entryFor(result, UsageStats.ZERO); + when(cache.getEntryForEncrypt(any(), any())).thenReturn(entry); + + assertEquals(result, cmm.getMaterialsForEncrypt(request)); + verify(delegate, never()).getMaterialsForEncrypt(any()); + + ArgumentCaptor statsCaptor = ArgumentCaptor.forClass(UsageStats.class); + verify(cache).getEntryForEncrypt(any(), statsCaptor.capture()); + assertEquals(statsCaptor.getValue(), new UsageStats(100, 1)); + } + + @Test + public void whenAlgorithmIsUncachable_resultNotStoredInCache() throws Exception { + EncryptionMaterialsRequest request = + CacheTestFixtures.createMaterialsRequest(0).toBuilder().setPlaintextSize(100).build(); + EncryptionMaterials result = + CacheTestFixtures.createMaterialsResult(request).toBuilder() + .setAlgorithm(CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF) + .build(); + setupForCacheMiss(request, result); + + CachingCryptoMaterialsManager allowNoKdfCMM = + CachingCryptoMaterialsManager.newBuilder() + .withBackingMaterialsManager(delegate) + .withCache(cache) + .withPartitionId(PARTITION_ID) + .withMaxAge(maxAgeMs, TimeUnit.MILLISECONDS) + .withByteUseLimit(200) + .withMessageUseLimit(100) + .build(); + + assertEquals(result, allowNoKdfCMM.getMaterialsForEncrypt(request)); + verify(cache, never()).putEntryForEncrypt(any(), any(), any(), any()); + } + + @Test + public void whenInitialUsageExceedsLimit_cacheIsBypassed() throws Exception { + EncryptionMaterialsRequest request = + CacheTestFixtures.createMaterialsRequest(0).toBuilder() + // Even at _exactly_ the byte-use limit, we won't try the cache, + // because it's unlikely to be useful to leave an entry with zero + // bytes remaining. + .setPlaintextSize(200) + .build(); + EncryptionMaterials result = + CacheTestFixtures.createMaterialsResult(request).toBuilder() + .setAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY) + .build(); + setupForCacheMiss(request, result); + + assertEquals(result, cmm.getMaterialsForEncrypt(request)); + verifyNoMoreInteractions(cache); + } + + @Test + public void whenCacheEntryIsExhausted_byMessageLimit_performsCacheMiss() throws Exception { + EncryptionMaterialsRequest request = + CacheTestFixtures.createMaterialsRequest(0).toBuilder().setPlaintextSize(100).build(); + EncryptionMaterials cacheHitResult = CacheTestFixtures.createMaterialsResult(request); + doReturn(CacheTestFixtures.createMaterialsResult(request)) + .when(delegate) + .getMaterialsForEncrypt(request); + + EncryptCacheEntry entry = entryFor(cacheHitResult, new UsageStats(0, 101)); + + when(cache.getEntryForEncrypt(any(), any())).thenReturn(entry); + + EncryptionMaterials returnedResult = cmm.getMaterialsForEncrypt(request); + + assertNotEquals(cacheHitResult, returnedResult); + verify(delegate, times(1)).getMaterialsForEncrypt(any()); + verify(cache).putEntryForEncrypt(any(), eq(returnedResult), any(), any()); + } + + @Test + public void whenEncryptCacheEntryIsExpired_performsCacheMiss() throws Exception { + EncryptionMaterialsRequest request = + CacheTestFixtures.createMaterialsRequest(0).toBuilder().setPlaintextSize(100).build(); + EncryptionMaterials cacheHitResult = CacheTestFixtures.createMaterialsResult(request); + doReturn(CacheTestFixtures.createMaterialsResult(request)) + .when(delegate) + .getMaterialsForEncrypt(request); + + EncryptCacheEntry entry = entryFor(cacheHitResult, new UsageStats(0, 100)); + when(entry.getEntryCreationTime()).thenReturn(System.currentTimeMillis() - maxAgeMs - 1); + + when(cache.getEntryForEncrypt(any(), any())).thenReturn(entry); + + EncryptionMaterials returnedResult = cmm.getMaterialsForEncrypt(request); + + assertNotEquals(cacheHitResult, returnedResult); + verify(delegate, times(1)).getMaterialsForEncrypt(any()); + verify(cache).putEntryForEncrypt(any(), eq(returnedResult), any(), any()); + verify(entry).invalidate(); + } + + @Test + public void whenCacheEntryIsExhausted_byByteLimit_performsCacheMiss() throws Exception { + EncryptionMaterialsRequest request = + CacheTestFixtures.createMaterialsRequest(0).toBuilder().setPlaintextSize(100).build(); + EncryptionMaterials cacheHitResult = CacheTestFixtures.createMaterialsResult(request); + doReturn(CacheTestFixtures.createMaterialsResult(request)) + .when(delegate) + .getMaterialsForEncrypt(request); + + EncryptCacheEntry entry = entryFor(cacheHitResult, new UsageStats(1_000_000 - 99, 0)); + + when(cache.getEntryForEncrypt(any(), any())).thenReturn(entry); + + EncryptionMaterials returnedResult = cmm.getMaterialsForEncrypt(request); + + assertNotEquals(cacheHitResult, returnedResult); + verify(delegate, times(1)).getMaterialsForEncrypt(any()); + verify(cache).putEntryForEncrypt(any(), eq(returnedResult), any(), any()); + } + + @Test + public void whenStreaming_cacheMiss_withNoSizeHint_doesNotCache() throws Exception { + EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0); + EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); + EncryptCacheEntry entry = setupForCacheMiss(request, result); + + EncryptionMaterials actualResult = cmm.getMaterialsForEncrypt(request); + + verifyNoMoreInteractions(cache); + } + + @Test + public void whenDecrypting_cacheMiss() throws Exception { + DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(0); + DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); + + doReturn(result).when(delegate).decryptMaterials(any()); + + DecryptionMaterials actual = cmm.decryptMaterials(request); + + assertEquals(result, actual); + verify(cache).putEntryForDecrypt(any(), eq(result), any()); + } + + @Test + public void whenDecryptCacheMisses_correctHintPassed() throws Exception { + DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(0); + DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); + + doReturn(result).when(delegate).decryptMaterials(any()); + + cmm.decryptMaterials(request); + + ArgumentCaptor hintCaptor = + ArgumentCaptor.forClass(CryptoMaterialsCache.CacheHint.class); + verify(cache).putEntryForDecrypt(any(), any(), hintCaptor.capture()); + + assertEquals(maxAgeMs, hintCaptor.getValue().getMaxAgeMillis()); + } + + @Test + public void whenDecrypting_cacheHit() throws Exception { + DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(0); + DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); + + when(cache.getEntryForDecrypt(any())).thenReturn(new TestDecryptCacheEntry(result)); + + DecryptionMaterials actual = cmm.decryptMaterials(request); + + assertEquals(result, actual); + verify(cache, never()).putEntryForDecrypt(any(), any(), any()); + verify(delegate, never()).decryptMaterials(any()); + } + + @Test + public void whenDecrypting_andEntryExpired_cacheMiss() throws Exception { + DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(0); + DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); + doReturn(CacheTestFixtures.createDecryptResult(request)).when(delegate).decryptMaterials(any()); + + TestDecryptCacheEntry entry = new TestDecryptCacheEntry(result); + entry.creationTime -= (maxAgeMs + 1); + when(cache.getEntryForDecrypt(any())).thenReturn(entry); + + DecryptionMaterials actual = cmm.decryptMaterials(request); + + assertNotEquals(result, actual); + verify(delegate, times(1)).decryptMaterials(any()); + verify(cache, times(1)).putEntryForDecrypt(any(), any(), any()); + } + + @Test + public void testBuilderValidation() throws Exception { + CachingCryptoMaterialsManager.Builder b = CachingCryptoMaterialsManager.newBuilder(); + + assertThrows(() -> b.withMaxAge(-1, TimeUnit.MILLISECONDS)); + assertThrows(() -> b.withMaxAge(0, TimeUnit.MILLISECONDS)); + assertThrows(() -> b.withMessageUseLimit(-1)); + assertThrows(() -> b.withMessageUseLimit(1L << 33)); + assertThrows(() -> b.withByteUseLimit(-1)); + + assertThrows(b::build); // backing CMM not set + b.withBackingMaterialsManager(delegate); + assertThrows(b::build); // cache not set + b.withCache(cache); + assertThrows(b::build); // max age + b.withMaxAge(1, TimeUnit.SECONDS); + b.build(); + } + + @Test + public void whenBuilderReused_uniquePartitionSet() throws Exception { + EncryptionMaterialsRequest request = + CacheTestFixtures.createMaterialsRequest(0).toBuilder().setPlaintextSize(1).build(); + EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); + EncryptCacheEntry entry = setupForCacheMiss(request, result); + + CachingCryptoMaterialsManager.Builder builder = + CachingCryptoMaterialsManager.newBuilder() + .withCache(cache) + .withBackingMaterialsManager(delegate) + .withMaxAge(5, TimeUnit.DAYS); + + builder.build().getMaterialsForEncrypt(request); + builder.build().getMaterialsForEncrypt(request); + + ArgumentCaptor idCaptor = ArgumentCaptor.forClass(byte[].class); + verify(cache, times(2)).getEntryForEncrypt(idCaptor.capture(), any()); + + byte[] firstId = idCaptor.getAllValues().get(0); + byte[] secondId = idCaptor.getAllValues().get(1); + + assertFalse(Arrays.equals(firstId, secondId)); + } + + @Test + public void whenMKPPassed_itIsUsed() throws Exception { + JceMasterKey key = + spy( + JceMasterKey.getInstance( + new SecretKeySpec(new byte[16], "AES"), "provider", "keyId", "AES/GCM/NoPadding")); + CryptoMaterialsManager cmm = + CachingCryptoMaterialsManager.newBuilder() + .withCache(cache) + .withMasterKeyProvider(key) + .withMaxAge(5, TimeUnit.DAYS) + .build(); + + cmm.getMaterialsForEncrypt(CacheTestFixtures.createMaterialsRequest(0)); + verify(key).generateDataKey(any(), any()); + } + + private EncryptCacheEntry setupForCacheMiss( + EncryptionMaterialsRequest request, EncryptionMaterials result) throws Exception { + doReturn(result).when(delegate).getMaterialsForEncrypt(request); + EncryptCacheEntry entry = entryFor(result, UsageStats.ZERO); + doReturn(entry).when(cache).putEntryForEncrypt(any(), eq(result), any(), any()); + + return entry; + } + + private EncryptCacheEntry entryFor(EncryptionMaterials result, final UsageStats initialUsage) + throws Exception { + return spy(new TestEncryptCacheEntry(result, initialUsage)); + } + + private static class TestEncryptCacheEntry implements EncryptCacheEntry { + private final EncryptionMaterials result; + private final UsageStats stats; + + public TestEncryptCacheEntry(EncryptionMaterials result, UsageStats initialUsage) { + this.result = result; + stats = initialUsage; } - @Test - public void whenCacheEntryIsExhausted_byByteLimit_performsCacheMiss() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0) - .toBuilder() - .setPlaintextSize(100) - .build(); - EncryptionMaterials cacheHitResult = CacheTestFixtures.createMaterialsResult(request); - doReturn(CacheTestFixtures.createMaterialsResult(request)).when(delegate).getMaterialsForEncrypt(request); - - EncryptCacheEntry entry = entryFor(cacheHitResult, new UsageStats(1_000_000 - 99, 0)); - - when(cache.getEntryForEncrypt(any(), any())).thenReturn(entry); - - EncryptionMaterials returnedResult = cmm.getMaterialsForEncrypt(request); - - assertNotEquals(cacheHitResult, returnedResult); - verify(delegate, times(1)).getMaterialsForEncrypt(any()); - verify(cache).putEntryForEncrypt(any(), eq(returnedResult), any(), any()); - } - - @Test - public void whenStreaming_cacheMiss_withNoSizeHint_doesNotCache() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0); - EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); - EncryptCacheEntry entry = setupForCacheMiss(request, result); - - EncryptionMaterials actualResult = cmm.getMaterialsForEncrypt(request); - - verifyNoMoreInteractions(cache); - } - - @Test - public void whenDecrypting_cacheMiss() throws Exception { - DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(0); - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); - - doReturn(result).when(delegate).decryptMaterials(any()); - - DecryptionMaterials actual = cmm.decryptMaterials(request); - - assertEquals(result, actual); - verify(cache).putEntryForDecrypt(any(), eq(result), any()); + @Override + public UsageStats getUsageStats() { + return stats; } - @Test - public void whenDecryptCacheMisses_correctHintPassed() throws Exception { - DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(0); - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); - - doReturn(result).when(delegate).decryptMaterials(any()); - - cmm.decryptMaterials(request); - - ArgumentCaptor hintCaptor = ArgumentCaptor.forClass(CryptoMaterialsCache.CacheHint.class); - verify(cache).putEntryForDecrypt(any(), any(), hintCaptor.capture()); - - assertEquals(maxAgeMs, hintCaptor.getValue().getMaxAgeMillis()); + @Override + public long getEntryCreationTime() { + return System.currentTimeMillis(); } - @Test - public void whenDecrypting_cacheHit() throws Exception { - DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(0); - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); - - when(cache.getEntryForDecrypt(any())).thenReturn(new TestDecryptCacheEntry(result)); - - DecryptionMaterials actual = cmm.decryptMaterials(request); - - assertEquals(result, actual); - verify(cache, never()).putEntryForDecrypt(any(), any(), any()); - verify(delegate, never()).decryptMaterials(any()); + @Override + public EncryptionMaterials getResult() { + return result; } - @Test - public void whenDecrypting_andEntryExpired_cacheMiss() throws Exception { - DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(0); - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); - doReturn(CacheTestFixtures.createDecryptResult(request)).when(delegate).decryptMaterials(any()); - - TestDecryptCacheEntry entry = new TestDecryptCacheEntry(result); - entry.creationTime -= (maxAgeMs + 1); - when(cache.getEntryForDecrypt(any())).thenReturn(entry); + @Override + public void invalidate() {} + } - DecryptionMaterials actual = cmm.decryptMaterials(request); + private class TestDecryptCacheEntry implements CryptoMaterialsCache.DecryptCacheEntry { + private final DecryptionMaterials result; + private long creationTime = System.currentTimeMillis(); - assertNotEquals(result, actual); - verify(delegate, times(1)).decryptMaterials(any()); - verify(cache, times(1)).putEntryForDecrypt(any(), any(), any()); + public TestDecryptCacheEntry(final DecryptionMaterials result) { + this.result = result; } - @Test - public void testBuilderValidation() throws Exception { - CachingCryptoMaterialsManager.Builder b = CachingCryptoMaterialsManager.newBuilder(); - - assertThrows(() -> b.withMaxAge(-1, TimeUnit.MILLISECONDS)); - assertThrows(() -> b.withMaxAge(0, TimeUnit.MILLISECONDS)); - assertThrows(() -> b.withMessageUseLimit(-1)); - assertThrows(() -> b.withMessageUseLimit(1L << 33)); - assertThrows(() -> b.withByteUseLimit(-1)); - - assertThrows(b::build); // backing CMM not set - b.withBackingMaterialsManager(delegate); - assertThrows(b::build); // cache not set - b.withCache(cache); - assertThrows(b::build); // max age - b.withMaxAge(1, TimeUnit.SECONDS); - b.build(); + @Override + public DecryptionMaterials getResult() { + return result; } - @Test - public void whenBuilderReused_uniquePartitionSet() throws Exception { - EncryptionMaterialsRequest request = CacheTestFixtures.createMaterialsRequest(0) - .toBuilder().setPlaintextSize(1).build(); - EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(request); - EncryptCacheEntry entry = setupForCacheMiss(request, result); - - CachingCryptoMaterialsManager.Builder builder = CachingCryptoMaterialsManager.newBuilder() - .withCache(cache) - .withBackingMaterialsManager(delegate) - .withMaxAge(5, TimeUnit.DAYS); - - builder.build().getMaterialsForEncrypt(request); - builder.build().getMaterialsForEncrypt(request); - - ArgumentCaptor idCaptor = ArgumentCaptor.forClass(byte[].class); - verify(cache, times(2)).getEntryForEncrypt(idCaptor.capture(), any()); - - byte[] firstId = idCaptor.getAllValues().get(0); - byte[] secondId = idCaptor.getAllValues().get(1); - - assertFalse(Arrays.equals(firstId, secondId)); - } - - @Test - public void whenMKPPassed_itIsUsed() throws Exception { - JceMasterKey key = spy(JceMasterKey.getInstance(new SecretKeySpec(new byte[16], "AES"), - "provider", - "keyId", - "AES/GCM/NoPadding")); - CryptoMaterialsManager cmm = CachingCryptoMaterialsManager.newBuilder() - .withCache(cache) - .withMasterKeyProvider(key) - .withMaxAge(5, TimeUnit.DAYS) - .build(); - - cmm.getMaterialsForEncrypt(CacheTestFixtures.createMaterialsRequest(0)); - verify(key).generateDataKey(any(), any()); - } - - private EncryptCacheEntry setupForCacheMiss(EncryptionMaterialsRequest request, EncryptionMaterials result) throws Exception { - doReturn(result).when(delegate).getMaterialsForEncrypt(request); - EncryptCacheEntry entry = entryFor(result, UsageStats.ZERO); - doReturn(entry).when(cache).putEntryForEncrypt(any(), eq(result), any(), any()); - - return entry; - } - - private EncryptCacheEntry entryFor( - EncryptionMaterials result, - final UsageStats initialUsage - ) throws Exception { - return spy(new TestEncryptCacheEntry(result, initialUsage)); - } - - private static class TestEncryptCacheEntry implements EncryptCacheEntry { - private final EncryptionMaterials result; - private final UsageStats stats; - - public TestEncryptCacheEntry(EncryptionMaterials result, UsageStats initialUsage) { - this.result = result; - stats = initialUsage; - } - - @Override public UsageStats getUsageStats() { - return stats; - } - - @Override public long getEntryCreationTime() { - return System.currentTimeMillis(); - } - - @Override public EncryptionMaterials getResult() { - return result; - } - - @Override public void invalidate() { - - } - } - - private class TestDecryptCacheEntry implements CryptoMaterialsCache.DecryptCacheEntry{ - private final DecryptionMaterials result; - private long creationTime = System.currentTimeMillis(); - - public TestDecryptCacheEntry(final DecryptionMaterials result) { - this.result = result; - } - - @Override public DecryptionMaterials getResult() { - return result; - } - - @Override public void invalidate() { - - } + @Override + public void invalidate() {} - @Override public long getEntryCreationTime() { - return creationTime; - } + @Override + public long getEntryCreationTime() { + return creationTime; } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCacheTest.java b/src/test/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCacheTest.java index 98a61dd2d..731ab613a 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCacheTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCacheTest.java @@ -1,345 +1,356 @@ package com.amazonaws.encryptionsdk.caching; import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; -import static com.amazonaws.encryptionsdk.caching.CacheTestFixtures.createDecryptRequest; -import static com.amazonaws.encryptionsdk.caching.CacheTestFixtures.createMaterialsResult; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import com.amazonaws.encryptionsdk.caching.CryptoMaterialsCache.UsageStats; +import com.amazonaws.encryptionsdk.model.DecryptionMaterials; +import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import java.lang.reflect.Field; import java.util.Map; import java.util.Optional; - import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import com.amazonaws.encryptionsdk.caching.CryptoMaterialsCache.UsageStats; -import com.amazonaws.encryptionsdk.model.DecryptionMaterials; -import com.amazonaws.encryptionsdk.model.EncryptionMaterials; - public class LocalCryptoMaterialsCacheTest { - public static final String PARTTION_NAME = "foo"; - FakeClock clock; - LocalCryptoMaterialsCache cache; - CryptoMaterialsCache.CacheHint hint = () -> 1000; // maxAge = 1000 - - @Before - public void setUp() { - clock = new FakeClock(); - cache = new LocalCryptoMaterialsCache(5); - cache.clock = clock; + public static final String PARTTION_NAME = "foo"; + FakeClock clock; + LocalCryptoMaterialsCache cache; + CryptoMaterialsCache.CacheHint hint = () -> 1000; // maxAge = 1000 + + @Before + public void setUp() { + clock = new FakeClock(); + cache = new LocalCryptoMaterialsCache(5); + cache.clock = clock; + } + + @Test + public void whenNoEntriesInCache_noEntriesReturned() { + assertNull(cache.getEntryForDecrypt(new byte[10])); + byte[] cacheId = new byte[10]; + assertNull(cache.getEntryForEncrypt(cacheId, UsageStats.ZERO)); + } + + @Test + public void whenEntriesAddedToDecryptCache_correctEntriesReturned() { + DecryptionMaterials result1 = + CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(1)); + DecryptionMaterials result2 = + CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(2)); + + cache.putEntryForDecrypt(new byte[] {1}, result1, hint); + cache.putEntryForDecrypt(new byte[] {2}, result2, hint); + assertEquals(result2, cache.getEntryForDecrypt(new byte[] {2}).getResult()); + assertEquals(result1, cache.getEntryForDecrypt(new byte[] {1}).getResult()); + } + + @Test + public void whenManyDecryptEntriesAdded_LRURespected() { + DecryptionMaterials[] results = new DecryptionMaterials[6]; + + for (int i = 0; i < results.length; i++) { + results[i] = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(i)); } - @Test - public void whenNoEntriesInCache_noEntriesReturned() { - assertNull(cache.getEntryForDecrypt(new byte[10])); - byte[] cacheId = new byte[10]; - assertNull(cache.getEntryForEncrypt(cacheId, UsageStats.ZERO)); - } + cache.putEntryForDecrypt(new byte[] {0}, results[0], hint); + cache.putEntryForDecrypt(new byte[] {1}, results[1], hint); + cache.putEntryForDecrypt(new byte[] {2}, results[2], hint); + cache.putEntryForDecrypt(new byte[] {3}, results[3], hint); + cache.putEntryForDecrypt(new byte[] {4}, results[4], hint); - @Test - public void whenEntriesAddedToDecryptCache_correctEntriesReturned() { - DecryptionMaterials result1 = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(1)); - DecryptionMaterials result2 = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(2)); + // make entry 0 most recently used + assertEquals(results[0], cache.getEntryForDecrypt(new byte[] {0}).getResult()); - cache.putEntryForDecrypt(new byte[]{1}, result1, hint); - cache.putEntryForDecrypt(new byte[]{2}, result2, hint); - assertEquals(result2, cache.getEntryForDecrypt(new byte[]{2}).getResult()); - assertEquals(result1, cache.getEntryForDecrypt(new byte[]{1}).getResult()); - } + // entry 1 is evicted + cache.putEntryForDecrypt(new byte[] {5}, results[5], hint); - @Test - public void whenManyDecryptEntriesAdded_LRURespected() { - DecryptionMaterials[] results = new DecryptionMaterials[6]; + for (int i = 0; i < results.length; i++) { + DecryptionMaterials actualResult = + Optional.ofNullable(cache.getEntryForDecrypt(new byte[] {(byte) i})) + .map(CryptoMaterialsCache.DecryptCacheEntry::getResult) + .orElse(null); + DecryptionMaterials expected = (i == 1) ? null : results[i]; - for (int i = 0; i < results.length; i++) { - results[i] = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(i)); - } - - cache.putEntryForDecrypt(new byte[]{0}, results[0], hint); - cache.putEntryForDecrypt(new byte[]{1}, results[1], hint); - cache.putEntryForDecrypt(new byte[]{2}, results[2], hint); - cache.putEntryForDecrypt(new byte[]{3}, results[3], hint); - cache.putEntryForDecrypt(new byte[]{4}, results[4], hint); - - // make entry 0 most recently used - assertEquals(results[0], cache.getEntryForDecrypt(new byte[] {0}).getResult()); - - // entry 1 is evicted - cache.putEntryForDecrypt(new byte[]{5}, results[5], hint); - - for (int i = 0; i < results.length; i++) { - DecryptionMaterials actualResult = - Optional.ofNullable(cache.getEntryForDecrypt(new byte[] {(byte)i})) - .map(CryptoMaterialsCache.DecryptCacheEntry::getResult) - .orElse(null); - DecryptionMaterials expected = (i == 1) ? null : results[i]; - - assertEquals("index " + i, expected, actualResult); - } + assertEquals("index " + i, expected, actualResult); } - - @Test - public void whenEncryptEntriesAdded_theyCanBeRetrieved() { - EncryptionMaterials - result1a = CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); - EncryptionMaterials - result1b = CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); - EncryptionMaterials - result2 = CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(1)); - - cache.putEntryForEncrypt(new byte[] {0}, result1a, hint, UsageStats.ZERO); - cache.putEntryForEncrypt(new byte[] {0}, result1b, hint, UsageStats.ZERO); - cache.putEntryForEncrypt(new byte[] {1}, result2, hint, UsageStats.ZERO); - - assertEncryptEntry(new byte[]{0}, result1b); - assertEncryptEntry(new byte[]{1}, result2); + } + + @Test + public void whenEncryptEntriesAdded_theyCanBeRetrieved() { + EncryptionMaterials result1a = + CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); + EncryptionMaterials result1b = + CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); + EncryptionMaterials result2 = + CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(1)); + + cache.putEntryForEncrypt(new byte[] {0}, result1a, hint, UsageStats.ZERO); + cache.putEntryForEncrypt(new byte[] {0}, result1b, hint, UsageStats.ZERO); + cache.putEntryForEncrypt(new byte[] {1}, result2, hint, UsageStats.ZERO); + + assertEncryptEntry(new byte[] {0}, result1b); + assertEncryptEntry(new byte[] {1}, result2); + } + + @Test + public void whenInitialUsagePassed_itIsRetained() { + UsageStats stats = new UsageStats(123, 456); + EncryptionMaterials result1a = + CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); + cache.putEntryForEncrypt(new byte[] {0}, result1a, hint, stats); + + assertEquals(stats, cache.getEntryForEncrypt(new byte[] {0}, UsageStats.ZERO).getUsageStats()); + } + + @Test + public void whenManyEncryptEntriesAdded_LRUIsRespected() { + EncryptionMaterials[] results = new EncryptionMaterials[6]; + for (int i = 0; i < results.length; i++) { + results[i] = + CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(i / 3)); + cache.putEntryForEncrypt(new byte[] {(byte) (i)}, results[i], hint, UsageStats.ZERO); } - @Test - public void whenInitialUsagePassed_itIsRetained() { - UsageStats stats = new UsageStats(123, 456); - EncryptionMaterials - result1a = CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); - cache.putEntryForEncrypt(new byte[] {0}, result1a, hint, stats); + for (int i = 0; i < results.length; i++) { + EncryptionMaterials expected = i == 0 ? null : results[i]; - assertEquals(stats, cache.getEntryForEncrypt(new byte[]{0}, UsageStats.ZERO).getUsageStats()); + assertEncryptEntry(new byte[] {(byte) i}, expected); } - - @Test - public void whenManyEncryptEntriesAdded_LRUIsRespected() { - EncryptionMaterials[] results = new EncryptionMaterials[6]; - for (int i = 0; i < results.length; i++) { - results[i] = CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(i / 3)); - cache.putEntryForEncrypt(new byte[]{(byte)(i)}, results[i], hint, UsageStats.ZERO); - } - - for (int i = 0; i < results.length; i++) { - EncryptionMaterials expected = i == 0 ? null : results[i]; - - assertEncryptEntry(new byte[]{(byte)i}, expected); - } + } + + @Test + public void whenManyEncryptEntriesAdded_andEntriesTouched_LRUIsRespected() { + EncryptionMaterials[] results = new EncryptionMaterials[6]; + for (int i = 0; i < 3; i++) { + results[i] = + CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); + cache.putEntryForEncrypt(new byte[] {(byte) i}, results[i], hint, UsageStats.ZERO); } - @Test - public void whenManyEncryptEntriesAdded_andEntriesTouched_LRUIsRespected() { - EncryptionMaterials[] results = new EncryptionMaterials[6]; - for (int i = 0; i < 3; i++) { - results[i] = CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); - cache.putEntryForEncrypt(new byte[]{(byte)i}, results[i], hint, UsageStats.ZERO); - } - - cache.getEntryForEncrypt(new byte[]{0}, UsageStats.ZERO); - - for (int i = 3; i < 6; i++) { - results[i] = CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); - cache.putEntryForEncrypt(new byte[]{(byte)i}, results[i], hint, UsageStats.ZERO); - } - - assertEncryptEntry(new byte[]{0}, results[0]); - assertEncryptEntry(new byte[]{1}, null); - assertEncryptEntry(new byte[]{2}, results[2]); - assertEncryptEntry(new byte[]{3}, results[3]); - assertEncryptEntry(new byte[]{4}, results[4]); - assertEncryptEntry(new byte[]{5}, results[5]); - } + cache.getEntryForEncrypt(new byte[] {0}, UsageStats.ZERO); - @Test - public void whenManyEncryptEntriesAdded_andEntryInvalidated_LRUIsRespected() { - EncryptionMaterials[] results = new EncryptionMaterials[6]; - for (int i = 0; i < 3; i++) { - results[i] = CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); - cache.putEntryForEncrypt(new byte[]{(byte) i}, results[i], hint, UsageStats.ZERO); - } - - cache.getEntryForEncrypt(new byte[]{2}, UsageStats.ZERO).invalidate(); - - for (int i = 3; i < 6; i++) { - results[i] = CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); - cache.putEntryForEncrypt(new byte[]{(byte) i}, results[i], hint, UsageStats.ZERO); - } - - assertEncryptEntry(new byte[]{0}, results[0]); - assertEncryptEntry(new byte[]{1}, results[1]); - assertEncryptEntry(new byte[]{2}, null); - assertEncryptEntry(new byte[]{3}, results[3]); - assertEncryptEntry(new byte[]{4}, results[4]); - assertEncryptEntry(new byte[]{5}, results[5]); + for (int i = 3; i < 6; i++) { + results[i] = + CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); + cache.putEntryForEncrypt(new byte[] {(byte) i}, results[i], hint, UsageStats.ZERO); } - @Test - public void testCacheEntryBehavior() { - EncryptionMaterials result = createResult(); - CryptoMaterialsCache.EncryptCacheEntry e = cache.putEntryForEncrypt(new byte[]{0}, result, hint, - new UsageStats(1, 2)); - assertEquals(clock.now, e.getEntryCreationTime()); - - assertEquals(new UsageStats(1, 2), e.getUsageStats()); - - CryptoMaterialsCache.EncryptCacheEntry e2 = cache.getEntryForEncrypt(new byte[]{0}, new UsageStats(200, 100)); - // Old entry usage is unchanged - assertEquals(new UsageStats(1, 2), e.getUsageStats()); - assertEquals(new UsageStats(201, 102), e2.getUsageStats()); - - e2.invalidate(); - // All EncryptCacheEntry methods should still work after invalidation - Assert.assertEquals(result, e2.getResult()); - assertEquals(new UsageStats(201, 102), e2.getUsageStats()); + assertEncryptEntry(new byte[] {0}, results[0]); + assertEncryptEntry(new byte[] {1}, null); + assertEncryptEntry(new byte[] {2}, results[2]); + assertEncryptEntry(new byte[] {3}, results[3]); + assertEncryptEntry(new byte[] {4}, results[4]); + assertEncryptEntry(new byte[] {5}, results[5]); + } + + @Test + public void whenManyEncryptEntriesAdded_andEntryInvalidated_LRUIsRespected() { + EncryptionMaterials[] results = new EncryptionMaterials[6]; + for (int i = 0; i < 3; i++) { + results[i] = + CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); + cache.putEntryForEncrypt(new byte[] {(byte) i}, results[i], hint, UsageStats.ZERO); } - @Test - public void whenTTLExceeded_encryptEntriesAreEvicted() throws Exception { - EncryptionMaterials result = createResult(); - cache.putEntryForEncrypt(new byte[]{0}, result, () -> 500, UsageStats.ZERO); - clock.now += 500; + cache.getEntryForEncrypt(new byte[] {2}, UsageStats.ZERO).invalidate(); - assertEncryptEntry(new byte[]{0}, result); - clock.now += 1; - assertEncryptEntry(new byte[]{0}, null); + for (int i = 3; i < 6; i++) { + results[i] = + CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); + cache.putEntryForEncrypt(new byte[] {(byte) i}, results[i], hint, UsageStats.ZERO); + } - // Verify that the cache isn't hanging on to memory once it notices the entry is expired - assertEquals(0, getCacheMap(cache).size()); + assertEncryptEntry(new byte[] {0}, results[0]); + assertEncryptEntry(new byte[] {1}, results[1]); + assertEncryptEntry(new byte[] {2}, null); + assertEncryptEntry(new byte[] {3}, results[3]); + assertEncryptEntry(new byte[] {4}, results[4]); + assertEncryptEntry(new byte[] {5}, results[5]); + } + + @Test + public void testCacheEntryBehavior() { + EncryptionMaterials result = createResult(); + CryptoMaterialsCache.EncryptCacheEntry e = + cache.putEntryForEncrypt(new byte[] {0}, result, hint, new UsageStats(1, 2)); + assertEquals(clock.now, e.getEntryCreationTime()); + + assertEquals(new UsageStats(1, 2), e.getUsageStats()); + + CryptoMaterialsCache.EncryptCacheEntry e2 = + cache.getEntryForEncrypt(new byte[] {0}, new UsageStats(200, 100)); + // Old entry usage is unchanged + assertEquals(new UsageStats(1, 2), e.getUsageStats()); + assertEquals(new UsageStats(201, 102), e2.getUsageStats()); + + e2.invalidate(); + // All EncryptCacheEntry methods should still work after invalidation + Assert.assertEquals(result, e2.getResult()); + assertEquals(new UsageStats(201, 102), e2.getUsageStats()); + } + + @Test + public void whenTTLExceeded_encryptEntriesAreEvicted() throws Exception { + EncryptionMaterials result = createResult(); + cache.putEntryForEncrypt(new byte[] {0}, result, () -> 500, UsageStats.ZERO); + clock.now += 500; + + assertEncryptEntry(new byte[] {0}, result); + clock.now += 1; + assertEncryptEntry(new byte[] {0}, null); + + // Verify that the cache isn't hanging on to memory once it notices the entry is expired + assertEquals(0, getCacheMap(cache).size()); + } + + @Test + public void whenManyEntriesExpireAtOnce_expiredEncryptEntriesStillNotReturned() { + // Our active TTL expiration logic will only remove a certain number of entries per call, make + // sure that even + // if we bail out before removing a particular entry, it's still filtered from the return value. + cache = new LocalCryptoMaterialsCache(200); + cache.clock = clock; + + for (int i = 0; i < 100; i++) { + cache.putEntryForEncrypt(new byte[] {(byte) i}, createResult(), () -> 500, UsageStats.ZERO); } - @Test - public void whenManyEntriesExpireAtOnce_expiredEncryptEntriesStillNotReturned() { - // Our active TTL expiration logic will only remove a certain number of entries per call, make sure that even - // if we bail out before removing a particular entry, it's still filtered from the return value. - cache = new LocalCryptoMaterialsCache(200); - cache.clock = clock; + cache.putEntryForEncrypt(new byte[] {(byte) 0xFF}, createResult(), () -> 501, UsageStats.ZERO); + clock.now += 502; - for (int i = 0; i < 100; i++) { - cache.putEntryForEncrypt(new byte[]{(byte)i}, createResult(), () -> 500, UsageStats.ZERO); - } + assertEncryptEntry(new byte[] {(byte) 0xFF}, null); + } - cache.putEntryForEncrypt(new byte[]{(byte)0xFF}, createResult(), () -> 501, UsageStats.ZERO); - clock.now += 502; + @Test + public void whenAccessed_encryptEntryTTLNotReset() { + EncryptionMaterials result = createResult(); + cache.putEntryForEncrypt(new byte[] {0}, result, hint, UsageStats.ZERO); - assertEncryptEntry(new byte[]{(byte)0xFF}, null); - } + clock.now += 1000; + assertEncryptEntry(new byte[] {0}, result); + clock.now += 1; + assertEncryptEntry(new byte[] {0}, null); + } - @Test - public void whenAccessed_encryptEntryTTLNotReset() { - EncryptionMaterials result = createResult(); - cache.putEntryForEncrypt(new byte[]{0}, result, hint, UsageStats.ZERO); + @Test + public void whenTTLExceeded_decryptEntriesAreEvicted() throws Exception { + DecryptionMaterials result = + CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); - clock.now += 1000; - assertEncryptEntry(new byte[]{0}, result); - clock.now += 1; - assertEncryptEntry(new byte[]{0}, null); - } + cache.putEntryForDecrypt(new byte[] {0}, result, hint); - @Test - public void whenTTLExceeded_decryptEntriesAreEvicted() throws Exception { - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); + clock.now += 1001; + assertNull(cache.getEntryForDecrypt(new byte[] {0})); - cache.putEntryForDecrypt(new byte[]{0}, result, hint); + // Verify that the cache isn't hanging on to memory once it notices the entry is expired + assertEquals(0, getCacheMap(cache).size()); + } - clock.now += 1001; - assertNull(cache.getEntryForDecrypt(new byte[]{0})); + @Test + public void whenAccessed_decryptEntryTTLNotReset() { + DecryptionMaterials result = + CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); - // Verify that the cache isn't hanging on to memory once it notices the entry is expired - assertEquals(0, getCacheMap(cache).size()); - } + cache.putEntryForDecrypt(new byte[] {0}, result, hint); - @Test - public void whenAccessed_decryptEntryTTLNotReset() { - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); + clock.now += 500; + assertNotNull(cache.getEntryForDecrypt(new byte[] {0})); + clock.now += 501; + assertNull(cache.getEntryForDecrypt(new byte[] {0})); + } - cache.putEntryForDecrypt(new byte[]{0}, result, hint); + @Test + public void whenManyEntriesExpireAtOnce_expiredDecryptEntriesStillNotReturned() { + cache = new LocalCryptoMaterialsCache(200); + cache.clock = clock; - clock.now += 500; - assertNotNull(cache.getEntryForDecrypt(new byte[]{0})); - clock.now += 501; - assertNull(cache.getEntryForDecrypt(new byte[]{0})); + for (int i = 0; i < 100; i++) { + cache.putEntryForEncrypt( + new byte[] {(byte) (i + 1)}, createResult(), () -> 500, UsageStats.ZERO); } - @Test - public void whenManyEntriesExpireAtOnce_expiredDecryptEntriesStillNotReturned() { - cache = new LocalCryptoMaterialsCache(200); - cache.clock = clock; + DecryptionMaterials result = + CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); + cache.putEntryForDecrypt(new byte[] {0}, result, () -> 501); - for (int i = 0; i < 100; i++) { - cache.putEntryForEncrypt(new byte[]{(byte)(i + 1)}, createResult(), () -> 500, UsageStats.ZERO); - } + // our encrypt entries will expire first + clock.now += 502; + assertNull(cache.getEntryForDecrypt(new byte[] {0})); + } - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); - cache.putEntryForDecrypt(new byte[]{0}, result, () -> 501); + @Test + public void testDecryptInvalidate() { + DecryptionMaterials result = + CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); - // our encrypt entries will expire first - clock.now += 502; - assertNull(cache.getEntryForDecrypt(new byte[]{0})); - } + cache.putEntryForDecrypt(new byte[] {0}, result, hint); - @Test - public void testDecryptInvalidate() { - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); + cache.getEntryForDecrypt(new byte[] {0}).invalidate(); + assertNull(cache.getEntryForDecrypt(new byte[] {0})); + } - cache.putEntryForDecrypt(new byte[]{0}, result, hint); + @Test + public void testDecryptEntryCreationTime() { + DecryptionMaterials result = + CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); - cache.getEntryForDecrypt(new byte[]{0}).invalidate(); - assertNull(cache.getEntryForDecrypt(new byte[]{0})); - } + cache.putEntryForDecrypt(new byte[] {0}, result, hint); - @Test - public void testDecryptEntryCreationTime() { - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); + assertEquals( + clock.timestamp(), cache.getEntryForDecrypt(new byte[] {0}).getEntryCreationTime()); + } - cache.putEntryForDecrypt(new byte[]{0}, result, hint); + @Test + public void whenIdentifiersDifferInLowOrderBytes_theyAreNotConsideredEquivalent() + throws Exception { + DecryptionMaterials result = + CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); - assertEquals(clock.timestamp(), cache.getEntryForDecrypt(new byte[]{0}).getEntryCreationTime()); - } - - @Test - public void whenIdentifiersDifferInLowOrderBytes_theyAreNotConsideredEquivalent() throws Exception { - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); - - cache.putEntryForDecrypt(new byte[128], result, hint); + cache.putEntryForDecrypt(new byte[128], result, hint); - for (int i = 0; i < 128; i++) { - byte[] otherIdentifier = new byte[128]; - otherIdentifier[i]++; + for (int i = 0; i < 128; i++) { + byte[] otherIdentifier = new byte[128]; + otherIdentifier[i]++; - assertNull(cache.getEntryForDecrypt(otherIdentifier)); - } + assertNull(cache.getEntryForDecrypt(otherIdentifier)); } + } - @Test - public void testUsageStatsCtorValidation() { - assertThrows(() -> new UsageStats(1, -1)); - assertThrows(() -> new UsageStats(-1, 1)); - } + @Test + public void testUsageStatsCtorValidation() { + assertThrows(() -> new UsageStats(1, -1)); + assertThrows(() -> new UsageStats(-1, 1)); + } - private EncryptionMaterials createResult() { - return CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); - } + private EncryptionMaterials createResult() { + return CacheTestFixtures.createMaterialsResult(CacheTestFixtures.createMaterialsRequest(0)); + } - private void assertEncryptEntry(byte[] cacheId, EncryptionMaterials expectedResult) { - CryptoMaterialsCache.EncryptCacheEntry entry = cache.getEntryForEncrypt(cacheId, UsageStats.ZERO); - EncryptionMaterials actual = entry == null ? null : entry.getResult(); + private void assertEncryptEntry(byte[] cacheId, EncryptionMaterials expectedResult) { + CryptoMaterialsCache.EncryptCacheEntry entry = + cache.getEntryForEncrypt(cacheId, UsageStats.ZERO); + EncryptionMaterials actual = entry == null ? null : entry.getResult(); - assertEquals(expectedResult, actual); - } + assertEquals(expectedResult, actual); + } - private Map getCacheMap(LocalCryptoMaterialsCache cache) throws Exception { - Field field = LocalCryptoMaterialsCache.class.getDeclaredField("cacheMap"); - field.setAccessible(true); + private Map getCacheMap(LocalCryptoMaterialsCache cache) throws Exception { + Field field = LocalCryptoMaterialsCache.class.getDeclaredField("cacheMap"); + field.setAccessible(true); - return (Map)field.get(cache); - } + return (Map) field.get(cache); + } - private static final class FakeClock implements MsClock { - long now = 0x1_0000_0000L; - - @Override public long timestamp() { - return now; - } + private static final class FakeClock implements MsClock { + long now = 0x1_0000_0000L; + + @Override + public long timestamp() { + return now; } + } } - diff --git a/src/test/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCacheThreadStormTest.java b/src/test/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCacheThreadStormTest.java index 804b148ac..338f725e8 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCacheThreadStormTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/caching/LocalCryptoMaterialsCacheThreadStormTest.java @@ -1,11 +1,12 @@ package com.amazonaws.encryptionsdk.caching; -import static com.amazonaws.encryptionsdk.caching.CacheTestFixtures.createMaterialsResult; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.amazonaws.encryptionsdk.DataKey; +import com.amazonaws.encryptionsdk.caching.CryptoMaterialsCache.UsageStats; +import com.amazonaws.encryptionsdk.model.DecryptionMaterials; +import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import java.lang.reflect.Field; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -25,324 +26,339 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import java.util.function.Supplier; - import org.junit.Test; -import com.amazonaws.encryptionsdk.DataKey; -import com.amazonaws.encryptionsdk.caching.CryptoMaterialsCache.UsageStats; -import com.amazonaws.encryptionsdk.model.DecryptionMaterials; -import com.amazonaws.encryptionsdk.model.EncryptionMaterials; - public class LocalCryptoMaterialsCacheThreadStormTest { - /* - * This test tests the behavior of LocalCryptoMaterialsCache under contention at the cache level. - * We specifically test: - * - * 1. Gets and puts of encrypt and decrypt entries, including entries under the same cache ID for encrypt - * 2. Invalidations - * 3. Changes to cache capacity - * - * Periodically, we verify that the system state is sane. This is done by inspecting the private members of - * LocalCryptoMaterialsCache and verifying that all cache entries are in the LRU map. - */ - - // Private member accessors - private static final Function> get_cacheMap; - private static final Function> get_expirationQueue; - - private static Function getGetter(Class klass, String fieldName) { + /* + * This test tests the behavior of LocalCryptoMaterialsCache under contention at the cache level. + * We specifically test: + * + * 1. Gets and puts of encrypt and decrypt entries, including entries under the same cache ID for encrypt + * 2. Invalidations + * 3. Changes to cache capacity + * + * Periodically, we verify that the system state is sane. This is done by inspecting the private members of + * LocalCryptoMaterialsCache and verifying that all cache entries are in the LRU map. + */ + + // Private member accessors + private static final Function> get_cacheMap; + private static final Function> get_expirationQueue; + + private static Function getGetter(Class klass, String fieldName) { + try { + Field f = klass.getDeclaredField(fieldName); + f.setAccessible(true); + + return obj -> { try { - Field f = klass.getDeclaredField(fieldName); - f.setAccessible(true); - - return obj -> { - try { - return (R)f.get(obj); - } catch (Exception e) { - throw new RuntimeException(e); - } - }; + return (R) f.get(obj); } catch (Exception e) { - throw new Error(e); + throw new RuntimeException(e); } + }; + } catch (Exception e) { + throw new Error(e); } + } - static { - get_cacheMap = getGetter(LocalCryptoMaterialsCache.class, "cacheMap"); - get_expirationQueue = getGetter(LocalCryptoMaterialsCache.class, "expirationQueue"); - } + static { + get_cacheMap = getGetter(LocalCryptoMaterialsCache.class, "cacheMap"); + get_expirationQueue = getGetter(LocalCryptoMaterialsCache.class, "expirationQueue"); + } - public static void assertConsistent(LocalCryptoMaterialsCache cache) { - synchronized (cache) { - HashSet expirationQueue = new HashSet<>(get_expirationQueue.apply(cache)); - HashSet cacheMap = new HashSet<>(get_cacheMap.apply(cache).values()); + public static void assertConsistent(LocalCryptoMaterialsCache cache) { + synchronized (cache) { + HashSet expirationQueue = new HashSet<>(get_expirationQueue.apply(cache)); + HashSet cacheMap = new HashSet<>(get_cacheMap.apply(cache).values()); - assertEquals("Cache group entries are inconsistent with expiration queue", - cacheMap, expirationQueue); - } + assertEquals( + "Cache group entries are inconsistent with expiration queue", cacheMap, expirationQueue); } + } - LocalCryptoMaterialsCache cache; + LocalCryptoMaterialsCache cache; - // When barrier request = true, all worker threads will join the barrier twice. - CyclicBarrier barrier; - volatile boolean barrierRequest = false; - CountDownLatch stopRequest = new CountDownLatch(1); + // When barrier request = true, all worker threads will join the barrier twice. + CyclicBarrier barrier; + volatile boolean barrierRequest = false; + CountDownLatch stopRequest = new CountDownLatch(1); - // Decrypt results that _might_ be returned. Note that due to race conditions in the test itself, we might be - // missing valid cached values here; if a result is in neither forbiddenKeys nor possibleDecrypts, then we must - // assume that it's allowed to be returned. - ConcurrentHashMap> possibleDecrypts = new ConcurrentHashMap<>(); + // Decrypt results that _might_ be returned. Note that due to race conditions in the test itself, + // we might be + // missing valid cached values here; if a result is in neither forbiddenKeys nor possibleDecrypts, + // then we must + // assume that it's allowed to be returned. + ConcurrentHashMap> + possibleDecrypts = new ConcurrentHashMap<>(); - // The values of the inner map are arbitrary but non-null (we use this effectively like a set) - ConcurrentHashMap> possibleEncrypts = new ConcurrentHashMap<>(); + // The values of the inner map are arbitrary but non-null (we use this effectively like a set) + ConcurrentHashMap> + possibleEncrypts = new ConcurrentHashMap<>(); - // Counters for debugging the test itself. If null, this debug infrastructure is disabled. - private ConcurrentHashMap counters = null; //new ConcurrentHashMap<>(); - void inc(String s) { - if (counters != null) { - counters.computeIfAbsent(s, ignored -> new AtomicLong(0)).incrementAndGet(); - } - } + // Counters for debugging the test itself. If null, this debug infrastructure is disabled. + private ConcurrentHashMap counters = null; // new ConcurrentHashMap<>(); - private static final EncryptionMaterials BASE_ENCRYPT = CacheTestFixtures.createMaterialsResult(); - private static final DecryptionMaterials BASE_DECRYPT - = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); - - private void maybeBarrier() { - if (barrierRequest) { - try { - barrier.await(); - barrier.await(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } + void inc(String s) { + if (counters != null) { + counters.computeIfAbsent(s, ignored -> new AtomicLong(0)).incrementAndGet(); } + } + + private static final EncryptionMaterials BASE_ENCRYPT = CacheTestFixtures.createMaterialsResult(); + private static final DecryptionMaterials BASE_DECRYPT = + CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0)); + + private void maybeBarrier() { + if (barrierRequest) { + try { + barrier.await(); + barrier.await(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + // This thread continually adds items to the decrypt cache, logging ones it added. + // The expectedDecryptMap has multiple items because we don't know if the cache expired the prior + // one; the + // decrypt check thread will check and forget/forbid the expected items that were not found. + public void decryptAddThread() { + int nItemsBeforeRelax = 200_000; + int nItems = 0; + + try { + while (stopRequest.getCount() > 0) { + maybeBarrier(); + + byte[] ref = new byte[3]; + ThreadLocalRandom.current().nextBytes(ref); + ref[0] = 0; + + CacheTestFixtures.SentinelKey key = new CacheTestFixtures.SentinelKey(); + DecryptionMaterials result = + BASE_DECRYPT.toBuilder() + .setDataKey( + new DataKey( + key, new byte[0], new byte[0], BASE_DECRYPT.getDataKey().getMasterKey())) + .build(); + + ConcurrentHashMap expectedDecryptMap = + possibleDecrypts.computeIfAbsent( + ByteBuffer.wrap(ref), ignored -> new ConcurrentHashMap<>()); + + synchronized (expectedDecryptMap) { + cache.putEntryForDecrypt(ref, result, () -> Long.MAX_VALUE); + expectedDecryptMap.put(key, this); + } - // This thread continually adds items to the decrypt cache, logging ones it added. - // The expectedDecryptMap has multiple items because we don't know if the cache expired the prior one; the - // decrypt check thread will check and forget/forbid the expected items that were not found. - public void decryptAddThread() { - int nItemsBeforeRelax = 200_000; - int nItems = 0; - - try { - while (stopRequest.getCount() > 0) { - maybeBarrier(); - - byte[] ref = new byte[3]; - ThreadLocalRandom.current().nextBytes(ref); - ref[0] = 0; - - CacheTestFixtures.SentinelKey key = new CacheTestFixtures.SentinelKey(); - DecryptionMaterials result = BASE_DECRYPT.toBuilder().setDataKey( - new DataKey(key, new byte[0], new byte[0], BASE_DECRYPT.getDataKey().getMasterKey()) - ).build(); - - ConcurrentHashMap expectedDecryptMap - = possibleDecrypts.computeIfAbsent(ByteBuffer.wrap(ref), - ignored -> new ConcurrentHashMap<>()); - - synchronized (expectedDecryptMap) { - cache.putEntryForDecrypt(ref, result, () -> Long.MAX_VALUE); - expectedDecryptMap.put(key, this); - } - - inc("decrypt put"); + inc("decrypt put"); - if (++nItems >= nItemsBeforeRelax) { - Thread.sleep(5); - nItems = 0; - } - } - } catch (Exception e) { - throw new RuntimeException(e); + if (++nItems >= nItemsBeforeRelax) { + Thread.sleep(5); + nItems = 0; } + } + } catch (Exception e) { + throw new RuntimeException(e); } - - // The decrypt check thread verifies that the decrypt results are sane - specifically, if we don't see an item - // that is known to have once been added to the cache, we should not see it reappear later. - public void decryptCheckThread() { - try { - while (stopRequest.getCount() > 0) { - maybeBarrier(); - - byte[] ref = new byte[3]; - ThreadLocalRandom.current().nextBytes(ref); - ref[0] = 0; - - ConcurrentHashMap expectedDecryptMap - = possibleDecrypts.computeIfAbsent(ByteBuffer.wrap(ref), - ignored -> new ConcurrentHashMap<>()); - - synchronized (expectedDecryptMap) { - CryptoMaterialsCache.DecryptCacheEntry result = cache.getEntryForDecrypt(ref); - - CacheTestFixtures.SentinelKey cachedKey = null; - if (result != null) { - inc("decrypt: hit"); - cachedKey = (CacheTestFixtures.SentinelKey) result.getResult().getDataKey().getKey(); - if (expectedDecryptMap.containsKey(cachedKey)) { - inc("decrypt: found key in expected"); - } else { - fail("decrypt: unexpected key"); - } - } else { - inc("decrypt: miss"); - } - - for (CacheTestFixtures.SentinelKey expectedKey : expectedDecryptMap.keySet()) { - if (cachedKey != expectedKey) { - inc("decrypt: prune"); - expectedDecryptMap.remove(expectedKey); - } - } - } + } + + // The decrypt check thread verifies that the decrypt results are sane - specifically, if we don't + // see an item + // that is known to have once been added to the cache, we should not see it reappear later. + public void decryptCheckThread() { + try { + while (stopRequest.getCount() > 0) { + maybeBarrier(); + + byte[] ref = new byte[3]; + ThreadLocalRandom.current().nextBytes(ref); + ref[0] = 0; + + ConcurrentHashMap expectedDecryptMap = + possibleDecrypts.computeIfAbsent( + ByteBuffer.wrap(ref), ignored -> new ConcurrentHashMap<>()); + + synchronized (expectedDecryptMap) { + CryptoMaterialsCache.DecryptCacheEntry result = cache.getEntryForDecrypt(ref); + + CacheTestFixtures.SentinelKey cachedKey = null; + if (result != null) { + inc("decrypt: hit"); + cachedKey = (CacheTestFixtures.SentinelKey) result.getResult().getDataKey().getKey(); + if (expectedDecryptMap.containsKey(cachedKey)) { + inc("decrypt: found key in expected"); + } else { + fail("decrypt: unexpected key"); } - } catch (Exception e) { - throw new RuntimeException(e); + } else { + inc("decrypt: miss"); + } + + for (CacheTestFixtures.SentinelKey expectedKey : expectedDecryptMap.keySet()) { + if (cachedKey != expectedKey) { + inc("decrypt: prune"); + expectedDecryptMap.remove(expectedKey); + } + } } + } + } catch (Exception e) { + throw new RuntimeException(e); } + } + + // Continually adds encryption cache entries. + public void encryptAddThread() { + int nItemsBeforeRelax = 200_000; + int nItems = 0; + + try { + while (stopRequest.getCount() > 0) { + maybeBarrier(); + + byte[] ref = new byte[2]; + ThreadLocalRandom.current().nextBytes(ref); + + EncryptionMaterials result = + BASE_ENCRYPT.toBuilder() + .setCleartextDataKey(new CacheTestFixtures.SentinelKey()) + .build(); + ConcurrentHashMap keys = + possibleEncrypts.computeIfAbsent( + ByteBuffer.wrap(ref), ignored -> new ConcurrentHashMap<>()); + synchronized (keys) { + inc("encrypt: add"); + + cache.putEntryForEncrypt(ref, result, () -> Long.MAX_VALUE, UsageStats.ZERO); + keys.put((CacheTestFixtures.SentinelKey) result.getCleartextDataKey(), this); + } - // Continually adds encryption cache entries. - public void encryptAddThread() { - int nItemsBeforeRelax = 200_000; - int nItems = 0; - - try { - while (stopRequest.getCount() > 0) { - maybeBarrier(); - - byte[] ref = new byte[2]; - ThreadLocalRandom.current().nextBytes(ref); - - EncryptionMaterials result = BASE_ENCRYPT.toBuilder().setCleartextDataKey(new CacheTestFixtures.SentinelKey()).build(); - ConcurrentHashMap keys - = possibleEncrypts.computeIfAbsent(ByteBuffer.wrap(ref), - ignored -> new ConcurrentHashMap<>()); - synchronized (keys) { - inc("encrypt: add"); - - cache.putEntryForEncrypt(ref, result, () -> Long.MAX_VALUE, UsageStats.ZERO); - keys.put((CacheTestFixtures.SentinelKey) result.getCleartextDataKey(), this); - } - - if (++nItems >= nItemsBeforeRelax) { - Thread.sleep(5); - nItems = 0; - } - } - } catch (Exception e) { - throw new RuntimeException(e); + if (++nItems >= nItemsBeforeRelax) { + Thread.sleep(5); + nItems = 0; } + } + } catch (Exception e) { + throw new RuntimeException(e); } - - // Verifies that there is no resurrection, as above. - public void encryptCheckThread() { - try { - while (stopRequest.getCount() > 0) { - maybeBarrier(); - - byte[] ref = new byte[2]; - ThreadLocalRandom.current().nextBytes(ref); - - ConcurrentHashMap allowedKeys - = possibleEncrypts.computeIfAbsent(ByteBuffer.wrap(ref), - ignored -> new ConcurrentHashMap<>()); - - synchronized (allowedKeys) { - HashSet foundKeys = new HashSet<>(); - CryptoMaterialsCache.EncryptCacheEntry ece = cache.getEntryForEncrypt(ref, UsageStats.ZERO); - - if (ece != null) { - foundKeys.add((CacheTestFixtures.SentinelKey)ece.getResult().getCleartextDataKey()); - } - - if (foundKeys.isEmpty()) { - inc("encrypt check: empty foundRefs"); - } else { - inc("encrypt check: non-empty foundRefs"); - } - - foundKeys.forEach(foundKey -> { - if (!allowedKeys.containsKey(foundKey)) { - fail("encrypt check: unexpected key; " + allowedKeys + " " + foundKeys); - } - }); - - allowedKeys.keySet().forEach(allowedKey -> { - if (!foundKeys.contains(allowedKey)) { - inc("encrypt check: prune"); - // safe since this is a concurrent map - allowedKeys.remove(allowedKey); - } - }); + } + + // Verifies that there is no resurrection, as above. + public void encryptCheckThread() { + try { + while (stopRequest.getCount() > 0) { + maybeBarrier(); + + byte[] ref = new byte[2]; + ThreadLocalRandom.current().nextBytes(ref); + + ConcurrentHashMap allowedKeys = + possibleEncrypts.computeIfAbsent( + ByteBuffer.wrap(ref), ignored -> new ConcurrentHashMap<>()); + + synchronized (allowedKeys) { + HashSet foundKeys = new HashSet<>(); + CryptoMaterialsCache.EncryptCacheEntry ece = + cache.getEntryForEncrypt(ref, UsageStats.ZERO); + + if (ece != null) { + foundKeys.add((CacheTestFixtures.SentinelKey) ece.getResult().getCleartextDataKey()); + } + + if (foundKeys.isEmpty()) { + inc("encrypt check: empty foundRefs"); + } else { + inc("encrypt check: non-empty foundRefs"); + } + + foundKeys.forEach( + foundKey -> { + if (!allowedKeys.containsKey(foundKey)) { + fail("encrypt check: unexpected key; " + allowedKeys + " " + foundKeys); } - } - } catch (Exception e) { - throw new RuntimeException(e); + }); + + allowedKeys + .keySet() + .forEach( + allowedKey -> { + if (!foundKeys.contains(allowedKey)) { + inc("encrypt check: prune"); + // safe since this is a concurrent map + allowedKeys.remove(allowedKey); + } + }); } + } + } catch (Exception e) { + throw new RuntimeException(e); } - - // Performs a consistency check of the cache entries vs the LRU tracker periodically. Due to the high overhead - // of this test, we run it infrequently. - public void checkThread() { - try { - while (!stopRequest.await(5000, TimeUnit.MILLISECONDS)) { - barrierRequest = true; - barrier.await(); - - assertConsistent(cache); - inc("consistency check passed"); - - barrier.await(); - } - } catch (Exception e) { - throw new RuntimeException(e); - } + } + + // Performs a consistency check of the cache entries vs the LRU tracker periodically. Due to the + // high overhead + // of this test, we run it infrequently. + public void checkThread() { + try { + while (!stopRequest.await(5000, TimeUnit.MILLISECONDS)) { + barrierRequest = true; + barrier.await(); + + assertConsistent(cache); + inc("consistency check passed"); + + barrier.await(); + } + } catch (Exception e) { + throw new RuntimeException(e); } + } - @Test - public void test() throws Exception { - cache = new LocalCryptoMaterialsCache(100_000); + @Test + public void test() throws Exception { + cache = new LocalCryptoMaterialsCache(100_000); - ArrayList> futures = new ArrayList<>(); - ExecutorService es = Executors.newCachedThreadPool(); + ArrayList> futures = new ArrayList<>(); + ExecutorService es = Executors.newCachedThreadPool(); - ArrayList>> starters = new ArrayList<>(); + ArrayList>> starters = new ArrayList<>(); - for (int i = 0; i < 2; i++) { - starters.add(() -> CompletableFuture.runAsync(this::encryptAddThread, es)); - starters.add(() -> CompletableFuture.runAsync(this::encryptCheckThread, es)); - starters.add(() -> CompletableFuture.runAsync(this::decryptAddThread, es)); - starters.add(() -> CompletableFuture.runAsync(this::decryptCheckThread, es)); - } - starters.add(() -> CompletableFuture.runAsync(this::checkThread, es)); + for (int i = 0; i < 2; i++) { + starters.add(() -> CompletableFuture.runAsync(this::encryptAddThread, es)); + starters.add(() -> CompletableFuture.runAsync(this::encryptCheckThread, es)); + starters.add(() -> CompletableFuture.runAsync(this::decryptAddThread, es)); + starters.add(() -> CompletableFuture.runAsync(this::decryptCheckThread, es)); + } + starters.add(() -> CompletableFuture.runAsync(this::checkThread, es)); - barrier = new CyclicBarrier(starters.size()); + barrier = new CyclicBarrier(starters.size()); - try { - starters.forEach(s -> futures.add(s.get())); + try { + starters.forEach(s -> futures.add(s.get())); - CompletableFuture metaFuture = CompletableFuture.anyOf(futures.toArray(new CompletableFuture[0])); + CompletableFuture metaFuture = + CompletableFuture.anyOf(futures.toArray(new CompletableFuture[0])); - try { - metaFuture.get(10, TimeUnit.SECONDS); - fail("unexpected termination"); - } catch (TimeoutException e) { - // ok - } - } finally { - stopRequest.countDown(); - es.shutdownNow(); + try { + metaFuture.get(10, TimeUnit.SECONDS); + fail("unexpected termination"); + } catch (TimeoutException e) { + // ok + } + } finally { + stopRequest.countDown(); + es.shutdownNow(); - es.awaitTermination(1, TimeUnit.SECONDS); + es.awaitTermination(1, TimeUnit.SECONDS); - if (counters != null) { - new TreeMap<>(counters).forEach((k, v) -> System.out.println(String.format("%s: %d", k, v.get()))); - } - } + if (counters != null) { + new TreeMap<>(counters) + .forEach((k, v) -> System.out.println(String.format("%s: %d", k, v.get()))); + } } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/caching/NullCryptoMaterialsCacheTest.java b/src/test/java/com/amazonaws/encryptionsdk/caching/NullCryptoMaterialsCacheTest.java index ca3f028a2..d5739d9f9 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/caching/NullCryptoMaterialsCacheTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/caching/NullCryptoMaterialsCacheTest.java @@ -3,47 +3,45 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import org.junit.Test; - -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; -import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; +import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import org.junit.Test; public class NullCryptoMaterialsCacheTest { - @Test - public void testEncryptPath() { - NullCryptoMaterialsCache cache = new NullCryptoMaterialsCache(); - - EncryptionMaterialsRequest req = CacheTestFixtures.createMaterialsRequest(1); - EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(req); - - CryptoMaterialsCache.UsageStats stats = new CryptoMaterialsCache.UsageStats(123, 456); - CryptoMaterialsCache.EncryptCacheEntry entry = cache.putEntryForEncrypt( - new byte[1], result, () -> Long.MAX_VALUE, - stats); - assertEquals(result, entry.getResult()); - assertFalse(entry.getEntryCreationTime() > System.currentTimeMillis()); - assertEquals(stats, entry.getUsageStats());; - - // the entry should not be in the "cache" - byte[] cacheId = new byte[1]; - assertNull(cache.getEntryForEncrypt(cacheId, CryptoMaterialsCache.UsageStats.ZERO)); - - entry.invalidate(); // shouldn't throw - } - - @Test - public void testDecryptPath() { - NullCryptoMaterialsCache cache = new NullCryptoMaterialsCache(); - - DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(1); - DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); - - assertNull(cache.getEntryForDecrypt(new byte[1])); - cache.putEntryForDecrypt(new byte[1], result, () -> Long.MAX_VALUE); - assertNull(cache.getEntryForDecrypt(new byte[1])); - } + @Test + public void testEncryptPath() { + NullCryptoMaterialsCache cache = new NullCryptoMaterialsCache(); + + EncryptionMaterialsRequest req = CacheTestFixtures.createMaterialsRequest(1); + EncryptionMaterials result = CacheTestFixtures.createMaterialsResult(req); + + CryptoMaterialsCache.UsageStats stats = new CryptoMaterialsCache.UsageStats(123, 456); + CryptoMaterialsCache.EncryptCacheEntry entry = + cache.putEntryForEncrypt(new byte[1], result, () -> Long.MAX_VALUE, stats); + assertEquals(result, entry.getResult()); + assertFalse(entry.getEntryCreationTime() > System.currentTimeMillis()); + assertEquals(stats, entry.getUsageStats()); + ; + + // the entry should not be in the "cache" + byte[] cacheId = new byte[1]; + assertNull(cache.getEntryForEncrypt(cacheId, CryptoMaterialsCache.UsageStats.ZERO)); + + entry.invalidate(); // shouldn't throw + } + + @Test + public void testDecryptPath() { + NullCryptoMaterialsCache cache = new NullCryptoMaterialsCache(); + + DecryptionMaterialsRequest request = CacheTestFixtures.createDecryptRequest(1); + DecryptionMaterials result = CacheTestFixtures.createDecryptResult(request); + + assertNull(cache.getEntryForDecrypt(new byte[1])); + cache.putEntryForDecrypt(new byte[1], result, () -> Long.MAX_VALUE); + assertNull(cache.getEntryForDecrypt(new byte[1])); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/AwsKmsCmkArnInfoTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/AwsKmsCmkArnInfoTest.java index 47f82253b..4bf14076b 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/AwsKmsCmkArnInfoTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/AwsKmsCmkArnInfoTest.java @@ -3,487 +3,442 @@ package com.amazonaws.encryptionsdk.internal; +import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + import org.junit.Test; import org.junit.experimental.runners.Enclosed; import org.junit.jupiter.api.DisplayName; import org.junit.runner.RunWith; -import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - @RunWith(Enclosed.class) public class AwsKmsCmkArnInfoTest { - public static class splitArn { - @Test - public void basic_use() { - String[] test = AwsKmsCmkArnInfo.AwsKmsArnParts.splitArn("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"); - assertEquals(test.length, 6); - } - - @Test - public void with_fewer_elements() { - String[] test = AwsKmsCmkArnInfo.AwsKmsArnParts.splitArn("arn:aws:kms:us-west-2"); - assertEquals(test.length, 4); - } - - @Test - public void with_valid_arn_but_not_kms_valid() { - String[] test = AwsKmsCmkArnInfo.AwsKmsArnParts.splitArn("arn:aws:kms:us-west-2:111122223333:key:mrk-edb7fe6942894d32ac46dbb1c922d574"); - assertEquals(test.length, 6); - } + public static class splitArn { + @Test + public void basic_use() { + String[] test = + AwsKmsCmkArnInfo.AwsKmsArnParts.splitArn( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"); + assertEquals(test.length, 6); + } + + @Test + public void with_fewer_elements() { + String[] test = AwsKmsCmkArnInfo.AwsKmsArnParts.splitArn("arn:aws:kms:us-west-2"); + assertEquals(test.length, 4); + } + + @Test + public void with_valid_arn_but_not_kms_valid() { + String[] test = + AwsKmsCmkArnInfo.AwsKmsArnParts.splitArn( + "arn:aws:kms:us-west-2:111122223333:key:mrk-edb7fe6942894d32ac46dbb1c922d574"); + assertEquals(test.length, 6); + } + } + + public static class splitResourceParts { + @Test + public void basic_use() { + String[] test = + AwsKmsCmkArnInfo.AwsKmsArnParts.Resource.splitResourceParts( + "key/mrk-edb7fe6942894d32ac46dbb1c922d574"); + assertEquals(test.length, 2); + } + } + + public static class parseInfoFromKeyArn { + @Test + public void basic_use() { + AwsKmsCmkArnInfo test = + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"); + assertNotNull(test); + assertEquals(test.getPartition(), "aws"); + assertEquals(test.getRegion(), "us-west-2"); + assertEquals(test.getAccountId(), "111122223333"); + assertEquals(test.getAccountId(), "111122223333"); + assertEquals(test.getResourceType(), "key"); + assertEquals(test.getResource(), "mrk-edb7fe6942894d32ac46dbb1c922d574"); + } + + @Test + @DisplayName("Precondition: keyArn must be a string.") + public void keyArn_must_be_string_with_content() { + assertEquals(AwsKmsCmkArnInfo.parseInfoFromKeyArn(""), null); + assertEquals(AwsKmsCmkArnInfo.parseInfoFromKeyArn(null), null); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // = type=test + // # MUST start with string "arn" + public void not_well_formed() { + assertEquals( + AwsKmsCmkArnInfo.parseInfoFromKeyArn("key/mrk-edb7fe6942894d32ac46dbb1c922d574"), null); + assertEquals(AwsKmsCmkArnInfo.parseInfoFromKeyArn("alias/my-key"), null); + assertEquals( + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "not-an-arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + null); } - public static class splitResourceParts { - @Test - public void basic_use() { - String[] test = AwsKmsCmkArnInfo.AwsKmsArnParts.Resource.splitResourceParts("key/mrk-edb7fe6942894d32ac46dbb1c922d574"); - assertEquals(test.length, 2); - } + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // = type=test + // # The service MUST be the string "kms" + public void not_kms_service() { + assertEquals( + AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:sqs:us-east-2:444455556666:queue1"), null); } - public static class parseInfoFromKeyArn { - @Test - public void basic_use() { - AwsKmsCmkArnInfo test = AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"); - assertNotNull(test); - assertEquals(test.getPartition(), "aws"); - assertEquals(test.getRegion(), "us-west-2"); - assertEquals(test.getAccountId(), "111122223333"); - assertEquals(test.getAccountId(), "111122223333"); - assertEquals(test.getResourceType(), "key"); - assertEquals(test.getResource(), "mrk-edb7fe6942894d32ac46dbb1c922d574"); - } - - @Test - @DisplayName("Precondition: keyArn must be a string.") - public void keyArn_must_be_string_with_content() { - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn(""), - null - ); - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn(null), - null - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //= type=test - //# MUST start with string "arn" - public void not_well_formed() { - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn("key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - null - ); - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn("alias/my-key"), - null - ); - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn("not-an-arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - null - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //= type=test - //# The service MUST be the string "kms" - public void not_kms_service() { - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:sqs:us-east-2:444455556666:queue1"), - null - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //= type=test - //# The partition MUST be a non-empty - public void partition_non_empty() { - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn::kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - null - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //= type=test - //# The region MUST be a non-empty string - public void region_non_empty() { - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms::111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - null - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //= type=test - //# The account MUST be a non-empty string - public void account_non_empty() { - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2::key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - null - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //= type=test - //# The resource section MUST be non-empty and MUST be split by a - //# single "/" any additional "/" are included in the resource id - public void resource_non_empty() { - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //= type=test - //# The resource id MUST be a non-empty string - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:"), - null - ); - assertEquals( - // This is a valid ARN but not valid for AWS KMS - AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:key:mrk-edb7fe6942894d32ac46dbb1c922d574"), - null - ); - final AwsKmsCmkArnInfo arn = AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:alias/has/slashes"); - assertNotNull(arn); - assertEquals(arn.getResource(), "has/slashes"); - } - - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 - //= type=test - //# The resource type MUST be either "alias" or "key" - public void resource_type_key_or_alias() { - assertEquals( - AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:not-key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - null - ); - } + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // = type=test + // # The partition MUST be a non-empty + public void partition_non_empty() { + assertEquals( + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn::kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + null); } - public static class validAwsKmsIdentifier { - - @Test - public void basic_use() { - AwsKmsCmkArnInfo.validAwsKmsIdentifier("mrk-edb7fe6942894d32ac46dbb1c922d574"); - AwsKmsCmkArnInfo.validAwsKmsIdentifier("alias/my-alias"); - AwsKmsCmkArnInfo.validAwsKmsIdentifier("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"); - AwsKmsCmkArnInfo.validAwsKmsIdentifier("arn:aws:kms:us-west-2:111122223333:alias/my-alias"); - } - - @Test - @DisplayName("Exceptional Postcondition: Null or empty string is not a valid identifier.") - public void must_have_content() { - assertThrows( - IllegalArgumentException.class, - "Null or empty string is not a valid Aws KMS identifier.", - () -> AwsKmsCmkArnInfo.validAwsKmsIdentifier("")); - assertThrows( - IllegalArgumentException.class, - "Null or empty string is not a valid Aws KMS identifier.", - () -> AwsKmsCmkArnInfo.validAwsKmsIdentifier(null)); - } - - @Test - @DisplayName("Exceptional Postcondition: Things that start with `arn:` MUST be ARNs.") - public void arn_must_be_arn() { - assertThrows( - IllegalArgumentException.class, - "Invalid ARN used as an identifier.", - () -> AwsKmsCmkArnInfo.validAwsKmsIdentifier("arn:aws:dynamodb:us-east-2:123456789012:table/myDynamoDBTable")); - } - - @Test - @DisplayName("Postcondition: Raw alias starts with `alias/`.") - public void alias_is_valid() { - AwsKmsCmkArnInfo.validAwsKmsIdentifier("alias/some/kind/of/alias"); - } - - @Test - @DisplayName("Postcondition: There are no requirements on key ids.") - public void anything_else_is_key_id() { - AwsKmsCmkArnInfo.validAwsKmsIdentifier("mrk-edb7fe6942894d32ac46dbb1c922d574"); - AwsKmsCmkArnInfo.validAwsKmsIdentifier("b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"); - } + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // = type=test + // # The region MUST be a non-empty string + public void region_non_empty() { + assertEquals( + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms::111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + null); } - public static class isMRK { - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //= type=test - //# This function MUST take a single AWS KMS identifier - public void basic_use() { - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //= type=test - //# If the input starts - //# with "mrk-", this is a multi-Region key id and MUST return true. - assertEquals( - AwsKmsCmkArnInfo.isMRK("mrk-edb7fe6942894d32ac46dbb1c922d574"), - true - ); - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //= type=test - //# If the input starts with "alias/", this an AWS KMS alias and - //# not a multi-Region key id and MUST return false. - assertEquals( - AwsKmsCmkArnInfo.isMRK("alias/mrk-1234"), - false - ); - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //= type=test - //# If - //# the input does not start with any of the above, this is not a multi- - //# Region key id and MUST return false. - assertEquals( - AwsKmsCmkArnInfo.isMRK("64339c87-2ae4-42b1-8875-c83fc47acc97"), - false - ); - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 - //= type=test - //# If the input starts with "arn:", this MUST return the output of - //# identifying an an AWS KMS multi-Region ARN (aws-kms-key- - //# arn.md#identifying-an-an-aws-kms-multi-region-arn) called with this - //# input. - assertEquals( - AwsKmsCmkArnInfo.isMRK("arn:aws:kms:us-west-2:111122223333:alias/mrk-edb7fe6942894d32ac46dbb1c922d574"), - false - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //= type=test - //# If the input is an invalid AWS KMS ARN this function MUST error. - public void invalid_arn() { - assertThrows( - () -> AwsKmsCmkArnInfo.isMRK(AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:not-key/mrk-edb7fe6942894d32ac46dbb1c922d574")) - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //= type=test - //# If resource type is "alias", this is an AWS KMS alias ARN and MUST - //# return false. - public void with_an_alias_AwsKmsCmkArnInfo() { - assertEquals( - AwsKmsCmkArnInfo.isMRK(AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:alias/mrk-edb7fe6942894d32ac46dbb1c922d574")), - false - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //= type=test - //# This function MUST take a single AWS KMS ARN - public void with_an_mrk_AwsKmsCmkArnInfo() { - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //= type=test - //# If resource type is "key" and resource ID starts with - //# "mrk-", this is a AWS KMS multi-Region key ARN and MUST return true. - assertEquals( - AwsKmsCmkArnInfo.isMRK(AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574")), - true - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 - //= type=test - //# If resource type is "key" and resource ID does not start with "mrk-", - //# this is a (single-region) AWS KMS key ARN and MUST return false. - public void with_an_srk_AwsKmsCmkArnInfo() { - assertEquals( - AwsKmsCmkArnInfo.isMRK(AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f")), - false - ); - } + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // = type=test + // # The account MUST be a non-empty string + public void account_non_empty() { + assertEquals( + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms:us-west-2::key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + null); } - public static class awsKmsArnMatchForDecrypt { - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 - //= type=test - //# The caller MUST provide: - public void basic_use() { - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - ), - true - ); - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - ), - true - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 - //= type=test - //# If both identifiers are identical, this function MUST return "true". - public void string_match_cases() { - - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - ), - true - ); - - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-west-2:111122223333:key/64339c87-2ae4-42b1-8875-c83fc47acc97", - "arn:aws:kms:us-west-2:111122223333:key/64339c87-2ae4-42b1-8875-c83fc47acc97" - ), - true - ); - - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-west-2:111122223333:alias/my-name", - "arn:aws:kms:us-west-2:111122223333:alias/my-name" - ), - true - ); - - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "alias/my-raw-alias", - "alias/my-raw-alias" - ), - true - ); - - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "64339c87-2ae4-42b1-8875-c83fc47acc97", - "64339c87-2ae4-42b1-8875-c83fc47acc97" - ), - true - ); - - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "c83fc47acc97", - "64339c87-2ae4-42b1-8875-c83fc47acc97" - ), - false - ); - } - - @Test - @DisplayName("Check for early return (Postcondition): Both identifiers are not ARNs and not equal, therefore they can not match.") - public void flexibility_for_only_arns() { - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "mrk-edb7fe6942894d32ac46dbb1c922d574" - ), - false - ); - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - ), - false - ); - } - - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 - //= type=test - //# Otherwise if either input is not identified as a multi-Region key - //# (aws-kms-key-arn.md#identifying-an-aws-kms-multi-region-key), then - //# this function MUST return "false". - public void no_flexibility_for_non_mrks() { - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-west-2:111122223333:key/64339c87-2ae4-42b1-8875-c83fc47acc97", - "arn:aws:kms:us-east-1:111122223333:key/64339c87-2ae4-42b1-8875-c83fc47acc97" - ), - false - ); - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-west-2:111122223333:alias/mrk-someOtherName", - "arn:aws:kms:us-east-1:111122223333:alias/mrk-someOtherName" - ), - false - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 - //= type=test - //# Otherwise if both inputs are - //# identified as a multi-Region keys (aws-kms-key-arn.md#identifying-an- - //# aws-kms-multi-region-key), this function MUST return the result of - //# comparing the "partition", "service", "accountId", "resourceType", - //# and "resource" parts of both ARN inputs. - public void all_elements_must_match() { - // Different partition - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:not-aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - ), - false - ); - // Different account - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-east-1:333322221111:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - ), - false - ); - // Different resource type - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:not-aws:kms:us-east-1:111122223333:not-key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - ), - false - ); - // Different resource - assertEquals( - AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( - "arn:aws:kms:us-east-1:111122223333:key/mrk-475d229c1bbd64ca23d4982496ef7bde", - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - ), - false - ); - } + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // = type=test + // # The resource section MUST be non-empty and MUST be split by a + // # single "/" any additional "/" are included in the resource id + public void resource_non_empty() { + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // = type=test + // # The resource id MUST be a non-empty string + assertEquals( + AwsKmsCmkArnInfo.parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:"), null); + assertEquals( + // This is a valid ARN but not valid for AWS KMS + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:key:mrk-edb7fe6942894d32ac46dbb1c922d574"), + null); + final AwsKmsCmkArnInfo arn = + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:alias/has/slashes"); + assertNotNull(arn); + assertEquals(arn.getResource(), "has/slashes"); } - public static class to_string_tests { - @Test - public void basic_use() { - final String arn = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final String region = "us-east-1"; - final AwsKmsCmkArnInfo test = AwsKmsCmkArnInfo.parseInfoFromKeyArn(arn); - - assertEquals(arn, test.toString()); - assertEquals( - "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - test.toString("us-east-1") - ); - } + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.5 + // = type=test + // # The resource type MUST be either "alias" or "key" + public void resource_type_key_or_alias() { + assertEquals( + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:not-key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + null); } + } + public static class validAwsKmsIdentifier { + + @Test + public void basic_use() { + AwsKmsCmkArnInfo.validAwsKmsIdentifier("mrk-edb7fe6942894d32ac46dbb1c922d574"); + AwsKmsCmkArnInfo.validAwsKmsIdentifier("alias/my-alias"); + AwsKmsCmkArnInfo.validAwsKmsIdentifier( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"); + AwsKmsCmkArnInfo.validAwsKmsIdentifier("arn:aws:kms:us-west-2:111122223333:alias/my-alias"); + } + + @Test + @DisplayName("Exceptional Postcondition: Null or empty string is not a valid identifier.") + public void must_have_content() { + assertThrows( + IllegalArgumentException.class, + "Null or empty string is not a valid Aws KMS identifier.", + () -> AwsKmsCmkArnInfo.validAwsKmsIdentifier("")); + assertThrows( + IllegalArgumentException.class, + "Null or empty string is not a valid Aws KMS identifier.", + () -> AwsKmsCmkArnInfo.validAwsKmsIdentifier(null)); + } + + @Test + @DisplayName("Exceptional Postcondition: Things that start with `arn:` MUST be ARNs.") + public void arn_must_be_arn() { + assertThrows( + IllegalArgumentException.class, + "Invalid ARN used as an identifier.", + () -> + AwsKmsCmkArnInfo.validAwsKmsIdentifier( + "arn:aws:dynamodb:us-east-2:123456789012:table/myDynamoDBTable")); + } + + @Test + @DisplayName("Postcondition: Raw alias starts with `alias/`.") + public void alias_is_valid() { + AwsKmsCmkArnInfo.validAwsKmsIdentifier("alias/some/kind/of/alias"); + } + + @Test + @DisplayName("Postcondition: There are no requirements on key ids.") + public void anything_else_is_key_id() { + AwsKmsCmkArnInfo.validAwsKmsIdentifier("mrk-edb7fe6942894d32ac46dbb1c922d574"); + AwsKmsCmkArnInfo.validAwsKmsIdentifier("b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"); + } + } + + public static class isMRK { + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // = type=test + // # This function MUST take a single AWS KMS identifier + public void basic_use() { + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // = type=test + // # If the input starts + // # with "mrk-", this is a multi-Region key id and MUST return true. + assertEquals(AwsKmsCmkArnInfo.isMRK("mrk-edb7fe6942894d32ac46dbb1c922d574"), true); + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // = type=test + // # If the input starts with "alias/", this an AWS KMS alias and + // # not a multi-Region key id and MUST return false. + assertEquals(AwsKmsCmkArnInfo.isMRK("alias/mrk-1234"), false); + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // = type=test + // # If + // # the input does not start with any of the above, this is not a multi- + // # Region key id and MUST return false. + assertEquals(AwsKmsCmkArnInfo.isMRK("64339c87-2ae4-42b1-8875-c83fc47acc97"), false); + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.9 + // = type=test + // # If the input starts with "arn:", this MUST return the output of + // # identifying an an AWS KMS multi-Region ARN (aws-kms-key- + // # arn.md#identifying-an-an-aws-kms-multi-region-arn) called with this + // # input. + assertEquals( + AwsKmsCmkArnInfo.isMRK( + "arn:aws:kms:us-west-2:111122223333:alias/mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // = type=test + // # If the input is an invalid AWS KMS ARN this function MUST error. + public void invalid_arn() { + assertThrows( + () -> + AwsKmsCmkArnInfo.isMRK( + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:not-key/mrk-edb7fe6942894d32ac46dbb1c922d574"))); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // = type=test + // # If resource type is "alias", this is an AWS KMS alias ARN and MUST + // # return false. + public void with_an_alias_AwsKmsCmkArnInfo() { + assertEquals( + AwsKmsCmkArnInfo.isMRK( + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:alias/mrk-edb7fe6942894d32ac46dbb1c922d574")), + false); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // = type=test + // # This function MUST take a single AWS KMS ARN + public void with_an_mrk_AwsKmsCmkArnInfo() { + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // = type=test + // # If resource type is "key" and resource ID starts with + // # "mrk-", this is a AWS KMS multi-Region key ARN and MUST return true. + assertEquals( + AwsKmsCmkArnInfo.isMRK( + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574")), + true); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-key-arn.txt#2.8 + // = type=test + // # If resource type is "key" and resource ID does not start with "mrk-", + // # this is a (single-region) AWS KMS key ARN and MUST return false. + public void with_an_srk_AwsKmsCmkArnInfo() { + assertEquals( + AwsKmsCmkArnInfo.isMRK( + AwsKmsCmkArnInfo.parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f")), + false); + } + } + + public static class awsKmsArnMatchForDecrypt { + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 + // = type=test + // # The caller MUST provide: + public void basic_use() { + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + true); + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + true); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 + // = type=test + // # If both identifiers are identical, this function MUST return "true". + public void string_match_cases() { + + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + true); + + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-west-2:111122223333:key/64339c87-2ae4-42b1-8875-c83fc47acc97", + "arn:aws:kms:us-west-2:111122223333:key/64339c87-2ae4-42b1-8875-c83fc47acc97"), + true); + + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-west-2:111122223333:alias/my-name", + "arn:aws:kms:us-west-2:111122223333:alias/my-name"), + true); + + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt("alias/my-raw-alias", "alias/my-raw-alias"), + true); + + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "64339c87-2ae4-42b1-8875-c83fc47acc97", "64339c87-2ae4-42b1-8875-c83fc47acc97"), + true); + + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "c83fc47acc97", "64339c87-2ae4-42b1-8875-c83fc47acc97"), + false); + } + + @Test + @DisplayName( + "Check for early return (Postcondition): Both identifiers are not ARNs and not equal, therefore they can not match.") + public void flexibility_for_only_arns() { + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 + // = type=test + // # Otherwise if either input is not identified as a multi-Region key + // # (aws-kms-key-arn.md#identifying-an-aws-kms-multi-region-key), then + // # this function MUST return "false". + public void no_flexibility_for_non_mrks() { + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-west-2:111122223333:key/64339c87-2ae4-42b1-8875-c83fc47acc97", + "arn:aws:kms:us-east-1:111122223333:key/64339c87-2ae4-42b1-8875-c83fc47acc97"), + false); + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-west-2:111122223333:alias/mrk-someOtherName", + "arn:aws:kms:us-east-1:111122223333:alias/mrk-someOtherName"), + false); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-match-for-decrypt.txt#2.5 + // = type=test + // # Otherwise if both inputs are + // # identified as a multi-Region keys (aws-kms-key-arn.md#identifying-an- + // # aws-kms-multi-region-key), this function MUST return the result of + // # comparing the "partition", "service", "accountId", "resourceType", + // # and "resource" parts of both ARN inputs. + public void all_elements_must_match() { + // Different partition + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:not-aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + // Different account + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-east-1:333322221111:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + // Different resource type + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:not-aws:kms:us-east-1:111122223333:not-key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + // Different resource + assertEquals( + AwsKmsCmkArnInfo.awsKmsArnMatchForDecrypt( + "arn:aws:kms:us-east-1:111122223333:key/mrk-475d229c1bbd64ca23d4982496ef7bde", + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + } + } + + public static class to_string_tests { + @Test + public void basic_use() { + final String arn = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final String region = "us-east-1"; + final AwsKmsCmkArnInfo test = AwsKmsCmkArnInfo.parseInfoFromKeyArn(arn); + + assertEquals(arn, test.toString()); + assertEquals( + "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + test.toString("us-east-1")); + } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/BlockDecryptionHandlerTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/BlockDecryptionHandlerTest.java index df776a1e4..60d8257c2 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/BlockDecryptionHandlerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/BlockDecryptionHandlerTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -15,97 +15,85 @@ import static org.junit.Assert.assertTrue; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.TestUtils; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import java.nio.ByteBuffer; import java.security.SecureRandom; - import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; - -import com.amazonaws.encryptionsdk.TestUtils; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import org.junit.Before; import org.junit.Test; -import com.amazonaws.encryptionsdk.AwsCrypto; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; - public class BlockDecryptionHandlerTest { - private static final SecureRandom RND = new SecureRandom(); - private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; - private final byte[] messageId_ = new byte[cryptoAlgorithm_.getMessageIdLength()]; - private final byte nonceLen_ = cryptoAlgorithm_.getNonceLen(); - private final byte[] dataKeyBytes_ = new byte[cryptoAlgorithm_.getKeyLength()]; - private final SecretKey dataKey_ = new SecretKeySpec(dataKeyBytes_, "AES"); - - private final BlockDecryptionHandler blockDecryptionHandler_ = new BlockDecryptionHandler( - dataKey_, - nonceLen_, - cryptoAlgorithm_, - messageId_); - - @Before - public void setup() { - RND.nextBytes(messageId_); - RND.nextBytes(dataKeyBytes_); - } - - @Test - public void estimateOutputSize() { - final int inLen = 1; - final int outSize = blockDecryptionHandler_.estimateOutputSize(inLen); - - // the estimated output size must at least be equal to inLen. - assertTrue(outSize >= inLen); - } - - @Test(expected= BadCiphertextException.class) - public void doFinalCalledWhileNotComplete() { - blockDecryptionHandler_.doFinal(new byte[1], 0); - } - - @Test(expected = AwsCryptoException.class) - public void decryptMaxContentLength() { - final BlockEncryptionHandler blockEncryptionHandler = new BlockEncryptionHandler( - dataKey_, - nonceLen_, - cryptoAlgorithm_, - messageId_); - final byte[] in = new byte[0]; - final int outLen = blockEncryptionHandler.estimateOutputSize(in.length); - final byte[] out = new byte[outLen]; - - blockEncryptionHandler.processBytes(in, 0, in.length, out, 0); - blockEncryptionHandler.doFinal(out, 0); - - final ByteBuffer outBuff = ByteBuffer.wrap(out); - // pull out nonce to get to content length. - final byte[] nonce = new byte[nonceLen_]; - outBuff.get(nonce); - // set content length to integer max value + 1. - outBuff.putLong(Integer.MAX_VALUE + 1L); - - final int decryptedOutLen = blockDecryptionHandler_.estimateOutputSize(outLen); - final byte[] decryptedOut = new byte[decryptedOutLen]; - blockDecryptionHandler_.processBytes(outBuff.array(), 0, outBuff.array().length, decryptedOut, 0); - } - - @Test(expected = AwsCryptoException.class) - public void processBytesCalledWhileComplete() { - final BlockEncryptionHandler blockEncryptionHandler = new BlockEncryptionHandler( - dataKey_, - nonceLen_, - cryptoAlgorithm_, - messageId_); - final byte[] in = new byte[0]; - final int outLen = blockEncryptionHandler.estimateOutputSize(in.length); - final byte[] out = new byte[outLen]; - - blockEncryptionHandler.processBytes(in, 0, in.length, out, 0); - blockEncryptionHandler.doFinal(out, 0); - - final byte[] decryptedOut = new byte[outLen]; - blockDecryptionHandler_.processBytes(out, 0, outLen, decryptedOut, 0); - blockDecryptionHandler_.processBytes(out, 0, outLen, decryptedOut, 0); - } + private static final SecureRandom RND = new SecureRandom(); + private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; + private final byte[] messageId_ = new byte[cryptoAlgorithm_.getMessageIdLength()]; + private final byte nonceLen_ = cryptoAlgorithm_.getNonceLen(); + private final byte[] dataKeyBytes_ = new byte[cryptoAlgorithm_.getKeyLength()]; + private final SecretKey dataKey_ = new SecretKeySpec(dataKeyBytes_, "AES"); + + private final BlockDecryptionHandler blockDecryptionHandler_ = + new BlockDecryptionHandler(dataKey_, nonceLen_, cryptoAlgorithm_, messageId_); + + @Before + public void setup() { + RND.nextBytes(messageId_); + RND.nextBytes(dataKeyBytes_); + } + + @Test + public void estimateOutputSize() { + final int inLen = 1; + final int outSize = blockDecryptionHandler_.estimateOutputSize(inLen); + + // the estimated output size must at least be equal to inLen. + assertTrue(outSize >= inLen); + } + + @Test(expected = BadCiphertextException.class) + public void doFinalCalledWhileNotComplete() { + blockDecryptionHandler_.doFinal(new byte[1], 0); + } + + @Test(expected = AwsCryptoException.class) + public void decryptMaxContentLength() { + final BlockEncryptionHandler blockEncryptionHandler = + new BlockEncryptionHandler(dataKey_, nonceLen_, cryptoAlgorithm_, messageId_); + final byte[] in = new byte[0]; + final int outLen = blockEncryptionHandler.estimateOutputSize(in.length); + final byte[] out = new byte[outLen]; + + blockEncryptionHandler.processBytes(in, 0, in.length, out, 0); + blockEncryptionHandler.doFinal(out, 0); + + final ByteBuffer outBuff = ByteBuffer.wrap(out); + // pull out nonce to get to content length. + final byte[] nonce = new byte[nonceLen_]; + outBuff.get(nonce); + // set content length to integer max value + 1. + outBuff.putLong(Integer.MAX_VALUE + 1L); + + final int decryptedOutLen = blockDecryptionHandler_.estimateOutputSize(outLen); + final byte[] decryptedOut = new byte[decryptedOutLen]; + blockDecryptionHandler_.processBytes( + outBuff.array(), 0, outBuff.array().length, decryptedOut, 0); + } + + @Test(expected = AwsCryptoException.class) + public void processBytesCalledWhileComplete() { + final BlockEncryptionHandler blockEncryptionHandler = + new BlockEncryptionHandler(dataKey_, nonceLen_, cryptoAlgorithm_, messageId_); + final byte[] in = new byte[0]; + final int outLen = blockEncryptionHandler.estimateOutputSize(in.length); + final byte[] out = new byte[outLen]; + + blockEncryptionHandler.processBytes(in, 0, in.length, out, 0); + blockEncryptionHandler.doFinal(out, 0); + + final byte[] decryptedOut = new byte[outLen]; + blockDecryptionHandler_.processBytes(out, 0, outLen, decryptedOut, 0); + blockDecryptionHandler_.processBytes(out, 0, outLen, decryptedOut, 0); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/BlockEncryptionHandlerTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/BlockEncryptionHandlerTest.java index 14c83a4f3..880fdc74a 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/BlockEncryptionHandlerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/BlockEncryptionHandlerTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -16,62 +16,55 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.TestUtils; +import com.amazonaws.encryptionsdk.model.CipherBlockHeaders; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; - -import com.amazonaws.encryptionsdk.TestUtils; import org.junit.Before; import org.junit.Test; -import com.amazonaws.encryptionsdk.AwsCrypto; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.model.CipherBlockHeaders; -import com.amazonaws.encryptionsdk.model.CipherFrameHeaders; - public class BlockEncryptionHandlerTest { - private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; - private final byte[] messageId_ = RandomBytesGenerator.generate(cryptoAlgorithm_.getMessageIdLength()); - private final byte nonceLen_ = cryptoAlgorithm_.getNonceLen(); - private final byte[] dataKeyBytes_ = RandomBytesGenerator.generate(cryptoAlgorithm_.getKeyLength()); - private final SecretKey encryptionKey_ = new SecretKeySpec(dataKeyBytes_, "AES"); + private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; + private final byte[] messageId_ = + RandomBytesGenerator.generate(cryptoAlgorithm_.getMessageIdLength()); + private final byte nonceLen_ = cryptoAlgorithm_.getNonceLen(); + private final byte[] dataKeyBytes_ = + RandomBytesGenerator.generate(cryptoAlgorithm_.getKeyLength()); + private final SecretKey encryptionKey_ = new SecretKeySpec(dataKeyBytes_, "AES"); - private BlockEncryptionHandler blockEncryptionHandler_; + private BlockEncryptionHandler blockEncryptionHandler_; - @Before - public void setUp() throws Exception { - blockEncryptionHandler_ = new BlockEncryptionHandler( - encryptionKey_, - nonceLen_, - cryptoAlgorithm_, - messageId_ - ); - } + @Before + public void setUp() throws Exception { + blockEncryptionHandler_ = + new BlockEncryptionHandler(encryptionKey_, nonceLen_, cryptoAlgorithm_, messageId_); + } - @Test - public void emptyOutBytes() { - final int outLen = 0; - final byte[] out = new byte[outLen]; - final int processedLen = blockEncryptionHandler_.doFinal(out, 0); - assertEquals(outLen, processedLen); - } + @Test + public void emptyOutBytes() { + final int outLen = 0; + final byte[] out = new byte[outLen]; + final int processedLen = blockEncryptionHandler_.doFinal(out, 0); + assertEquals(outLen, processedLen); + } - @Test - public void correctIVGenerated() throws Exception { - final byte[] out = new byte[1024]; - int outOff = blockEncryptionHandler_.processBytes(new byte[1], 0, 1, out, 0).getBytesWritten(); - final int processedLen = blockEncryptionHandler_.doFinal(out, outOff); + @Test + public void correctIVGenerated() throws Exception { + final byte[] out = new byte[1024]; + int outOff = blockEncryptionHandler_.processBytes(new byte[1], 0, 1, out, 0).getBytesWritten(); + final int processedLen = blockEncryptionHandler_.doFinal(out, outOff); - CipherBlockHeaders headers = new CipherBlockHeaders(); - headers.setNonceLength(cryptoAlgorithm_.getNonceLen()); - headers.deserialize(out, 0); + CipherBlockHeaders headers = new CipherBlockHeaders(); + headers.setNonceLength(cryptoAlgorithm_.getNonceLen()); + headers.deserialize(out, 0); - assertArrayEquals( - new byte[] { - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 1 - }, - headers.getNonce() - ); - } + assertArrayEquals( + new byte[] { + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 1 + }, + headers.getNonce()); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/CipherHandlerTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/CipherHandlerTest.java index adc1083f3..6f911cdb6 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/CipherHandlerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/CipherHandlerTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -15,75 +15,75 @@ import static org.junit.Assert.assertTrue; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.TestUtils; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import java.util.Arrays; import java.util.EnumSet; - import javax.crypto.Cipher; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; - -import com.amazonaws.encryptionsdk.TestUtils; import org.junit.Test; -import com.amazonaws.encryptionsdk.AwsCrypto; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; - public class CipherHandlerTest { - private final int contentLen_ = 1024; // 1KB - private final byte[] contentAad_ = "Test string AAD".getBytes(); - - @Test - public void encryptDecryptWithAllAlgos() { - for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { - assertTrue(encryptDecryptContent(cryptoAlg)); - assertTrue(encryptDecryptEmptyContent(cryptoAlg)); - } + private final int contentLen_ = 1024; // 1KB + private final byte[] contentAad_ = "Test string AAD".getBytes(); + + @Test + public void encryptDecryptWithAllAlgos() { + for (final CryptoAlgorithm cryptoAlg : EnumSet.allOf(CryptoAlgorithm.class)) { + assertTrue(encryptDecryptContent(cryptoAlg)); + assertTrue(encryptDecryptEmptyContent(cryptoAlg)); } - - @Test(expected = BadCiphertextException.class) - public void tamperCiphertext() { - final CryptoAlgorithm cryptoAlgorithm = TestUtils.DEFAULT_TEST_CRYPTO_ALG; - final byte[] content = RandomBytesGenerator.generate(contentLen_); - final byte[] keyBytes = RandomBytesGenerator.generate(cryptoAlgorithm.getKeyLength()); - final byte[] nonce = RandomBytesGenerator.generate(cryptoAlgorithm.getNonceLen()); - - final SecretKey key = new SecretKeySpec(keyBytes, cryptoAlgorithm.getKeyAlgo()); - CipherHandler cipherHandler = createCipherHandler(key, cryptoAlgorithm, Cipher.ENCRYPT_MODE); - final byte[] encryptedBytes = cipherHandler.cipherData(nonce, contentAad_, content, 0, content.length); - - encryptedBytes[0] += 1; // tamper the first byte in ciphertext - - cipherHandler = createCipherHandler(key, cryptoAlgorithm, Cipher.DECRYPT_MODE); + } + + @Test(expected = BadCiphertextException.class) + public void tamperCiphertext() { + final CryptoAlgorithm cryptoAlgorithm = TestUtils.DEFAULT_TEST_CRYPTO_ALG; + final byte[] content = RandomBytesGenerator.generate(contentLen_); + final byte[] keyBytes = RandomBytesGenerator.generate(cryptoAlgorithm.getKeyLength()); + final byte[] nonce = RandomBytesGenerator.generate(cryptoAlgorithm.getNonceLen()); + + final SecretKey key = new SecretKeySpec(keyBytes, cryptoAlgorithm.getKeyAlgo()); + CipherHandler cipherHandler = createCipherHandler(key, cryptoAlgorithm, Cipher.ENCRYPT_MODE); + final byte[] encryptedBytes = + cipherHandler.cipherData(nonce, contentAad_, content, 0, content.length); + + encryptedBytes[0] += 1; // tamper the first byte in ciphertext + + cipherHandler = createCipherHandler(key, cryptoAlgorithm, Cipher.DECRYPT_MODE); + cipherHandler.cipherData(nonce, contentAad_, encryptedBytes, 0, encryptedBytes.length); + } + + private boolean encryptDecryptContent(final CryptoAlgorithm cryptoAlgorithm) { + final byte[] content = RandomBytesGenerator.generate(contentLen_); + final byte[] result = encryptDecrypt(content, cryptoAlgorithm); + return Arrays.equals(content, result) ? true : false; + } + + private boolean encryptDecryptEmptyContent(final CryptoAlgorithm cryptoAlgorithm) { + final byte[] result = encryptDecrypt(new byte[0], cryptoAlgorithm); + return (result.length == 0) ? true : false; + } + + private byte[] encryptDecrypt(final byte[] content, final CryptoAlgorithm cryptoAlgorithm) { + final byte[] keyBytes = RandomBytesGenerator.generate(cryptoAlgorithm.getKeyLength()); + final byte[] nonce = RandomBytesGenerator.generate(cryptoAlgorithm.getNonceLen()); + + final SecretKey key = new SecretKeySpec(keyBytes, cryptoAlgorithm.getKeyAlgo()); + CipherHandler cipherHandler = createCipherHandler(key, cryptoAlgorithm, Cipher.ENCRYPT_MODE); + final byte[] encryptedBytes = + cipherHandler.cipherData(nonce, contentAad_, content, 0, content.length); + + cipherHandler = createCipherHandler(key, cryptoAlgorithm, Cipher.DECRYPT_MODE); + final byte[] decryptedBytes = cipherHandler.cipherData(nonce, contentAad_, encryptedBytes, 0, encryptedBytes.length); - } - - private boolean encryptDecryptContent(final CryptoAlgorithm cryptoAlgorithm) { - final byte[] content = RandomBytesGenerator.generate(contentLen_); - final byte[] result = encryptDecrypt(content, cryptoAlgorithm); - return Arrays.equals(content, result) ? true : false; - } - - private boolean encryptDecryptEmptyContent(final CryptoAlgorithm cryptoAlgorithm) { - final byte[] result = encryptDecrypt(new byte[0], cryptoAlgorithm); - return (result.length == 0) ? true : false; - } - private byte[] encryptDecrypt(final byte[] content, final CryptoAlgorithm cryptoAlgorithm) { - final byte[] keyBytes = RandomBytesGenerator.generate(cryptoAlgorithm.getKeyLength()); - final byte[] nonce = RandomBytesGenerator.generate(cryptoAlgorithm.getNonceLen()); + return decryptedBytes; + } - final SecretKey key = new SecretKeySpec(keyBytes, cryptoAlgorithm.getKeyAlgo()); - CipherHandler cipherHandler = createCipherHandler(key, cryptoAlgorithm, Cipher.ENCRYPT_MODE); - final byte[] encryptedBytes = cipherHandler.cipherData( nonce, contentAad_, content, 0, content.length); - - cipherHandler = createCipherHandler(key, cryptoAlgorithm, Cipher.DECRYPT_MODE); - final byte[] decryptedBytes = cipherHandler.cipherData(nonce, contentAad_, encryptedBytes, 0, encryptedBytes.length); - - return decryptedBytes; - } - - private CipherHandler createCipherHandler(final SecretKey key, final CryptoAlgorithm cryptoAlgorithm, final int mode) { - return new CipherHandler(key, mode, cryptoAlgorithm); - } + private CipherHandler createCipherHandler( + final SecretKey key, final CryptoAlgorithm cryptoAlgorithm, final int mode) { + return new CipherHandler(key, mode, cryptoAlgorithm); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/CommittedKeyTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/CommittedKeyTest.java index e470275b2..b33f0bec2 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/CommittedKeyTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/CommittedKeyTest.java @@ -3,17 +3,6 @@ package com.amazonaws.encryptionsdk.internal; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.TestUtils; -import org.bouncycastle.util.Arrays; -import org.junit.Test; - -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; -import java.nio.charset.StandardCharsets; - import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; import static com.amazonaws.encryptionsdk.TestUtils.insecureRandomBytes; import static org.junit.Assert.assertArrayEquals; @@ -21,119 +10,165 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.TestUtils; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.nio.charset.StandardCharsets; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.bouncycastle.util.Arrays; +import org.junit.Test; + public class CommittedKeyTest { - @Test - public void testGenerate() { - final CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; - SecretKeySpec secretKey = new SecretKeySpec(Utils.decodeBase64String(TestUtils.messageWithCommitKeyDEKBase64), algorithm.getDataKeyAlgo()); - CommittedKey committedKey = CommittedKey.generate(algorithm, secretKey, Utils.decodeBase64String(TestUtils.messageWithCommitKeyMessageIdBase64)); - assertNotNull(committedKey); - assertEquals(TestUtils.messageWithCommitKeyCryptoAlgorithm.getKeyAlgo(), committedKey.getKey().getAlgorithm()); - assertArrayEquals(Utils.decodeBase64String(TestUtils.messageWithCommitKeyCommitmentBase64), committedKey.getCommitment()); - } - - @Test - public void testGenerateBadNonceLen() { - final CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; - SecretKeySpec secretKey = new SecretKeySpec(Utils.decodeBase64String(TestUtils.messageWithCommitKeyDEKBase64), algorithm.getDataKeyAlgo()); - assertThrows(IllegalArgumentException.class, "Invalid nonce size", - () -> CommittedKey.generate(algorithm, secretKey, new byte[algorithm.getCommitmentNonceLength() + 1])); - } - - @Test - public void testGenerateIncorrectMismatchedKeySpecAlgorithm() { - final CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; - SecretKeySpec secretKey = new SecretKeySpec(new byte[algorithm.getDataKeyLength()],"incorrectAlgorithm"); - assertThrows(IllegalArgumentException.class, "DataKey of incorrect algorithm.", - () -> CommittedKey.generate(algorithm, secretKey, new byte[algorithm.getCommitmentNonceLength()])); - } - - @Test - public void testGenerateIncorrectDataKeyLenForAlgorithm() { - final CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; - SecretKeySpec secretKey = new SecretKeySpec(new byte[algorithm.getDataKeyLength() + 1], algorithm.getDataKeyAlgo()); - assertThrows(IllegalArgumentException.class, "DataKey of incorrect length.", - () -> CommittedKey.generate(algorithm, secretKey, new byte[algorithm.getCommitmentNonceLength()])); - } - - @Test - public void testGenerateNonCommittingAlgorithm() { - final CryptoAlgorithm algorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - SecretKeySpec secretKey = new SecretKeySpec(new byte[algorithm.getDataKeyLength()], algorithm.getDataKeyAlgo()); - assertThrows(IllegalArgumentException.class, "Algorithm does not support key commitment.", - () -> CommittedKey.generate(algorithm, secretKey, new byte[algorithm.getCommitmentNonceLength()])); - } - - @Test - public void testGenerateCommittedKeySmokeTest() throws Exception { - // This test intentionally using different techniques - // to assemble the labels and constants. - - // Commitment Nonce N1 is equal to the Message Id which is a 32 byte random value. - // Normally this needs to be cryptographically secure, but we can relax this for improved performance in testing. - final byte[] n1 = insecureRandomBytes(32); - - // Hash for HKDF is SHA-512 - final HmacKeyDerivationFunction hkdf = HmacKeyDerivationFunction.getInstance("HmacSHA512"); - - // K_R (Raw keying material, a.k.a. data key) is 256 bits (32 bytes) - // Normally this needs to be cryptographically secure, but we can relax this for improved performance in testing. - final byte[] k_r = insecureRandomBytes(32); - final SecretKey k_rKey = new SecretKeySpec(k_r, "HkdfSHA512"); // We also need K_R in this format for later use - - // Output key size for Encryption Key is 256 bits (32 bytes) - final int l_e = 32; - - // Output key size for Commitment Value is 256 bits (32 bytes) - final int l_c = 32; - - // KeyLabel is "DERIVEKEY" as UTF-8 encoded bytes - final byte[] keyLabel = "DERIVEKEY".getBytes(StandardCharsets.UTF_8); - - // CommitLabel is "COMMITKEY" as UTF-8 encoded bytes - final byte[] commitLabel = "COMMITKEY".getBytes(StandardCharsets.UTF_8); - - // PRK is HKDF-Extract(salt=N_1, initialKeyingMaterial=K_R) - hkdf.init(k_r /* IKM */, n1 /* Salt */); - - // Not final because we'll rerun this with the other algorithm - CryptoAlgorithm alg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; - // Info for K_E is Algorithm ID || KeyLabel. - // We intentionally construct this in a different manner from the tested implemention. - // This technique is harder to get wrong but less performant. - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - DataOutputStream out = new DataOutputStream(baos); - out.writeShort(alg.getValue()); - out.write(keyLabel); - out.close(); - - // K_E := HKDF-Expand(prk=PRK, info=Algorithm ID || KeyLabel, L=L_E) - byte[] k_e = hkdf.deriveKey(baos.toByteArray(), l_e); - - // K_C = HKDF-Expand(prk=PRK, info=CommitLabel, L=LC) - final byte[] k_c = hkdf.deriveKey(commitLabel, l_c); - - // Now that we have the expected values, test reality - CommittedKey committedKey = CommittedKey.generate(alg, k_rKey, n1); - assertArrayEquals("K_C for " + alg, k_c, committedKey.getCommitment()); - assertArrayEquals("K_E for " + alg, k_e, committedKey.getKey().getEncoded()); - - // Now test it with the second algorithm. - // Since the commitment value doesn't include the algorithm Id, - // K_C should remain unchanged and only K_E should vary. - alg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; - baos = new ByteArrayOutputStream(); - out = new DataOutputStream(baos); - out.writeShort(alg.getValue()); - out.write(keyLabel); - out.close(); - final byte[] k_e2 = hkdf.deriveKey(baos.toByteArray(), l_e); - - // Now that we have the expected values, test reality - committedKey = CommittedKey.generate(alg, k_rKey, n1); - assertArrayEquals("K_C for " + alg, k_c, committedKey.getCommitment()); - assertArrayEquals("K_E for " + alg, k_e2, committedKey.getKey().getEncoded()); - assertFalse("K_E must be different for different algorithms", Arrays.areEqual(k_e, k_e2)); - } + @Test + public void testGenerate() { + final CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; + SecretKeySpec secretKey = + new SecretKeySpec( + Utils.decodeBase64String(TestUtils.messageWithCommitKeyDEKBase64), + algorithm.getDataKeyAlgo()); + CommittedKey committedKey = + CommittedKey.generate( + algorithm, + secretKey, + Utils.decodeBase64String(TestUtils.messageWithCommitKeyMessageIdBase64)); + assertNotNull(committedKey); + assertEquals( + TestUtils.messageWithCommitKeyCryptoAlgorithm.getKeyAlgo(), + committedKey.getKey().getAlgorithm()); + assertArrayEquals( + Utils.decodeBase64String(TestUtils.messageWithCommitKeyCommitmentBase64), + committedKey.getCommitment()); + } + + @Test + public void testGenerateBadNonceLen() { + final CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; + SecretKeySpec secretKey = + new SecretKeySpec( + Utils.decodeBase64String(TestUtils.messageWithCommitKeyDEKBase64), + algorithm.getDataKeyAlgo()); + assertThrows( + IllegalArgumentException.class, + "Invalid nonce size", + () -> + CommittedKey.generate( + algorithm, secretKey, new byte[algorithm.getCommitmentNonceLength() + 1])); + } + + @Test + public void testGenerateIncorrectMismatchedKeySpecAlgorithm() { + final CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; + SecretKeySpec secretKey = + new SecretKeySpec(new byte[algorithm.getDataKeyLength()], "incorrectAlgorithm"); + assertThrows( + IllegalArgumentException.class, + "DataKey of incorrect algorithm.", + () -> + CommittedKey.generate( + algorithm, secretKey, new byte[algorithm.getCommitmentNonceLength()])); + } + + @Test + public void testGenerateIncorrectDataKeyLenForAlgorithm() { + final CryptoAlgorithm algorithm = TestUtils.messageWithCommitKeyCryptoAlgorithm; + SecretKeySpec secretKey = + new SecretKeySpec(new byte[algorithm.getDataKeyLength() + 1], algorithm.getDataKeyAlgo()); + assertThrows( + IllegalArgumentException.class, + "DataKey of incorrect length.", + () -> + CommittedKey.generate( + algorithm, secretKey, new byte[algorithm.getCommitmentNonceLength()])); + } + + @Test + public void testGenerateNonCommittingAlgorithm() { + final CryptoAlgorithm algorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + SecretKeySpec secretKey = + new SecretKeySpec(new byte[algorithm.getDataKeyLength()], algorithm.getDataKeyAlgo()); + assertThrows( + IllegalArgumentException.class, + "Algorithm does not support key commitment.", + () -> + CommittedKey.generate( + algorithm, secretKey, new byte[algorithm.getCommitmentNonceLength()])); + } + + @Test + public void testGenerateCommittedKeySmokeTest() throws Exception { + // This test intentionally using different techniques + // to assemble the labels and constants. + + // Commitment Nonce N1 is equal to the Message Id which is a 32 byte random value. + // Normally this needs to be cryptographically secure, but we can relax this for improved + // performance in testing. + final byte[] n1 = insecureRandomBytes(32); + + // Hash for HKDF is SHA-512 + final HmacKeyDerivationFunction hkdf = HmacKeyDerivationFunction.getInstance("HmacSHA512"); + + // K_R (Raw keying material, a.k.a. data key) is 256 bits (32 bytes) + // Normally this needs to be cryptographically secure, but we can relax this for improved + // performance in testing. + final byte[] k_r = insecureRandomBytes(32); + final SecretKey k_rKey = + new SecretKeySpec(k_r, "HkdfSHA512"); // We also need K_R in this format for later use + + // Output key size for Encryption Key is 256 bits (32 bytes) + final int l_e = 32; + + // Output key size for Commitment Value is 256 bits (32 bytes) + final int l_c = 32; + + // KeyLabel is "DERIVEKEY" as UTF-8 encoded bytes + final byte[] keyLabel = "DERIVEKEY".getBytes(StandardCharsets.UTF_8); + + // CommitLabel is "COMMITKEY" as UTF-8 encoded bytes + final byte[] commitLabel = "COMMITKEY".getBytes(StandardCharsets.UTF_8); + + // PRK is HKDF-Extract(salt=N_1, initialKeyingMaterial=K_R) + hkdf.init(k_r /* IKM */, n1 /* Salt */); + + // Not final because we'll rerun this with the other algorithm + CryptoAlgorithm alg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; + // Info for K_E is Algorithm ID || KeyLabel. + // We intentionally construct this in a different manner from the tested implemention. + // This technique is harder to get wrong but less performant. + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream out = new DataOutputStream(baos); + out.writeShort(alg.getValue()); + out.write(keyLabel); + out.close(); + + // K_E := HKDF-Expand(prk=PRK, info=Algorithm ID || KeyLabel, L=L_E) + byte[] k_e = hkdf.deriveKey(baos.toByteArray(), l_e); + + // K_C = HKDF-Expand(prk=PRK, info=CommitLabel, L=LC) + final byte[] k_c = hkdf.deriveKey(commitLabel, l_c); + + // Now that we have the expected values, test reality + CommittedKey committedKey = CommittedKey.generate(alg, k_rKey, n1); + assertArrayEquals("K_C for " + alg, k_c, committedKey.getCommitment()); + assertArrayEquals("K_E for " + alg, k_e, committedKey.getKey().getEncoded()); + + // Now test it with the second algorithm. + // Since the commitment value doesn't include the algorithm Id, + // K_C should remain unchanged and only K_E should vary. + alg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; + baos = new ByteArrayOutputStream(); + out = new DataOutputStream(baos); + out.writeShort(alg.getValue()); + out.write(keyLabel); + out.close(); + final byte[] k_e2 = hkdf.deriveKey(baos.toByteArray(), l_e); + + // Now that we have the expected values, test reality + committedKey = CommittedKey.generate(alg, k_rKey, n1); + assertArrayEquals("K_C for " + alg, k_c, committedKey.getCommitment()); + assertArrayEquals("K_E for " + alg, k_e2, committedKey.getKey().getEncoded()); + assertFalse("K_E must be different for different algorithms", Arrays.areEqual(k_e, k_e2)); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/DecryptionHandlerTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/DecryptionHandlerTest.java index 8e0f53554..7ed598115 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/DecryptionHandlerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/DecryptionHandlerTest.java @@ -3,17 +3,10 @@ package com.amazonaws.encryptionsdk.internal; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import com.amazonaws.encryptionsdk.ParsedCiphertext; -import com.amazonaws.encryptionsdk.model.CiphertextHeaders; -import org.junit.Before; -import org.junit.Test; +import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CommitmentPolicy; @@ -31,552 +24,761 @@ import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; - +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Test; -import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; - public class DecryptionHandlerTest { - private StaticMasterKey masterKeyProvider_; - private final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; - private final CommitmentPolicy requireReadPolicy = CommitmentPolicy.RequireEncryptRequireDecrypt; - private final List allowReadPolicies = Arrays.asList(CommitmentPolicy.RequireEncryptAllowDecrypt, + private StaticMasterKey masterKeyProvider_; + private final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; + private final CommitmentPolicy requireReadPolicy = CommitmentPolicy.RequireEncryptRequireDecrypt; + private final List allowReadPolicies = + Arrays.asList( + CommitmentPolicy.RequireEncryptAllowDecrypt, CommitmentPolicy.ForbidEncryptAllowDecrypt); + private final SignaturePolicy signaturePolicy = SignaturePolicy.AllowEncryptAllowDecrypt; + + @Before + public void init() { + masterKeyProvider_ = new StaticMasterKey("testmaterial"); + } + + @Test(expected = NullPointerException.class) + public void nullMasterKey() { + DecryptionHandler.create( + (MasterKey) null, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + } + + @Test + public void nullCommitment() { + final byte[] ciphertext = + getTestHeaders( + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, CommitmentPolicy.ForbidEncryptAllowDecrypt); - private final SignaturePolicy signaturePolicy = SignaturePolicy.AllowEncryptAllowDecrypt; - - @Before - public void init() { - masterKeyProvider_ = new StaticMasterKey("testmaterial"); - } - - @Test(expected = NullPointerException.class) - public void nullMasterKey() { - DecryptionHandler.create((MasterKey)null, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - } - - @Test - public void nullCommitment() { - final byte[] ciphertext = getTestHeaders(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, CommitmentPolicy.ForbidEncryptAllowDecrypt); - - assertThrows(NullPointerException.class, () -> - DecryptionHandler.create(masterKeyProvider_, new ParsedCiphertext(ciphertext), - null, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - assertThrows(NullPointerException.class, () -> - DecryptionHandler.create(masterKeyProvider_, - null, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - - @Test - public void nullSignaturePolicy() { - final byte[] ciphertext = getTestHeaders(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, CommitmentPolicy.ForbidEncryptAllowDecrypt); - - assertThrows(NullPointerException.class, () -> - DecryptionHandler.create(masterKeyProvider_, new ParsedCiphertext(ciphertext), - commitmentPolicy, null, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - assertThrows(NullPointerException.class, () -> - DecryptionHandler.create(masterKeyProvider_, - commitmentPolicy, null, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); } - - - @Test(expected = AwsCryptoException.class) - public void invalidLenProcessBytes() { - final DecryptionHandler decryptionHandler = - DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final byte[] in = new byte[1]; - final byte[] out = new byte[1]; - decryptionHandler.processBytes(in, 0, -1, out, 0); - } - - @Test(expected = AwsCryptoException.class) - public void maxLenProcessBytes() { - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - // Create input of size 3 bytes: 1 byte containing version, 1 byte - // containing type, and 1 byte containing half of the algoId short - // primitive. Only 1 byte of the algoId is provided because this - // forces the decryption handler to buffer that 1 byte while waiting for - // the other byte. We do this so we can specify an input of max - // value and the total bytes to parse will become max value + 1. - final byte[] in = new byte[3]; - final byte[] out = new byte[3]; - in[1] = CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA.getValue(); - - decryptionHandler.processBytes(in, 0, in.length, out, 0); - decryptionHandler.processBytes(in, 0, Integer.MAX_VALUE, out, 0); - } - - @Test - public void maxInputLength() { - final byte[] testMessage = getTestMessage(TestUtils.DEFAULT_TEST_CRYPTO_ALG, CommitmentPolicy.RequireEncryptRequireDecrypt); - final byte[] out = new byte[100]; - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - decryptionHandler.setMaxInputLength(testMessage.length - 1); - - assertThrows(IllegalStateException.class, () -> - decryptionHandler.processBytes(testMessage, 0, testMessage.length, out, 0)); - } - - @Test - public void maxInputLengthIncludingParsedCiphertext() { - final byte[] testMessage = getTestMessage(TestUtils.DEFAULT_TEST_CRYPTO_ALG, CommitmentPolicy.RequireEncryptRequireDecrypt); - final byte[] out = new byte[100]; - ParsedCiphertext parsedHeaders = new ParsedCiphertext(testMessage); - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, parsedHeaders, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - decryptionHandler.setMaxInputLength(testMessage.length - 1); - - assertThrows(IllegalStateException.class, () -> - decryptionHandler.processBytes(testMessage, parsedHeaders.getOffset(), testMessage.length - parsedHeaders.getOffset(), - out, 0)); - } - - @Test - public void maxInputLengthIncludingCiphertextHeaders() { - final byte[] testMessage = getTestMessage(TestUtils.DEFAULT_TEST_CRYPTO_ALG, CommitmentPolicy.RequireEncryptRequireDecrypt); - final byte[] out = new byte[100]; - ParsedCiphertext parsedHeaders = new ParsedCiphertext(testMessage); - CiphertextHeaders headers = new CiphertextHeaders(); - headers.deserialize(parsedHeaders.getCiphertext(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, headers, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - decryptionHandler.setMaxInputLength(testMessage.length - 1); - - assertThrows(IllegalStateException.class, () -> - decryptionHandler.processBytes(testMessage, parsedHeaders.getOffset(), testMessage.length - parsedHeaders.getOffset(), - out, 0)); - } - - @Test(expected = BadCiphertextException.class) - public void headerIntegrityFailure() { - byte[] ciphertext = getTestHeaders(); - - // tamper the fifth byte in the header which corresponds to the first - // byte of the message identifier. We do this because tampering the - // first four bytes will be detected as invalid values during parsing. - ciphertext[5] += 1; - - // attempt to decrypt with the tampered header. - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); - final byte[] plaintext = new byte[plaintextLen]; - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); - } - - @Test(expected = BadCiphertextException.class) - public void invalidVersion() { - byte[] ciphertext = getTestHeaders(); - - // set byte containing version to invalid value. - ciphertext[0] = 0; // NOTE: This will need to be updated should 0 ever be a valid version - - // attempt to decrypt with the tampered header. - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); - final byte[] plaintext = new byte[plaintextLen]; - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); - } - - @Test(expected = AwsCryptoException.class) - public void invalidCMK() { - final byte[] ciphertext = getTestHeaders(); - - masterKeyProvider_.setKeyId(masterKeyProvider_.getKeyId() + "nonsense"); - - // attempt to decrypt with the tampered header. - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); - final byte[] plaintext = new byte[plaintextLen]; - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); - } - - @Test - public void validAlgForCommitmentPolicyCreate() { - // ensure we can decrypt non-committing algs with the policies that allow it - final CryptoAlgorithm nonCommittingAlg = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - for (CommitmentPolicy policy : allowReadPolicies) { - final byte[] ciphertext = getTestHeaders(nonCommittingAlg, CommitmentPolicy.ForbidEncryptAllowDecrypt); - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, policy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - // expected plaintext is zero length - final byte[] plaintext = new byte[0]; - ProcessingSummary processingSummary = decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); - assertEquals(ciphertext.length, processingSummary.getBytesProcessed()); - assertArrayEquals(new byte[0], plaintext); - } - - // ensure we can decrypt committing algs with all policies - final CryptoAlgorithm committingAlg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; - for (CommitmentPolicy policy : CommitmentPolicy.values()) { - final byte[] ciphertext = getTestHeaders(committingAlg, CommitmentPolicy.RequireEncryptRequireDecrypt); - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, policy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - // expected plaintext is zero length - final byte[] plaintext = new byte[0]; - ProcessingSummary processingSummary = decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); - assertEquals(ciphertext.length, processingSummary.getBytesProcessed()); - assertArrayEquals(new byte[0], plaintext); - } - } - - @Test - public void invalidAlgForCommitmentPolicyCreateWithoutHeaders() { - final CryptoAlgorithm nonCommittingAlg = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; - final byte[] ciphertext = getTestHeaders(nonCommittingAlg, CommitmentPolicy.ForbidEncryptAllowDecrypt); - - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, requireReadPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); - final byte[] plaintext = new byte[plaintextLen]; - - assertThrows(AwsCryptoException.class, () -> decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0)); - } - - @Test - public void invalidAlgForCommitmentPolicyCreateWithHeaders() { - final CryptoAlgorithm nonCommittingAlg = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; - final byte[] ciphertext = getTestHeaders(nonCommittingAlg, CommitmentPolicy.ForbidEncryptAllowDecrypt); - - assertThrows(AwsCryptoException.class, - () -> DecryptionHandler.create(masterKeyProvider_, new ParsedCiphertext(ciphertext), requireReadPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - - @Test - public void validAlgForSignaturePolicyCreate() { - // ensure we can decrypt non-signing algs with the policy that allows it - final CryptoAlgorithm nonSigningAlg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; - final byte[] ciphertext = getTestHeaders(nonSigningAlg, commitmentPolicy); - final DecryptionHandler decryptionHandler = - DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, SignaturePolicy.AllowEncryptAllowDecrypt, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - // expected plaintext is zero length - final byte[] plaintext = new byte[0]; - ProcessingSummary processingSummary = decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); - assertEquals(ciphertext.length, processingSummary.getBytesProcessed()); - assertArrayEquals(new byte[0], plaintext); - } - - @Test - public void invalidAlgForSignaturePolicyCreateWithoutHeaders() { - final CryptoAlgorithm signingAlg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; - final byte[] ciphertext = getTestHeaders(signingAlg, commitmentPolicy); - - final DecryptionHandler decryptionHandler = - DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, SignaturePolicy.AllowEncryptForbidDecrypt, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); - final byte[] plaintext = new byte[plaintextLen]; - assertThrows(AwsCryptoException.class, () -> decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0)); - } - - @Test - public void invalidAlgForSignaturePolicyCreateWithHeaders() { - final CryptoAlgorithm signingAlg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; - final byte[] ciphertext = getTestHeaders(signingAlg, commitmentPolicy); - - assertThrows(AwsCryptoException.class, - () -> DecryptionHandler.create(masterKeyProvider_, new ParsedCiphertext(ciphertext), - commitmentPolicy, SignaturePolicy.AllowEncryptForbidDecrypt, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - - private byte[] getTestHeaders() { - return getTestHeaders(TestUtils.DEFAULT_TEST_CRYPTO_ALG, TestUtils.DEFAULT_TEST_COMMITMENT_POLICY, 1); - } - - private byte[] getTestHeaders(CryptoAlgorithm cryptoAlgorithm, CommitmentPolicy policy) { - return getTestHeaders(cryptoAlgorithm, policy, 1); - } - - private byte[] getTestHeaders(int numEdks) { - return getTestHeaders(TestUtils.DEFAULT_TEST_CRYPTO_ALG, TestUtils.DEFAULT_TEST_COMMITMENT_POLICY, numEdks); - } - - private byte[] getTestHeaders(CryptoAlgorithm cryptoAlgorithm, CommitmentPolicy policy, int numEdks) { - // Note that it's questionable to assume that failing to call doFinal() on the encryption handler - // always results in only outputting the header! - return getTestMessage(cryptoAlgorithm, policy, numEdks, false); - } - - private byte[] getTestMessage(CryptoAlgorithm cryptoAlgorithm, CommitmentPolicy policy) { - return getTestMessage(cryptoAlgorithm, policy, 1, true); - } - - private byte[] getTestMessage(CryptoAlgorithm cryptoAlgorithm, CommitmentPolicy policy, int numEdks, boolean doFinal) { - final int frameSize_ = AwsCrypto.getDefaultFrameSize(); - final Map encryptionContext = Collections. emptyMap(); - - final EncryptionMaterialsRequest encryptionMaterialsRequest = EncryptionMaterialsRequest.newBuilder() - .setContext(encryptionContext) - .setRequestedAlgorithm(cryptoAlgorithm) - .setCommitmentPolicy(policy) - .build(); - - List> providers = new ArrayList<>(); - for (int i = 0; i < numEdks; i++) { - providers.add(masterKeyProvider_); - } - MasterKeyProvider provider = MultipleProviderFactory.buildMultiProvider(providers); - - final EncryptionMaterials encryptionMaterials = new DefaultCryptoMaterialsManager(provider) - .getMaterialsForEncrypt(encryptionMaterialsRequest); - - final EncryptionHandler encryptionHandler = new EncryptionHandler(frameSize_, encryptionMaterials, policy); - - // create the ciphertext headers by calling encryption handler. - final byte[] in = new byte[0]; - final int ciphertextLen = encryptionHandler.estimateOutputSize(in.length); - final byte[] ciphertext = new byte[ciphertextLen]; - ProcessingSummary summary = encryptionHandler.processBytes(in, 0, in.length, ciphertext, 0); - if (doFinal) { - encryptionHandler.doFinal(ciphertext, summary.getBytesWritten()); - } - return ciphertext; - } - - @Test(expected = AwsCryptoException.class) - public void invalidOffsetProcessBytes() { - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final byte[] in = new byte[1]; - final byte[] out = new byte[1]; - decryptionHandler.processBytes(in, -1, in.length, out, 0); - } - - @Test(expected = BadCiphertextException.class) - public void incompleteCiphertext() { - byte[] ciphertext = getTestHeaders(); - - CiphertextHeaders h = new CiphertextHeaders(); - h.deserialize(ciphertext, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, commitmentPolicy, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final byte[] out = new byte[1]; - - // Note the " - 1" is a bit deceptive: the ciphertext SHOULD already be incomplete because we - // called getTestHeaders() above, so the whole body is missing! - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length - 1, out, 0); - decryptionHandler.doFinal(out, 0); - } - - @Test - public void incompleteCiphertextV2() { - byte[] ciphertext = Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64); - final DecryptionHandler decryptionHandler = DecryptionHandler.create( - TestUtils.messageWithCommitKeyMasterKey, - CommitmentPolicy.RequireEncryptRequireDecrypt, - signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final byte[] out = new byte[1]; - - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length - 1, out, 0); - assertThrows(BadCiphertextException.class, "Unable to process entire ciphertext.", - () -> decryptionHandler.doFinal(out, 0)); - } - - @Test - public void incompleteCiphertextSigned() { - byte[] ciphertext = getTestMessage(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, - CommitmentPolicy.RequireEncryptRequireDecrypt); - final DecryptionHandler decryptionHandler = DecryptionHandler.create( + assertThrows( + NullPointerException.class, + () -> + DecryptionHandler.create( masterKeyProvider_, - CommitmentPolicy.RequireEncryptRequireDecrypt, + new ParsedCiphertext(ciphertext), + null, signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final byte[] out = new byte[1]; - - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length - 1, out, 0); - assertThrows(BadCiphertextException.class, "Unable to process entire ciphertext.", - () -> decryptionHandler.doFinal(out, 0)); - } - - @Test - public void headerV2HeaderIntegrityFailure() { - byte[] ciphertext = Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64); - - // Tamper the bytes that corresponds to the frame length. - // This is the only reasonable way to tamper with this handcrafted message's - // header which can still be successfully parsed. - ciphertext[134] += 1; - - // attempt to decrypt with the tampered header. - final DecryptionHandler decryptionHandler = DecryptionHandler.create( - TestUtils.messageWithCommitKeyMasterKey, - CommitmentPolicy.RequireEncryptRequireDecrypt, - signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); - final byte[] plaintext = new byte[plaintextLen]; - assertThrows(BadCiphertextException.class, "Header integrity check failed", () -> - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0)); - } - - @Test - public void headerV2BodyIntegrityFailure() { - byte[] ciphertext = Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64); - - // Tamper the bytes that corresponds to the body auth - ciphertext[ciphertext.length - 1] += 1; - - // attempt to decrypt with the tampered header. - final DecryptionHandler decryptionHandler = DecryptionHandler.create( - TestUtils.messageWithCommitKeyMasterKey, - CommitmentPolicy.RequireEncryptRequireDecrypt, - signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); - final byte[] plaintext = new byte[plaintextLen]; - assertThrows(BadCiphertextException.class, "Tag mismatch", () -> - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0)); - } - - @Test - public void withLessThanMaxEdks() { - byte[] header = getTestHeaders(2); - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, CommitmentPolicy.RequireEncryptAllowDecrypt, signaturePolicy, 3); - final int plaintextLen = decryptionHandler.estimateOutputSize(header.length); - final byte[] plaintext = new byte[plaintextLen]; - decryptionHandler.processBytes(header, 0, header.length, plaintext, 0); - } - - @Test - public void withMaxEdks() { - byte[] header = getTestHeaders(3); - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, CommitmentPolicy.RequireEncryptAllowDecrypt, signaturePolicy, 3); - final int plaintextLen = decryptionHandler.estimateOutputSize(header.length); - final byte[] plaintext = new byte[plaintextLen]; - decryptionHandler.processBytes(header, 0, header.length, plaintext, 0); - } - - @Test - public void withMoreThanMaxEdks() { - byte[] header = getTestHeaders(4); - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, CommitmentPolicy.RequireEncryptAllowDecrypt, signaturePolicy, 3); - final int plaintextLen = decryptionHandler.estimateOutputSize(header.length); - final byte[] plaintext = new byte[plaintextLen]; - assertThrows(AwsCryptoException.class, "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", () -> - decryptionHandler.processBytes(header, 0, header.length, plaintext, 0) - ); - } - - @Test - public void withNoMaxEdks() { - byte[] header = getTestHeaders(1 << 16 - 1); - final DecryptionHandler decryptionHandler = DecryptionHandler.create(masterKeyProvider_, CommitmentPolicy.RequireEncryptAllowDecrypt, signaturePolicy, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final int plaintextLen = decryptionHandler.estimateOutputSize(header.length); - final byte[] plaintext = new byte[plaintextLen]; - decryptionHandler.processBytes(header, 0, header.length, plaintext, 0); - } - - public void validSignatureAcrossMultipleBlocks() { - byte[] ciphertext = getTestMessage(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, - CommitmentPolicy.RequireEncryptRequireDecrypt); - final DecryptionHandler decryptionHandler = DecryptionHandler.create( + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + assertThrows( + NullPointerException.class, + () -> + DecryptionHandler.create( masterKeyProvider_, - CommitmentPolicy.RequireEncryptRequireDecrypt, + null, signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final byte[] out = new byte[1]; - - // Parse the header and body - final int headerLength = 388; - decryptionHandler.processBytes(ciphertext, 0, headerLength, out, 0); - - // Parse the footer across two calls - // This used to fail to verify the signature because partial bytes were dropped instead. - // The overall decryption would still succeed because the completeness check in doFinal - // used to not include the footer. - // The number of bytes read in the first chunk is completely arbitrary. The - // parameterized CryptoOutputStreamTest tests covers lots of possible chunk - // sizes much more thoroughly. This is just a very explicit regression unit test for a known - // issue that is now fixed. - final int firstChunkLength = 12; - final int firstChunkOffset = headerLength; - final int secondChunkOffset = headerLength + firstChunkLength; - final int secondChunkLength = ciphertext.length - secondChunkOffset; - decryptionHandler.processBytes(ciphertext, firstChunkOffset, firstChunkLength, out, 0); - decryptionHandler.processBytes(ciphertext, secondChunkOffset, secondChunkLength, out, 0); - decryptionHandler.doFinal(out, 0); - } - - @Test - public void invalidSignatureAcrossMultipleBlocks() { - byte[] ciphertext = getTestMessage(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, - CommitmentPolicy.RequireEncryptRequireDecrypt); - final DecryptionHandler decryptionHandler = DecryptionHandler.create( - masterKeyProvider_, - CommitmentPolicy.RequireEncryptRequireDecrypt, - signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final byte[] out = new byte[1]; - - // Parse the header and body - decryptionHandler.processBytes(ciphertext, 0, 388, out, 0); - - // Process extra bytes before processing the actual signature bytes. - // This used to actually work because the handler failed to buffer the unparsed bytes - // across calls. To regression test this properly we have to parse the two bytes for the length... - decryptionHandler.processBytes(ciphertext, 388, 2, out, 0); - // ...and after that any bytes fewer than that length would previously be dropped. - decryptionHandler.processBytes(new byte[10], 0, 10, out, 0); - assertThrows(BadCiphertextException.class, "Bad trailing signature", () -> - decryptionHandler.processBytes(ciphertext, 390, ciphertext.length - 390, out, 0)); - } + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + + @Test + public void nullSignaturePolicy() { + final byte[] ciphertext = + getTestHeaders( + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + CommitmentPolicy.ForbidEncryptAllowDecrypt); - @Test - public void setMaxInputLength() { - byte[] ciphertext = getTestMessage(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, - CommitmentPolicy.RequireEncryptRequireDecrypt); - final DecryptionHandler decryptionHandler = DecryptionHandler.create( + assertThrows( + NullPointerException.class, + () -> + DecryptionHandler.create( masterKeyProvider_, - CommitmentPolicy.RequireEncryptRequireDecrypt, - signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - decryptionHandler.setMaxInputLength(ciphertext.length - 1); - - assertEquals(decryptionHandler.getMaxInputLength(), (long)ciphertext.length - 1); - - final byte[] out = new byte[1]; - assertThrows(IllegalStateException.class, "Ciphertext size exceeds size bound", () -> - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, out, 0)); - } - - @Test - public void setMaxInputLengthThrowsIfAlreadyOver() { - byte[] ciphertext = getTestMessage(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, - CommitmentPolicy.RequireEncryptRequireDecrypt); - final DecryptionHandler decryptionHandler = DecryptionHandler.create( + new ParsedCiphertext(ciphertext), + commitmentPolicy, + null, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + assertThrows( + NullPointerException.class, + () -> + DecryptionHandler.create( masterKeyProvider_, - CommitmentPolicy.RequireEncryptRequireDecrypt, - signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final byte[] out = new byte[1]; - decryptionHandler.processBytes(ciphertext, 0, ciphertext.length - 1, out, 0); - assertFalse(decryptionHandler.isComplete()); - - assertThrows(IllegalStateException.class, "Ciphertext size exceeds size bound", () -> - decryptionHandler.setMaxInputLength(ciphertext.length - 2)); - } - - @Test - public void setMaxInputLengthAcceptsSmallerValue() { - final DecryptionHandler decryptionHandler = DecryptionHandler.create( + commitmentPolicy, + null, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + + @Test(expected = AwsCryptoException.class) + public void invalidLenProcessBytes() { + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final byte[] in = new byte[1]; + final byte[] out = new byte[1]; + decryptionHandler.processBytes(in, 0, -1, out, 0); + } + + @Test(expected = AwsCryptoException.class) + public void maxLenProcessBytes() { + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + // Create input of size 3 bytes: 1 byte containing version, 1 byte + // containing type, and 1 byte containing half of the algoId short + // primitive. Only 1 byte of the algoId is provided because this + // forces the decryption handler to buffer that 1 byte while waiting for + // the other byte. We do this so we can specify an input of max + // value and the total bytes to parse will become max value + 1. + final byte[] in = new byte[3]; + final byte[] out = new byte[3]; + in[1] = CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA.getValue(); + + decryptionHandler.processBytes(in, 0, in.length, out, 0); + decryptionHandler.processBytes(in, 0, Integer.MAX_VALUE, out, 0); + } + + @Test + public void maxInputLength() { + final byte[] testMessage = + getTestMessage( + TestUtils.DEFAULT_TEST_CRYPTO_ALG, CommitmentPolicy.RequireEncryptRequireDecrypt); + final byte[] out = new byte[100]; + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + decryptionHandler.setMaxInputLength(testMessage.length - 1); + + assertThrows( + IllegalStateException.class, + () -> decryptionHandler.processBytes(testMessage, 0, testMessage.length, out, 0)); + } + + @Test + public void maxInputLengthIncludingParsedCiphertext() { + final byte[] testMessage = + getTestMessage( + TestUtils.DEFAULT_TEST_CRYPTO_ALG, CommitmentPolicy.RequireEncryptRequireDecrypt); + final byte[] out = new byte[100]; + ParsedCiphertext parsedHeaders = new ParsedCiphertext(testMessage); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + parsedHeaders, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + decryptionHandler.setMaxInputLength(testMessage.length - 1); + + assertThrows( + IllegalStateException.class, + () -> + decryptionHandler.processBytes( + testMessage, + parsedHeaders.getOffset(), + testMessage.length - parsedHeaders.getOffset(), + out, + 0)); + } + + @Test + public void maxInputLengthIncludingCiphertextHeaders() { + final byte[] testMessage = + getTestMessage( + TestUtils.DEFAULT_TEST_CRYPTO_ALG, CommitmentPolicy.RequireEncryptRequireDecrypt); + final byte[] out = new byte[100]; + ParsedCiphertext parsedHeaders = new ParsedCiphertext(testMessage); + CiphertextHeaders headers = new CiphertextHeaders(); + headers.deserialize( + parsedHeaders.getCiphertext(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + headers, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + decryptionHandler.setMaxInputLength(testMessage.length - 1); + + assertThrows( + IllegalStateException.class, + () -> + decryptionHandler.processBytes( + testMessage, + parsedHeaders.getOffset(), + testMessage.length - parsedHeaders.getOffset(), + out, + 0)); + } + + @Test(expected = BadCiphertextException.class) + public void headerIntegrityFailure() { + byte[] ciphertext = getTestHeaders(); + + // tamper the fifth byte in the header which corresponds to the first + // byte of the message identifier. We do this because tampering the + // first four bytes will be detected as invalid values during parsing. + ciphertext[5] += 1; + + // attempt to decrypt with the tampered header. + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); + final byte[] plaintext = new byte[plaintextLen]; + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); + } + + @Test(expected = BadCiphertextException.class) + public void invalidVersion() { + byte[] ciphertext = getTestHeaders(); + + // set byte containing version to invalid value. + ciphertext[0] = 0; // NOTE: This will need to be updated should 0 ever be a valid version + + // attempt to decrypt with the tampered header. + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); + final byte[] plaintext = new byte[plaintextLen]; + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); + } + + @Test(expected = AwsCryptoException.class) + public void invalidCMK() { + final byte[] ciphertext = getTestHeaders(); + + masterKeyProvider_.setKeyId(masterKeyProvider_.getKeyId() + "nonsense"); + + // attempt to decrypt with the tampered header. + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); + final byte[] plaintext = new byte[plaintextLen]; + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); + } + + @Test + public void validAlgForCommitmentPolicyCreate() { + // ensure we can decrypt non-committing algs with the policies that allow it + final CryptoAlgorithm nonCommittingAlg = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + for (CommitmentPolicy policy : allowReadPolicies) { + final byte[] ciphertext = + getTestHeaders(nonCommittingAlg, CommitmentPolicy.ForbidEncryptAllowDecrypt); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + policy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + // expected plaintext is zero length + final byte[] plaintext = new byte[0]; + ProcessingSummary processingSummary = + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); + assertEquals(ciphertext.length, processingSummary.getBytesProcessed()); + assertArrayEquals(new byte[0], plaintext); + } + + // ensure we can decrypt committing algs with all policies + final CryptoAlgorithm committingAlg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; + for (CommitmentPolicy policy : CommitmentPolicy.values()) { + final byte[] ciphertext = + getTestHeaders(committingAlg, CommitmentPolicy.RequireEncryptRequireDecrypt); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + policy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + // expected plaintext is zero length + final byte[] plaintext = new byte[0]; + ProcessingSummary processingSummary = + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); + assertEquals(ciphertext.length, processingSummary.getBytesProcessed()); + assertArrayEquals(new byte[0], plaintext); + } + } + + @Test + public void invalidAlgForCommitmentPolicyCreateWithoutHeaders() { + final CryptoAlgorithm nonCommittingAlg = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + final byte[] ciphertext = + getTestHeaders(nonCommittingAlg, CommitmentPolicy.ForbidEncryptAllowDecrypt); + + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + requireReadPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); + final byte[] plaintext = new byte[plaintextLen]; + + assertThrows( + AwsCryptoException.class, + () -> decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0)); + } + + @Test + public void invalidAlgForCommitmentPolicyCreateWithHeaders() { + final CryptoAlgorithm nonCommittingAlg = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + final byte[] ciphertext = + getTestHeaders(nonCommittingAlg, CommitmentPolicy.ForbidEncryptAllowDecrypt); + + assertThrows( + AwsCryptoException.class, + () -> + DecryptionHandler.create( masterKeyProvider_, - CommitmentPolicy.RequireEncryptRequireDecrypt, + new ParsedCiphertext(ciphertext), + requireReadPolicy, signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - decryptionHandler.setMaxInputLength(100); - assertEquals(decryptionHandler.getMaxInputLength(), 100); - - decryptionHandler.setMaxInputLength(10); - assertEquals(decryptionHandler.getMaxInputLength(), 10); - } - - @Test - public void setMaxInputLengthIgnoresLargerValue() { - final DecryptionHandler decryptionHandler = DecryptionHandler.create( + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + + @Test + public void validAlgForSignaturePolicyCreate() { + // ensure we can decrypt non-signing algs with the policy that allows it + final CryptoAlgorithm nonSigningAlg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; + final byte[] ciphertext = getTestHeaders(nonSigningAlg, commitmentPolicy); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + SignaturePolicy.AllowEncryptAllowDecrypt, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + // expected plaintext is zero length + final byte[] plaintext = new byte[0]; + ProcessingSummary processingSummary = + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0); + assertEquals(ciphertext.length, processingSummary.getBytesProcessed()); + assertArrayEquals(new byte[0], plaintext); + } + + @Test + public void invalidAlgForSignaturePolicyCreateWithoutHeaders() { + final CryptoAlgorithm signingAlg = + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; + final byte[] ciphertext = getTestHeaders(signingAlg, commitmentPolicy); + + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + SignaturePolicy.AllowEncryptForbidDecrypt, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); + final byte[] plaintext = new byte[plaintextLen]; + + assertThrows( + AwsCryptoException.class, + () -> decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0)); + } + + @Test + public void invalidAlgForSignaturePolicyCreateWithHeaders() { + final CryptoAlgorithm signingAlg = + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384; + final byte[] ciphertext = getTestHeaders(signingAlg, commitmentPolicy); + + assertThrows( + AwsCryptoException.class, + () -> + DecryptionHandler.create( masterKeyProvider_, - CommitmentPolicy.RequireEncryptRequireDecrypt, - signaturePolicy, - CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - decryptionHandler.setMaxInputLength(10); - assertEquals(decryptionHandler.getMaxInputLength(), 10); - - decryptionHandler.setMaxInputLength(100); - assertEquals(decryptionHandler.getMaxInputLength(), 10); - } + new ParsedCiphertext(ciphertext), + commitmentPolicy, + SignaturePolicy.AllowEncryptForbidDecrypt, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + + private byte[] getTestHeaders() { + return getTestHeaders( + TestUtils.DEFAULT_TEST_CRYPTO_ALG, TestUtils.DEFAULT_TEST_COMMITMENT_POLICY, 1); + } + + private byte[] getTestHeaders(CryptoAlgorithm cryptoAlgorithm, CommitmentPolicy policy) { + return getTestHeaders(cryptoAlgorithm, policy, 1); + } + + private byte[] getTestHeaders(int numEdks) { + return getTestHeaders( + TestUtils.DEFAULT_TEST_CRYPTO_ALG, TestUtils.DEFAULT_TEST_COMMITMENT_POLICY, numEdks); + } + + private byte[] getTestHeaders( + CryptoAlgorithm cryptoAlgorithm, CommitmentPolicy policy, int numEdks) { + // Note that it's questionable to assume that failing to call doFinal() on the encryption + // handler + // always results in only outputting the header! + return getTestMessage(cryptoAlgorithm, policy, numEdks, false); + } + + private byte[] getTestMessage(CryptoAlgorithm cryptoAlgorithm, CommitmentPolicy policy) { + return getTestMessage(cryptoAlgorithm, policy, 1, true); + } + + private byte[] getTestMessage( + CryptoAlgorithm cryptoAlgorithm, CommitmentPolicy policy, int numEdks, boolean doFinal) { + final int frameSize_ = AwsCrypto.getDefaultFrameSize(); + final Map encryptionContext = Collections.emptyMap(); + + final EncryptionMaterialsRequest encryptionMaterialsRequest = + EncryptionMaterialsRequest.newBuilder() + .setContext(encryptionContext) + .setRequestedAlgorithm(cryptoAlgorithm) + .setCommitmentPolicy(policy) + .build(); + + List> providers = new ArrayList<>(); + for (int i = 0; i < numEdks; i++) { + providers.add(masterKeyProvider_); + } + MasterKeyProvider provider = MultipleProviderFactory.buildMultiProvider(providers); + + final EncryptionMaterials encryptionMaterials = + new DefaultCryptoMaterialsManager(provider) + .getMaterialsForEncrypt(encryptionMaterialsRequest); + + final EncryptionHandler encryptionHandler = + new EncryptionHandler(frameSize_, encryptionMaterials, policy); + + // create the ciphertext headers by calling encryption handler. + final byte[] in = new byte[0]; + final int ciphertextLen = encryptionHandler.estimateOutputSize(in.length); + final byte[] ciphertext = new byte[ciphertextLen]; + ProcessingSummary summary = encryptionHandler.processBytes(in, 0, in.length, ciphertext, 0); + if (doFinal) { + encryptionHandler.doFinal(ciphertext, summary.getBytesWritten()); + } + return ciphertext; + } + + @Test(expected = AwsCryptoException.class) + public void invalidOffsetProcessBytes() { + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final byte[] in = new byte[1]; + final byte[] out = new byte[1]; + decryptionHandler.processBytes(in, -1, in.length, out, 0); + } + + @Test(expected = BadCiphertextException.class) + public void incompleteCiphertext() { + byte[] ciphertext = getTestHeaders(); + + CiphertextHeaders h = new CiphertextHeaders(); + h.deserialize(ciphertext, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + commitmentPolicy, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final byte[] out = new byte[1]; + + // Note the " - 1" is a bit deceptive: the ciphertext SHOULD already be incomplete because we + // called getTestHeaders() above, so the whole body is missing! + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length - 1, out, 0); + decryptionHandler.doFinal(out, 0); + } + + @Test + public void incompleteCiphertextV2() { + byte[] ciphertext = Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + TestUtils.messageWithCommitKeyMasterKey, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final byte[] out = new byte[1]; + + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length - 1, out, 0); + assertThrows( + BadCiphertextException.class, + "Unable to process entire ciphertext.", + () -> decryptionHandler.doFinal(out, 0)); + } + + @Test + public void incompleteCiphertextSigned() { + byte[] ciphertext = + getTestMessage( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, + CommitmentPolicy.RequireEncryptRequireDecrypt); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final byte[] out = new byte[1]; + + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length - 1, out, 0); + assertThrows( + BadCiphertextException.class, + "Unable to process entire ciphertext.", + () -> decryptionHandler.doFinal(out, 0)); + } + + @Test + public void headerV2HeaderIntegrityFailure() { + byte[] ciphertext = Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64); + + // Tamper the bytes that corresponds to the frame length. + // This is the only reasonable way to tamper with this handcrafted message's + // header which can still be successfully parsed. + ciphertext[134] += 1; + + // attempt to decrypt with the tampered header. + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + TestUtils.messageWithCommitKeyMasterKey, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); + final byte[] plaintext = new byte[plaintextLen]; + assertThrows( + BadCiphertextException.class, + "Header integrity check failed", + () -> decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0)); + } + + @Test + public void headerV2BodyIntegrityFailure() { + byte[] ciphertext = Utils.decodeBase64String(TestUtils.messageWithCommitKeyBase64); + + // Tamper the bytes that corresponds to the body auth + ciphertext[ciphertext.length - 1] += 1; + + // attempt to decrypt with the tampered header. + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + TestUtils.messageWithCommitKeyMasterKey, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final int plaintextLen = decryptionHandler.estimateOutputSize(ciphertext.length); + final byte[] plaintext = new byte[plaintextLen]; + assertThrows( + BadCiphertextException.class, + "Tag mismatch", + () -> decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, plaintext, 0)); + } + + @Test + public void withLessThanMaxEdks() { + byte[] header = getTestHeaders(2); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, CommitmentPolicy.RequireEncryptAllowDecrypt, signaturePolicy, 3); + final int plaintextLen = decryptionHandler.estimateOutputSize(header.length); + final byte[] plaintext = new byte[plaintextLen]; + decryptionHandler.processBytes(header, 0, header.length, plaintext, 0); + } + + @Test + public void withMaxEdks() { + byte[] header = getTestHeaders(3); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, CommitmentPolicy.RequireEncryptAllowDecrypt, signaturePolicy, 3); + final int plaintextLen = decryptionHandler.estimateOutputSize(header.length); + final byte[] plaintext = new byte[plaintextLen]; + decryptionHandler.processBytes(header, 0, header.length, plaintext, 0); + } + + @Test + public void withMoreThanMaxEdks() { + byte[] header = getTestHeaders(4); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, CommitmentPolicy.RequireEncryptAllowDecrypt, signaturePolicy, 3); + final int plaintextLen = decryptionHandler.estimateOutputSize(header.length); + final byte[] plaintext = new byte[plaintextLen]; + assertThrows( + AwsCryptoException.class, + "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", + () -> decryptionHandler.processBytes(header, 0, header.length, plaintext, 0)); + } + + @Test + public void withNoMaxEdks() { + byte[] header = getTestHeaders(1 << 16 - 1); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + CommitmentPolicy.RequireEncryptAllowDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final int plaintextLen = decryptionHandler.estimateOutputSize(header.length); + final byte[] plaintext = new byte[plaintextLen]; + decryptionHandler.processBytes(header, 0, header.length, plaintext, 0); + } + + public void validSignatureAcrossMultipleBlocks() { + byte[] ciphertext = + getTestMessage( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, + CommitmentPolicy.RequireEncryptRequireDecrypt); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final byte[] out = new byte[1]; + + // Parse the header and body + final int headerLength = 388; + decryptionHandler.processBytes(ciphertext, 0, headerLength, out, 0); + + // Parse the footer across two calls + // This used to fail to verify the signature because partial bytes were dropped instead. + // The overall decryption would still succeed because the completeness check in doFinal + // used to not include the footer. + // The number of bytes read in the first chunk is completely arbitrary. The + // parameterized CryptoOutputStreamTest tests covers lots of possible chunk + // sizes much more thoroughly. This is just a very explicit regression unit test for a known + // issue that is now fixed. + final int firstChunkLength = 12; + final int firstChunkOffset = headerLength; + final int secondChunkOffset = headerLength + firstChunkLength; + final int secondChunkLength = ciphertext.length - secondChunkOffset; + decryptionHandler.processBytes(ciphertext, firstChunkOffset, firstChunkLength, out, 0); + decryptionHandler.processBytes(ciphertext, secondChunkOffset, secondChunkLength, out, 0); + decryptionHandler.doFinal(out, 0); + } + + @Test + public void invalidSignatureAcrossMultipleBlocks() { + byte[] ciphertext = + getTestMessage( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, + CommitmentPolicy.RequireEncryptRequireDecrypt); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final byte[] out = new byte[1]; + + // Parse the header and body + decryptionHandler.processBytes(ciphertext, 0, 388, out, 0); + + // Process extra bytes before processing the actual signature bytes. + // This used to actually work because the handler failed to buffer the unparsed bytes + // across calls. To regression test this properly we have to parse the two bytes for the + // length... + decryptionHandler.processBytes(ciphertext, 388, 2, out, 0); + // ...and after that any bytes fewer than that length would previously be dropped. + decryptionHandler.processBytes(new byte[10], 0, 10, out, 0); + assertThrows( + BadCiphertextException.class, + "Bad trailing signature", + () -> decryptionHandler.processBytes(ciphertext, 390, ciphertext.length - 390, out, 0)); + } + + @Test + public void setMaxInputLength() { + byte[] ciphertext = + getTestMessage( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, + CommitmentPolicy.RequireEncryptRequireDecrypt); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + decryptionHandler.setMaxInputLength(ciphertext.length - 1); + + assertEquals(decryptionHandler.getMaxInputLength(), (long) ciphertext.length - 1); + + final byte[] out = new byte[1]; + assertThrows( + IllegalStateException.class, + "Ciphertext size exceeds size bound", + () -> decryptionHandler.processBytes(ciphertext, 0, ciphertext.length, out, 0)); + } + + @Test + public void setMaxInputLengthThrowsIfAlreadyOver() { + byte[] ciphertext = + getTestMessage( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY_ECDSA_P384, + CommitmentPolicy.RequireEncryptRequireDecrypt); + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final byte[] out = new byte[1]; + decryptionHandler.processBytes(ciphertext, 0, ciphertext.length - 1, out, 0); + assertFalse(decryptionHandler.isComplete()); + + assertThrows( + IllegalStateException.class, + "Ciphertext size exceeds size bound", + () -> decryptionHandler.setMaxInputLength(ciphertext.length - 2)); + } + + @Test + public void setMaxInputLengthAcceptsSmallerValue() { + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + decryptionHandler.setMaxInputLength(100); + assertEquals(decryptionHandler.getMaxInputLength(), 100); + + decryptionHandler.setMaxInputLength(10); + assertEquals(decryptionHandler.getMaxInputLength(), 10); + } + + @Test + public void setMaxInputLengthIgnoresLargerValue() { + final DecryptionHandler decryptionHandler = + DecryptionHandler.create( + masterKeyProvider_, + CommitmentPolicy.RequireEncryptRequireDecrypt, + signaturePolicy, + CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + decryptionHandler.setMaxInputLength(10); + assertEquals(decryptionHandler.getMaxInputLength(), 10); + + decryptionHandler.setMaxInputLength(100); + assertEquals(decryptionHandler.getMaxInputLength(), 10); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/EncContextSerializerTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/EncContextSerializerTest.java index e1a47bc7b..92930df05 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/EncContextSerializerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/EncContextSerializerTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -16,396 +16,404 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import java.nio.ByteBuffer; import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.UUID; - import org.junit.Test; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; - public class EncContextSerializerTest { - @Test - public void nullContext() { - final byte[] ctxBytes = EncryptionContextSerializer.serialize(null); - final Map result = EncryptionContextSerializer.deserialize(ctxBytes); - assertEquals(null, result); - } - - @Test - public void emptyContext() { - testMap(Collections. emptyMap()); - } - - @Test - public void singletonContext() { - testMap(Collections.singletonMap("Alice:", "trusts Bob")); + @Test + public void nullContext() { + final byte[] ctxBytes = EncryptionContextSerializer.serialize(null); + final Map result = EncryptionContextSerializer.deserialize(ctxBytes); + assertEquals(null, result); + } + + @Test + public void emptyContext() { + testMap(Collections.emptyMap()); + } + + @Test + public void singletonContext() { + testMap(Collections.singletonMap("Alice:", "trusts Bob")); + } + + @Test + public void contextOrdering() throws Exception { + // Context keys should be sorted by unsigned byte order + Map map = new HashMap<>(); + + map.put("\0", "\0"); + map.put("\u0081", "\u0081"); // 0xC2 0x81 in UTF8 + + assertArrayEquals( + new byte[] { + 0, + 2, + // "\0" + 0, + 1, + (byte) '\0', + // "\0" + 0, + 1, + (byte) '\0', + // "\u0081" + 0, + 2, + (byte) 0xC2, + (byte) 0x81, + // "\u0081" + 0, + 2, + (byte) 0xC2, + (byte) 0x81, + }, + EncryptionContextSerializer.serialize(map)); + } + + @Test + public void smallContext() { + final Map map = new HashMap(); + map.put("Alice:", "trusts Bob"); + map.put("Bob:", "trusts Trent"); + testMap(map); + } + + @Test + public void largeContext() { + final int size = 100; + final Map ctx = new HashMap(size); + for (int x = 0; x < size; x++) { + ctx.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); } - - @Test - public void contextOrdering() throws Exception { - // Context keys should be sorted by unsigned byte order - Map map = new HashMap<>(); - - map.put("\0", "\0"); - map.put("\u0081", "\u0081"); // 0xC2 0x81 in UTF8 - - assertArrayEquals( - new byte[] { - 0, 2, - // "\0" - 0, 1, (byte)'\0', - // "\0" - 0, 1, (byte)'\0', - // "\u0081" - 0, 2, (byte)0xC2, (byte)0x81, - // "\u0081" - 0, 2, (byte)0xC2, (byte)0x81, - }, - EncryptionContextSerializer.serialize(map) - ); - } - - @Test - public void smallContext() { - final Map map = new HashMap(); - map.put("Alice:", "trusts Bob"); - map.put("Bob:", "trusts Trent"); - testMap(map); - } - - @Test - public void largeContext() { - final int size = 100; - final Map ctx = new HashMap(size); - for (int x = 0; x < size; x++) { - ctx.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); - } - testMap(ctx); + testMap(ctx); + } + + @Test(expected = AwsCryptoException.class) + public void overlyLargeContext() { + final int size = Short.MAX_VALUE; + final Map ctx = new HashMap(size); + // we want to be at least 1 over the (max) size. + for (int x = 0; x <= size; x++) { + ctx.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); } - - @Test(expected = AwsCryptoException.class) - public void overlyLargeContext() { - final int size = Short.MAX_VALUE; - final Map ctx = new HashMap(size); - // we want to be at least 1 over the (max) size. - for (int x = 0; x <= size; x++) { - ctx.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); - } - testMap(ctx); + testMap(ctx); + } + + @Test(expected = AwsCryptoException.class) + public void overlyLargeKey() { + final int size = 10; + final Map ctx = new HashMap(size); + final char[] keyChars = new char[Short.MAX_VALUE + 1]; + final String key = new String(keyChars); + + for (int x = 0; x < size; x++) { + ctx.put(key, UUID.randomUUID().toString()); } - - @Test(expected = AwsCryptoException.class) - public void overlyLargeKey() { - final int size = 10; - final Map ctx = new HashMap(size); - final char[] keyChars = new char[Short.MAX_VALUE + 1]; - final String key = new String(keyChars); - - for (int x = 0; x < size; x++) { - ctx.put(key, UUID.randomUUID().toString()); - } - testMap(ctx); + testMap(ctx); + } + + @Test(expected = AwsCryptoException.class) + public void overlyLargeValue() { + final int size = 10; + final Map ctx = new HashMap(size); + final char[] valueChars = new char[Short.MAX_VALUE + 1]; + final String value = new String(valueChars); + + for (int x = 0; x < size; x++) { + ctx.put(UUID.randomUUID().toString(), value); } - - @Test(expected = AwsCryptoException.class) - public void overlyLargeValue() { - final int size = 10; - final Map ctx = new HashMap(size); - final char[] valueChars = new char[Short.MAX_VALUE + 1]; - final String value = new String(valueChars); - - for (int x = 0; x < size; x++) { - ctx.put(UUID.randomUUID().toString(), value); - } - testMap(ctx); + testMap(ctx); + } + + @Test(expected = AwsCryptoException.class) + public void overlyLargeContextBytes() { + final char[] keyChars = new char[Short.MAX_VALUE]; + final String key = new String(keyChars); + final char[] valueChars = new char[Short.MAX_VALUE]; + final String value = new String(valueChars); + + testMap(Collections.singletonMap(key, value)); + } + + @Test(expected = IllegalArgumentException.class) + public void contextWithBadUnicodeKey() { + final StringBuilder invalidString = new StringBuilder("Valid text"); + // Loop over invalid unicode codepoints + for (int x = 0xd800; x <= 0xdfff; x++) { + invalidString.appendCodePoint(x); } - - @Test(expected = AwsCryptoException.class) - public void overlyLargeContextBytes() { - final char[] keyChars = new char[Short.MAX_VALUE]; - final String key = new String(keyChars); - final char[] valueChars = new char[Short.MAX_VALUE]; - final String value = new String(valueChars); - - testMap(Collections.singletonMap(key, value)); + testMap(Collections.singletonMap(invalidString.toString(), "Valid value")); + } + + @Test(expected = IllegalArgumentException.class) + public void contextWithBadUnicodeValue() { + final StringBuilder invalidString = new StringBuilder("Base valid text"); + for (int x = 0xd800; x <= 0xdfff; x++) { // Invalid unicode codepoints + invalidString.appendCodePoint(x); } + testMap(Collections.singletonMap("Valid key", invalidString.toString())); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithEmptyKey() { + testMap(Collections.singletonMap("", "Value for empty key")); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithEmptyValue() { + testMap(Collections.singletonMap("Key for empty value", "")); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithEmptyKeyAndValue() { + testMap(Collections.singletonMap("", "")); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithNullKey() { + testMap(Collections.singletonMap((String) null, "value for null key")); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithNullValue() { + testMap(Collections.singletonMap("Key for null value", (String) null)); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithNullKeyAndValue() { + testMap(Collections.singletonMap((String) null, (String) null)); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithLargeKey() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); + + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + + // Pull out entry count to move to key pos + ctxBuff.getShort(); + // Overwrite key length + ctxBuff.putShort((short) Constants.UNSIGNED_SHORT_MAX_VAL); + + // The actual call which should fail + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithShortKey() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); + + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + + // Pull out entry count to move to key pos + ctxBuff.getShort(); + // Overwrite key length with 0 + ctxBuff.putShort((short) 0); + + // The actual call which should fail + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithNegativeKey() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); + + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + + // Pull out entry count to move to key pos + ctxBuff.getShort(); + // Overwrite key length with -1. + ctxBuff.putShort((short) -1); + + // The actual call which should fail + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } + + @Test(expected = AwsCryptoException.class) + public void contextWithLargeValue() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); + + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + + // Pull out entry count to move to key pos + ctxBuff.getShort(); - @Test(expected = IllegalArgumentException.class) - public void contextWithBadUnicodeKey() { - final StringBuilder invalidString = new StringBuilder("Valid text"); - // Loop over invalid unicode codepoints - for (int x = 0xd800; x <= 0xdfff; x++) { - invalidString.appendCodePoint(x); - } - testMap(Collections.singletonMap(invalidString.toString(), "Valid value")); - } + // Pull out key length and bytes. + final short keyLen = ctxBuff.getShort(); + final byte[] key = new byte[keyLen]; + ctxBuff.get(key); + + // Overwrite value length + ctxBuff.putShort((short) Constants.UNSIGNED_SHORT_MAX_VAL); + + // The actual call which should fail + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } - @Test(expected = IllegalArgumentException.class) - public void contextWithBadUnicodeValue() { - final StringBuilder invalidString = new StringBuilder("Base valid text"); - for (int x = 0xd800; x <= 0xdfff; x++) { // Invalid unicode codepoints - invalidString.appendCodePoint(x); - } - testMap(Collections.singletonMap("Valid key", invalidString.toString())); - } + @Test(expected = AwsCryptoException.class) + public void contextWithShortValue() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); - @Test(expected = AwsCryptoException.class) - public void contextWithEmptyKey() { - testMap(Collections.singletonMap("", "Value for empty key")); - } + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - @Test(expected = AwsCryptoException.class) - public void contextWithEmptyValue() { - testMap(Collections.singletonMap("Key for empty value", "")); - } + // Pull out entry count to move to key pos + ctxBuff.getShort(); - @Test(expected = AwsCryptoException.class) - public void contextWithEmptyKeyAndValue() { - testMap(Collections.singletonMap("", "")); - } + // Pull out key length and bytes. + final short keyLen = ctxBuff.getShort(); + final byte[] key = new byte[keyLen]; + ctxBuff.get(key); - @Test(expected = AwsCryptoException.class) - public void contextWithNullKey() { - testMap(Collections.singletonMap((String) null, "value for null key")); - } + // Overwrite value length + ctxBuff.putShort((short) 0); - @Test(expected = AwsCryptoException.class) - public void contextWithNullValue() { - testMap(Collections.singletonMap("Key for null value", (String) null)); - } + // The actual call which should fail + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } - @Test(expected = AwsCryptoException.class) - public void contextWithNullKeyAndValue() { - testMap(Collections.singletonMap((String) null, (String) null)); - } + @Test(expected = AwsCryptoException.class) + public void contextWithNegativeValue() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); - @Test(expected = AwsCryptoException.class) - public void contextWithLargeKey() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + // Pull out entry count to move to key pos + ctxBuff.getShort(); - // Pull out entry count to move to key pos - ctxBuff.getShort(); - // Overwrite key length - ctxBuff.putShort((short) Constants.UNSIGNED_SHORT_MAX_VAL); + // Pull out key length and bytes. + final short keyLen = ctxBuff.getShort(); + final byte[] key = new byte[keyLen]; + ctxBuff.get(key); - // The actual call which should fail - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } + // Overwrite value length + ctxBuff.putShort((short) -1); - @Test(expected = AwsCryptoException.class) - public void contextWithShortKey() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); + // The actual call which should fail + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + @Test(expected = AwsCryptoException.class) + public void contextWithNegativeCount() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); + ctx.put("Bob:", "trusts Trent"); - // Pull out entry count to move to key pos - ctxBuff.getShort(); - // Overwrite key length with 0 - ctxBuff.putShort((short) 0); + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - // The actual call which should fail - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } + // Overwrite entry count + ctxBuff.putShort((short) -1); - @Test(expected = AwsCryptoException.class) - public void contextWithNegativeKey() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + @Test(expected = AwsCryptoException.class) + public void contextWithZeroCount() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); + ctx.put("Bob:", "trusts Trent"); - // Pull out entry count to move to key pos - ctxBuff.getShort(); - // Overwrite key length with -1. - ctxBuff.putShort((short) -1); + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - // The actual call which should fail - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } + // Overwrite entry count + ctxBuff.putShort((short) 0); - @Test(expected = AwsCryptoException.class) - public void contextWithLargeValue() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + @Test(expected = AwsCryptoException.class) + public void contextWithInvalidCount() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); + ctx.put("Bob:", "trusts Trent"); - // Pull out entry count to move to key pos - ctxBuff.getShort(); + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - // Pull out key length and bytes. - final short keyLen = ctxBuff.getShort(); - final byte[] key = new byte[keyLen]; - ctxBuff.get(key); + // Overwrite count with more than what we have + ctxBuff.putShort((short) 100); - // Overwrite value length - ctxBuff.putShort((short) Constants.UNSIGNED_SHORT_MAX_VAL); + // The actual call which should fail + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } - // The actual call which should fail - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } + @Test(expected = IllegalArgumentException.class) + public void contextWithInvalidCharacters() { + final Map ctx = new HashMap(); + ctx.put("Alice:", "trusts Bob"); - @Test(expected = AwsCryptoException.class) - public void contextWithShortValue() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + // Pull out entry count to move to key pos + ctxBuff.getShort(); - // Pull out entry count to move to key pos - ctxBuff.getShort(); + // Pull out key length and bytes. + final short keyLen = ctxBuff.getShort(); + ctxBuff.mark(); - // Pull out key length and bytes. - final short keyLen = ctxBuff.getShort(); - final byte[] key = new byte[keyLen]; - ctxBuff.get(key); + final byte[] key = new byte[keyLen]; + ctxBuff.get(key); - // Overwrite value length - ctxBuff.putShort((short) 0); + // set the first two bytes of the key to an invalid + // unicode character: 0xd800. + key[0] = 0x0; + key[1] = (byte) 0xd8; - // The actual call which should fail - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } + ctxBuff.reset(); + ctxBuff.put(key); - @Test(expected = AwsCryptoException.class) - public void contextWithNegativeValue() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); + // The actual call which should fail + EncryptionContextSerializer.deserialize(ctxBuff.array()); + } - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + @Test(expected = AwsCryptoException.class) + public void contextWithDuplicateEntries() { + final Map ctx = Collections.singletonMap("Alice:", "trusts Bob"); - // Pull out entry count to move to key pos - ctxBuff.getShort(); + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); + // Don't duplicate the entry count + final ByteBuffer dupCtxBuff = ByteBuffer.allocate((2 * ctxBytes.length) - 2); - // Pull out key length and bytes. - final short keyLen = ctxBuff.getShort(); - final byte[] key = new byte[keyLen]; - ctxBuff.get(key); + // Set to 2 entries + dupCtxBuff.putShort((short) 2); - // Overwrite value length - ctxBuff.putShort((short) -1); + // Pull out entry count to move to key pos + ctxBuff.getShort(); + // From here to the end is a single entry, copy it + final byte[] entry = new byte[ctxBuff.remaining()]; + ctxBuff.get(entry); - // The actual call which should fail - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } + dupCtxBuff.put(entry); + dupCtxBuff.put(entry); - @Test(expected = AwsCryptoException.class) - public void contextWithNegativeCount() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); - ctx.put("Bob:", "trusts Trent"); + EncryptionContextSerializer.deserialize(dupCtxBuff.array()); + } - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - - // Overwrite entry count - ctxBuff.putShort((short) -1); - - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } - - @Test(expected = AwsCryptoException.class) - public void contextWithZeroCount() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); - ctx.put("Bob:", "trusts Trent"); - - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - - // Overwrite entry count - ctxBuff.putShort((short) 0); - - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } - - @Test(expected = AwsCryptoException.class) - public void contextWithInvalidCount() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); - ctx.put("Bob:", "trusts Trent"); - - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - - // Overwrite count with more than what we have - ctxBuff.putShort((short) 100); - - // The actual call which should fail - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } - - @Test(expected = IllegalArgumentException.class) - public void contextWithInvalidCharacters() { - final Map ctx = new HashMap(); - ctx.put("Alice:", "trusts Bob"); - - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - - // Pull out entry count to move to key pos - ctxBuff.getShort(); - - // Pull out key length and bytes. - final short keyLen = ctxBuff.getShort(); - ctxBuff.mark(); - - final byte[] key = new byte[keyLen]; - ctxBuff.get(key); - - // set the first two bytes of the key to an invalid - // unicode character: 0xd800. - key[0] = 0x0; - key[1] = (byte) 0xd8; - - ctxBuff.reset(); - ctxBuff.put(key); - - // The actual call which should fail - EncryptionContextSerializer.deserialize(ctxBuff.array()); - } - - @Test(expected = AwsCryptoException.class) - public void contextWithDuplicateEntries() { - final Map ctx = Collections.singletonMap("Alice:", "trusts Bob"); - - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final ByteBuffer ctxBuff = ByteBuffer.wrap(ctxBytes); - // Don't duplicate the entry count - final ByteBuffer dupCtxBuff = ByteBuffer.allocate((2 * ctxBytes.length) - 2); - - // Set to 2 entries - dupCtxBuff.putShort((short) 2); - - // Pull out entry count to move to key pos - ctxBuff.getShort(); - // From here to the end is a single entry, copy it - final byte[] entry = new byte[ctxBuff.remaining()]; - ctxBuff.get(entry); - - dupCtxBuff.put(entry); - dupCtxBuff.put(entry); - - EncryptionContextSerializer.deserialize(dupCtxBuff.array()); - } - - private void testMap(final Map ctx) { - final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); - final Map result = EncryptionContextSerializer.deserialize(ctxBytes); - assertEquals(ctx, result); - } + private void testMap(final Map ctx) { + final byte[] ctxBytes = EncryptionContextSerializer.serialize(Collections.unmodifiableMap(ctx)); + final Map result = EncryptionContextSerializer.deserialize(ctxBytes); + assertEquals(ctx, result); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/EncryptionHandlerTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/EncryptionHandlerTest.java index a5271f508..8c04373c7 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/EncryptionHandlerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/EncryptionHandlerTest.java @@ -10,196 +10,236 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import com.amazonaws.encryptionsdk.TestUtils; -import com.amazonaws.encryptionsdk.CommitmentPolicy; -import org.junit.Test; - import com.amazonaws.encryptionsdk.AwsCrypto; +import com.amazonaws.encryptionsdk.CommitmentPolicy; import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.DefaultCryptoMaterialsManager; +import com.amazonaws.encryptionsdk.TestUtils; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.Test; public class EncryptionHandlerTest { - private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; - private final int frameSize_ = AwsCrypto.getDefaultFrameSize(); - private final Map encryptionContext_ = Collections. emptyMap(); - private StaticMasterKey masterKeyProvider = new StaticMasterKey("mock"); - private final List cmks_ = Collections.singletonList(masterKeyProvider); - private final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; - private EncryptionMaterialsRequest testRequest - = EncryptionMaterialsRequest.newBuilder() - .setContext(encryptionContext_) - .setRequestedAlgorithm(cryptoAlgorithm_) - .setCommitmentPolicy(commitmentPolicy) - .build(); - - private EncryptionMaterials testResult = new DefaultCryptoMaterialsManager(masterKeyProvider) - .getMaterialsForEncrypt(testRequest); - - @Test - public void badArguments() { - assertThrows( - () -> new EncryptionHandler(frameSize_, testResult.toBuilder().setAlgorithm(null).build(), commitmentPolicy) - ); - - assertThrows( - () -> new EncryptionHandler(frameSize_, testResult.toBuilder().setEncryptionContext(null).build(), commitmentPolicy) - ); - - assertThrows( - () -> new EncryptionHandler(frameSize_, testResult.toBuilder().setEncryptedDataKeys(null).build(), commitmentPolicy) - ); - - assertThrows( - () -> new EncryptionHandler(frameSize_, testResult.toBuilder().setEncryptedDataKeys(emptyList()).build(), commitmentPolicy) - ); - - assertThrows( - () -> new EncryptionHandler(frameSize_, testResult.toBuilder().setCleartextDataKey(null).build(), commitmentPolicy) - ); - - assertThrows( - () -> new EncryptionHandler(frameSize_, testResult.toBuilder().setMasterKeys(null).build(), commitmentPolicy) - ); - - assertThrows( - () -> new EncryptionHandler(-1, testResult, commitmentPolicy) - ); - - assertThrows( - () -> new EncryptionHandler(frameSize_, testResult, null) - ); - } - - @Test(expected = AwsCryptoException.class) - public void invalidLenProcessBytes() { - final EncryptionHandler encryptionHandler = new EncryptionHandler(frameSize_, testResult, commitmentPolicy); - - final byte[] in = new byte[1]; - final byte[] out = new byte[1]; - encryptionHandler.processBytes(in, 0, -1, out, 0); - } - - @Test(expected = AwsCryptoException.class) - public void invalidOffsetProcessBytes() { - final EncryptionHandler encryptionHandler = new EncryptionHandler(frameSize_, testResult, commitmentPolicy); - - final byte[] in = new byte[1]; - final byte[] out = new byte[1]; - encryptionHandler.processBytes(in, -1, in.length, out, 0); - } - - @Test - public void whenEncrypting_headerIVIsZero() throws Exception { - final EncryptionHandler encryptionHandler = new EncryptionHandler(frameSize_, testResult, commitmentPolicy); - - assertArrayEquals( - new byte[encryptionHandler.getHeaders().getCryptoAlgoId().getNonceLen()], - encryptionHandler.getHeaders().getHeaderNonce() - ); - } - - @Test - public void whenConstructWithForbidPolicyAndCommittingAlg_fails() throws Exception { - final EncryptionMaterials resultWithV2Alg = testResult.toBuilder().setAlgorithm(TestUtils.DEFAULT_TEST_CRYPTO_ALG).build(); - assertThrows(AwsCryptoException.class, () -> new EncryptionHandler(frameSize_, resultWithV2Alg, CommitmentPolicy.ForbidEncryptAllowDecrypt)); - } - - @Test - public void whenConstructWithForbidPolicyAndNonCommittingAlg_succeeds() throws Exception { - final CryptoAlgorithm algorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; - final EncryptionMaterialsRequest requestForMaterialsWithoutCommitment = EncryptionMaterialsRequest.newBuilder() - .setContext(encryptionContext_) - .setRequestedAlgorithm(algorithm) - .setCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) - .build(); - final EncryptionMaterials materials = new DefaultCryptoMaterialsManager(masterKeyProvider) - .getMaterialsForEncrypt(requestForMaterialsWithoutCommitment); - - EncryptionHandler handler = new EncryptionHandler(frameSize_, materials, CommitmentPolicy.ForbidEncryptAllowDecrypt); - assertNotNull(handler); - assertEquals(algorithm, handler.getHeaders().getCryptoAlgoId()); - } - - @Test - public void whenConstructWithRequirePolicyAndNonCommittingAlg_fails() throws Exception { - final EncryptionMaterials resultWithV1Alg = testResult.toBuilder() - .setAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384) - .build(); - - assertThrows(AwsCryptoException.class, () -> new EncryptionHandler(frameSize_, resultWithV1Alg, CommitmentPolicy.RequireEncryptRequireDecrypt)); - assertThrows(AwsCryptoException.class, () -> new EncryptionHandler(frameSize_, resultWithV1Alg, CommitmentPolicy.RequireEncryptAllowDecrypt)); - } - - @Test - public void whenConstructWithRequirePolicyAndCommittingAlg_succeeds() throws Exception { - final CryptoAlgorithm algorithm = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; - final EncryptionMaterialsRequest requestForMaterialsWithCommitment = EncryptionMaterialsRequest.newBuilder() - .setContext(encryptionContext_) - .setRequestedAlgorithm(algorithm) - .setCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) - .build(); - final EncryptionMaterials materials = new DefaultCryptoMaterialsManager(masterKeyProvider) - .getMaterialsForEncrypt(requestForMaterialsWithCommitment); - final List requireWritePolicies = Arrays.asList( - CommitmentPolicy.RequireEncryptAllowDecrypt,CommitmentPolicy.RequireEncryptRequireDecrypt); - - for (CommitmentPolicy policy : requireWritePolicies) { - EncryptionHandler handler = new EncryptionHandler(frameSize_, materials, policy); - assertNotNull(handler); - assertEquals(algorithm, handler.getHeaders().getCryptoAlgoId()); - } - } - - @Test - public void setMaxInputLength() { - byte[] plaintext = "Don't warn the tadpoles".getBytes(); - final EncryptionHandler encryptionHandler = new EncryptionHandler(frameSize_, testResult, commitmentPolicy); - encryptionHandler.setMaxInputLength(plaintext.length - 1); - - assertEquals(encryptionHandler.getMaxInputLength(), (long)plaintext.length - 1); - - final byte[] out = new byte[1]; - assertThrows(IllegalStateException.class, "Plaintext size exceeds max input size limit", () -> - encryptionHandler.processBytes(plaintext, 0, plaintext.length, out, 0)); - } - - @Test - public void setMaxInputLengthThrowsIfAlreadyOver() { - byte[] plaintext = "Don't warn the tadpoles".getBytes(); - final EncryptionHandler encryptionHandler = new EncryptionHandler(frameSize_, testResult, commitmentPolicy); - final byte[] out = new byte[1024]; - encryptionHandler.processBytes(plaintext, 0, plaintext.length - 1, out, 0); - assertFalse(encryptionHandler.isComplete()); - - assertThrows(IllegalStateException.class, "Plaintext size exceeds max input size limit", () -> - encryptionHandler.setMaxInputLength(plaintext.length - 2)); - } - - @Test - public void setMaxInputLengthAcceptsSmallerValue() { - final EncryptionHandler encryptionHandler = new EncryptionHandler(frameSize_, testResult, commitmentPolicy); - encryptionHandler.setMaxInputLength(100); - assertEquals(encryptionHandler.getMaxInputLength(), 100); - - encryptionHandler.setMaxInputLength(10); - assertEquals(encryptionHandler.getMaxInputLength(), 10); - } - - @Test - public void setMaxInputLengthIgnoresLargerValue() { - final EncryptionHandler encryptionHandler = new EncryptionHandler(frameSize_, testResult, commitmentPolicy); - encryptionHandler.setMaxInputLength(10); - assertEquals(encryptionHandler.getMaxInputLength(), 10); - - encryptionHandler.setMaxInputLength(100); - assertEquals(encryptionHandler.getMaxInputLength(), 10); + private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; + private final int frameSize_ = AwsCrypto.getDefaultFrameSize(); + private final Map encryptionContext_ = Collections.emptyMap(); + private StaticMasterKey masterKeyProvider = new StaticMasterKey("mock"); + private final List cmks_ = Collections.singletonList(masterKeyProvider); + private final CommitmentPolicy commitmentPolicy = TestUtils.DEFAULT_TEST_COMMITMENT_POLICY; + private EncryptionMaterialsRequest testRequest = + EncryptionMaterialsRequest.newBuilder() + .setContext(encryptionContext_) + .setRequestedAlgorithm(cryptoAlgorithm_) + .setCommitmentPolicy(commitmentPolicy) + .build(); + + private EncryptionMaterials testResult = + new DefaultCryptoMaterialsManager(masterKeyProvider).getMaterialsForEncrypt(testRequest); + + @Test + public void badArguments() { + assertThrows( + () -> + new EncryptionHandler( + frameSize_, testResult.toBuilder().setAlgorithm(null).build(), commitmentPolicy)); + + assertThrows( + () -> + new EncryptionHandler( + frameSize_, + testResult.toBuilder().setEncryptionContext(null).build(), + commitmentPolicy)); + + assertThrows( + () -> + new EncryptionHandler( + frameSize_, + testResult.toBuilder().setEncryptedDataKeys(null).build(), + commitmentPolicy)); + + assertThrows( + () -> + new EncryptionHandler( + frameSize_, + testResult.toBuilder().setEncryptedDataKeys(emptyList()).build(), + commitmentPolicy)); + + assertThrows( + () -> + new EncryptionHandler( + frameSize_, + testResult.toBuilder().setCleartextDataKey(null).build(), + commitmentPolicy)); + + assertThrows( + () -> + new EncryptionHandler( + frameSize_, testResult.toBuilder().setMasterKeys(null).build(), commitmentPolicy)); + + assertThrows(() -> new EncryptionHandler(-1, testResult, commitmentPolicy)); + + assertThrows(() -> new EncryptionHandler(frameSize_, testResult, null)); + } + + @Test(expected = AwsCryptoException.class) + public void invalidLenProcessBytes() { + final EncryptionHandler encryptionHandler = + new EncryptionHandler(frameSize_, testResult, commitmentPolicy); + + final byte[] in = new byte[1]; + final byte[] out = new byte[1]; + encryptionHandler.processBytes(in, 0, -1, out, 0); + } + + @Test(expected = AwsCryptoException.class) + public void invalidOffsetProcessBytes() { + final EncryptionHandler encryptionHandler = + new EncryptionHandler(frameSize_, testResult, commitmentPolicy); + + final byte[] in = new byte[1]; + final byte[] out = new byte[1]; + encryptionHandler.processBytes(in, -1, in.length, out, 0); + } + + @Test + public void whenEncrypting_headerIVIsZero() throws Exception { + final EncryptionHandler encryptionHandler = + new EncryptionHandler(frameSize_, testResult, commitmentPolicy); + + assertArrayEquals( + new byte[encryptionHandler.getHeaders().getCryptoAlgoId().getNonceLen()], + encryptionHandler.getHeaders().getHeaderNonce()); + } + + @Test + public void whenConstructWithForbidPolicyAndCommittingAlg_fails() throws Exception { + final EncryptionMaterials resultWithV2Alg = + testResult.toBuilder().setAlgorithm(TestUtils.DEFAULT_TEST_CRYPTO_ALG).build(); + assertThrows( + AwsCryptoException.class, + () -> + new EncryptionHandler( + frameSize_, resultWithV2Alg, CommitmentPolicy.ForbidEncryptAllowDecrypt)); + } + + @Test + public void whenConstructWithForbidPolicyAndNonCommittingAlg_succeeds() throws Exception { + final CryptoAlgorithm algorithm = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + final EncryptionMaterialsRequest requestForMaterialsWithoutCommitment = + EncryptionMaterialsRequest.newBuilder() + .setContext(encryptionContext_) + .setRequestedAlgorithm(algorithm) + .setCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .build(); + final EncryptionMaterials materials = + new DefaultCryptoMaterialsManager(masterKeyProvider) + .getMaterialsForEncrypt(requestForMaterialsWithoutCommitment); + + EncryptionHandler handler = + new EncryptionHandler(frameSize_, materials, CommitmentPolicy.ForbidEncryptAllowDecrypt); + assertNotNull(handler); + assertEquals(algorithm, handler.getHeaders().getCryptoAlgoId()); + } + + @Test + public void whenConstructWithRequirePolicyAndNonCommittingAlg_fails() throws Exception { + final EncryptionMaterials resultWithV1Alg = + testResult.toBuilder() + .setAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384) + .build(); + + assertThrows( + AwsCryptoException.class, + () -> + new EncryptionHandler( + frameSize_, resultWithV1Alg, CommitmentPolicy.RequireEncryptRequireDecrypt)); + assertThrows( + AwsCryptoException.class, + () -> + new EncryptionHandler( + frameSize_, resultWithV1Alg, CommitmentPolicy.RequireEncryptAllowDecrypt)); + } + + @Test + public void whenConstructWithRequirePolicyAndCommittingAlg_succeeds() throws Exception { + final CryptoAlgorithm algorithm = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; + final EncryptionMaterialsRequest requestForMaterialsWithCommitment = + EncryptionMaterialsRequest.newBuilder() + .setContext(encryptionContext_) + .setRequestedAlgorithm(algorithm) + .setCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) + .build(); + final EncryptionMaterials materials = + new DefaultCryptoMaterialsManager(masterKeyProvider) + .getMaterialsForEncrypt(requestForMaterialsWithCommitment); + final List requireWritePolicies = + Arrays.asList( + CommitmentPolicy.RequireEncryptAllowDecrypt, + CommitmentPolicy.RequireEncryptRequireDecrypt); + + for (CommitmentPolicy policy : requireWritePolicies) { + EncryptionHandler handler = new EncryptionHandler(frameSize_, materials, policy); + assertNotNull(handler); + assertEquals(algorithm, handler.getHeaders().getCryptoAlgoId()); } + } + + @Test + public void setMaxInputLength() { + byte[] plaintext = "Don't warn the tadpoles".getBytes(); + final EncryptionHandler encryptionHandler = + new EncryptionHandler(frameSize_, testResult, commitmentPolicy); + encryptionHandler.setMaxInputLength(plaintext.length - 1); + + assertEquals(encryptionHandler.getMaxInputLength(), (long) plaintext.length - 1); + + final byte[] out = new byte[1]; + assertThrows( + IllegalStateException.class, + "Plaintext size exceeds max input size limit", + () -> encryptionHandler.processBytes(plaintext, 0, plaintext.length, out, 0)); + } + + @Test + public void setMaxInputLengthThrowsIfAlreadyOver() { + byte[] plaintext = "Don't warn the tadpoles".getBytes(); + final EncryptionHandler encryptionHandler = + new EncryptionHandler(frameSize_, testResult, commitmentPolicy); + final byte[] out = new byte[1024]; + encryptionHandler.processBytes(plaintext, 0, plaintext.length - 1, out, 0); + assertFalse(encryptionHandler.isComplete()); + + assertThrows( + IllegalStateException.class, + "Plaintext size exceeds max input size limit", + () -> encryptionHandler.setMaxInputLength(plaintext.length - 2)); + } + + @Test + public void setMaxInputLengthAcceptsSmallerValue() { + final EncryptionHandler encryptionHandler = + new EncryptionHandler(frameSize_, testResult, commitmentPolicy); + encryptionHandler.setMaxInputLength(100); + assertEquals(encryptionHandler.getMaxInputLength(), 100); + + encryptionHandler.setMaxInputLength(10); + assertEquals(encryptionHandler.getMaxInputLength(), 10); + } + + @Test + public void setMaxInputLengthIgnoresLargerValue() { + final EncryptionHandler encryptionHandler = + new EncryptionHandler(frameSize_, testResult, commitmentPolicy); + encryptionHandler.setMaxInputLength(10); + assertEquals(encryptionHandler.getMaxInputLength(), 10); + + encryptionHandler.setMaxInputLength(100); + assertEquals(encryptionHandler.getMaxInputLength(), 10); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/FrameDecryptionHandlerTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/FrameDecryptionHandlerTest.java index b996f84c0..1cf4d12eb 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/FrameDecryptionHandlerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/FrameDecryptionHandlerTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -15,104 +15,94 @@ import static org.junit.Assert.assertTrue; +import com.amazonaws.encryptionsdk.AwsCrypto; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.TestUtils; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import java.nio.ByteBuffer; import java.security.SecureRandom; - import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; - -import com.amazonaws.encryptionsdk.TestUtils; -import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import org.junit.Before; import org.junit.Test; -import com.amazonaws.encryptionsdk.AwsCrypto; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; - public class FrameDecryptionHandlerTest { - private static final SecureRandom RND = new SecureRandom(); - private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; - private final byte[] messageId_ = new byte[cryptoAlgorithm_.getMessageIdLength()]; - private final byte nonceLen_ = cryptoAlgorithm_.getNonceLen(); - private final byte[] dataKeyBytes_ = new byte[cryptoAlgorithm_.getKeyLength()]; - private final SecretKey dataKey_ = new SecretKeySpec(dataKeyBytes_, "AES"); - private final int frameSize_ = AwsCrypto.getDefaultFrameSize(); - - private final FrameDecryptionHandler frameDecryptionHandler_ = new FrameDecryptionHandler( - dataKey_, - nonceLen_, - cryptoAlgorithm_, - messageId_, - frameSize_); - - @Before - public void setup() { - RND.nextBytes(messageId_); - RND.nextBytes(dataKeyBytes_); - } - - @Test - public void estimateOutputSize() { - final int inLen = 1; - final int outSize = frameDecryptionHandler_.estimateOutputSize(inLen); - - // the estimated output size must at least be equal to inLen. - assertTrue(outSize >= inLen); - } - - @Test(expected = AwsCryptoException.class) - public void decryptMaxContentLength() { - // Create input of size 1 byte: 1 byte of the sequence number, - // Only 1 byte of the sequence number is provided because this - // forces the frame decryption handler to buffer that 1 byte while - // waiting for the remaining bytes of the sequence number. We do this so - // we can specify an input of max value and the total bytes to parse - // will become max value + 1. - final byte[] in = new byte[1]; - final byte[] out = new byte[1]; - - frameDecryptionHandler_.processBytes(in, 0, in.length, out, 0); - frameDecryptionHandler_.processBytes(in, 0, Integer.MAX_VALUE, out, 0); - } - - @Test(expected = BadCiphertextException.class) - public void finalFrameLengthTooLarge() { - - final ByteBuffer byteBuffer = ByteBuffer.allocate(25); - byteBuffer.put(TestUtils.unsignedBytesToSignedBytes( - new int[] {255, 255, 255, 255, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})); - byteBuffer.putInt(AwsCrypto.getDefaultFrameSize() + 1); - - final byte[] in = byteBuffer.array(); - final byte[] out = new byte[in.length]; - - frameDecryptionHandler_.processBytes(in, 0, in.length, out, 0); - } - - @Test(expected = BadCiphertextException.class) - public void doFinalCalledWhileNotComplete() { - frameDecryptionHandler_.doFinal(new byte[1], 0); - } - - @Test(expected = AwsCryptoException.class) - public void processBytesCalledWhileComplete() { - final FrameEncryptionHandler frameEncryptionHandler = new FrameEncryptionHandler( - dataKey_, - nonceLen_, - cryptoAlgorithm_, - messageId_, - frameSize_); - final byte[] in = new byte[0]; - final int outLen = frameEncryptionHandler.estimateOutputSize(in.length); - final byte[] out = new byte[outLen]; - - frameEncryptionHandler.processBytes(in, 0, in.length, out, 0); - frameEncryptionHandler.doFinal(out, 0); - - final byte[] decryptedOut = new byte[outLen]; - - frameDecryptionHandler_.processBytes(out, 0, out.length, decryptedOut, 0); - frameDecryptionHandler_.processBytes(out, 0, out.length, decryptedOut, 0); - } + private static final SecureRandom RND = new SecureRandom(); + private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; + private final byte[] messageId_ = new byte[cryptoAlgorithm_.getMessageIdLength()]; + private final byte nonceLen_ = cryptoAlgorithm_.getNonceLen(); + private final byte[] dataKeyBytes_ = new byte[cryptoAlgorithm_.getKeyLength()]; + private final SecretKey dataKey_ = new SecretKeySpec(dataKeyBytes_, "AES"); + private final int frameSize_ = AwsCrypto.getDefaultFrameSize(); + + private final FrameDecryptionHandler frameDecryptionHandler_ = + new FrameDecryptionHandler(dataKey_, nonceLen_, cryptoAlgorithm_, messageId_, frameSize_); + + @Before + public void setup() { + RND.nextBytes(messageId_); + RND.nextBytes(dataKeyBytes_); + } + + @Test + public void estimateOutputSize() { + final int inLen = 1; + final int outSize = frameDecryptionHandler_.estimateOutputSize(inLen); + + // the estimated output size must at least be equal to inLen. + assertTrue(outSize >= inLen); + } + + @Test(expected = AwsCryptoException.class) + public void decryptMaxContentLength() { + // Create input of size 1 byte: 1 byte of the sequence number, + // Only 1 byte of the sequence number is provided because this + // forces the frame decryption handler to buffer that 1 byte while + // waiting for the remaining bytes of the sequence number. We do this so + // we can specify an input of max value and the total bytes to parse + // will become max value + 1. + final byte[] in = new byte[1]; + final byte[] out = new byte[1]; + + frameDecryptionHandler_.processBytes(in, 0, in.length, out, 0); + frameDecryptionHandler_.processBytes(in, 0, Integer.MAX_VALUE, out, 0); + } + + @Test(expected = BadCiphertextException.class) + public void finalFrameLengthTooLarge() { + + final ByteBuffer byteBuffer = ByteBuffer.allocate(25); + byteBuffer.put( + TestUtils.unsignedBytesToSignedBytes( + new int[] {255, 255, 255, 255, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})); + byteBuffer.putInt(AwsCrypto.getDefaultFrameSize() + 1); + + final byte[] in = byteBuffer.array(); + final byte[] out = new byte[in.length]; + + frameDecryptionHandler_.processBytes(in, 0, in.length, out, 0); + } + + @Test(expected = BadCiphertextException.class) + public void doFinalCalledWhileNotComplete() { + frameDecryptionHandler_.doFinal(new byte[1], 0); + } + + @Test(expected = AwsCryptoException.class) + public void processBytesCalledWhileComplete() { + final FrameEncryptionHandler frameEncryptionHandler = + new FrameEncryptionHandler(dataKey_, nonceLen_, cryptoAlgorithm_, messageId_, frameSize_); + final byte[] in = new byte[0]; + final int outLen = frameEncryptionHandler.estimateOutputSize(in.length); + final byte[] out = new byte[outLen]; + + frameEncryptionHandler.processBytes(in, 0, in.length, out, 0); + frameEncryptionHandler.doFinal(out, 0); + + final byte[] decryptedOut = new byte[outLen]; + + frameDecryptionHandler_.processBytes(out, 0, out.length, decryptedOut, 0); + frameDecryptionHandler_.processBytes(out, 0, out.length, decryptedOut, 0); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerTest.java index ed662f004..da6bf08f0 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -17,111 +17,104 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import com.amazonaws.encryptionsdk.AwsCrypto; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.TestUtils; +import com.amazonaws.encryptionsdk.model.CipherFrameHeaders; +import java.lang.reflect.Field; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; - -import java.lang.reflect.Field; - -import com.amazonaws.encryptionsdk.TestUtils; import org.bouncycastle.util.encoders.Hex; import org.junit.Before; import org.junit.Test; -import com.amazonaws.encryptionsdk.AwsCrypto; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.model.CipherFrameHeaders; - public class FrameEncryptionHandlerTest { - private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; - private final byte[] messageId_ = RandomBytesGenerator.generate(cryptoAlgorithm_.getMessageIdLength()); - private final byte nonceLen_ = cryptoAlgorithm_.getNonceLen(); - private final byte[] dataKeyBytes_ = RandomBytesGenerator.generate(cryptoAlgorithm_.getKeyLength()); - private final SecretKey encryptionKey_ = new SecretKeySpec(dataKeyBytes_, "AES"); - private final int frameSize_ = AwsCrypto.getDefaultFrameSize(); - - private FrameEncryptionHandler frameEncryptionHandler_; - - @Before - public void setUp() throws Exception { - frameEncryptionHandler_ = new FrameEncryptionHandler( - encryptionKey_, - nonceLen_, - cryptoAlgorithm_, - messageId_, - frameSize_ - ); - } - - @Test - public void emptyOutBytes() { - final int outLen = 0; - final byte[] out = new byte[outLen]; - final int processedLen = frameEncryptionHandler_.doFinal(out, 0); - assertEquals(outLen, processedLen); + private final CryptoAlgorithm cryptoAlgorithm_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG; + private final byte[] messageId_ = + RandomBytesGenerator.generate(cryptoAlgorithm_.getMessageIdLength()); + private final byte nonceLen_ = cryptoAlgorithm_.getNonceLen(); + private final byte[] dataKeyBytes_ = + RandomBytesGenerator.generate(cryptoAlgorithm_.getKeyLength()); + private final SecretKey encryptionKey_ = new SecretKeySpec(dataKeyBytes_, "AES"); + private final int frameSize_ = AwsCrypto.getDefaultFrameSize(); + + private FrameEncryptionHandler frameEncryptionHandler_; + + @Before + public void setUp() throws Exception { + frameEncryptionHandler_ = + new FrameEncryptionHandler( + encryptionKey_, nonceLen_, cryptoAlgorithm_, messageId_, frameSize_); + } + + @Test + public void emptyOutBytes() { + final int outLen = 0; + final byte[] out = new byte[outLen]; + final int processedLen = frameEncryptionHandler_.doFinal(out, 0); + assertEquals(outLen, processedLen); + } + + @Test + public void correctIVsGenerated() throws Exception { + byte[] buf = new byte[frameSize_ + 1024]; + for (int i = 0; i <= 254; i++) { + byte[] expectedNonce = { + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, (byte) (i + 1) + }; + + generateTestBlock(buf); + assertHeaderNonce(expectedNonce, buf); } - @Test - public void correctIVsGenerated() throws Exception { - byte[] buf = new byte[frameSize_ + 1024]; - for (int i = 0; i <= 254; i++) { - byte[] expectedNonce = { - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, (byte) (i + 1) - }; - - generateTestBlock(buf); - assertHeaderNonce(expectedNonce, buf); - } - - generateTestBlock(buf); - assertHeaderNonce(new byte[] { - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 1, 0 - }, buf); - } - - @Test - public void encryptionHandlerEnforcesFrameLimits() throws Exception { - // Skip to the second-to-last frame first. Actually getting there by encrypting block would take a very long - // time, so we'll reflect in; for a test that legitimately gets there, see the - // FrameEncryptionHandlerVeryLongTest. - - Field f = FrameEncryptionHandler.class.getDeclaredField("frameNumber_"); - f.setAccessible(true); - f.set(frameEncryptionHandler_, 0xFFFF_FFFEL); - - byte[] buf = new byte[frameSize_ + 1024]; - // Writing frame 0xFFFF_FFFE should succeed. - generateTestBlock(buf); - assertHeaderNonce(Hex.decode("0000000000000000FFFFFFFE"), buf); - - byte[] oldBuf = buf.clone(); - // Writing the next frame must fail - assertThrows(() -> generateTestBlock(buf)); - // ... and must not produce any output - assertArrayEquals(oldBuf, buf); - - // However we can still finish the encryption - frameEncryptionHandler_.doFinal(buf, 0); - assertHeaderNonce(Hex.decode("0000000000000000FFFFFFFF"), buf); - } - - private void assertHeaderNonce(byte[] expectedNonce, byte[] buf) { - CipherFrameHeaders headers = new CipherFrameHeaders(); - headers.setNonceLength(cryptoAlgorithm_.getNonceLen()); - headers.deserialize(buf, 0); - - assertArrayEquals( - expectedNonce, - headers.getNonce() - ); - } - - private void generateTestBlock(byte[] buf) { - frameEncryptionHandler_.processBytes( - new byte[frameSize_], 0, frameSize_, buf, 0 - ); - } + generateTestBlock(buf); + assertHeaderNonce( + new byte[] { + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 1, 0 + }, + buf); + } + + @Test + public void encryptionHandlerEnforcesFrameLimits() throws Exception { + // Skip to the second-to-last frame first. Actually getting there by encrypting block would take + // a very long + // time, so we'll reflect in; for a test that legitimately gets there, see the + // FrameEncryptionHandlerVeryLongTest. + + Field f = FrameEncryptionHandler.class.getDeclaredField("frameNumber_"); + f.setAccessible(true); + f.set(frameEncryptionHandler_, 0xFFFF_FFFEL); + + byte[] buf = new byte[frameSize_ + 1024]; + // Writing frame 0xFFFF_FFFE should succeed. + generateTestBlock(buf); + assertHeaderNonce(Hex.decode("0000000000000000FFFFFFFE"), buf); + + byte[] oldBuf = buf.clone(); + // Writing the next frame must fail + assertThrows(() -> generateTestBlock(buf)); + // ... and must not produce any output + assertArrayEquals(oldBuf, buf); + + // However we can still finish the encryption + frameEncryptionHandler_.doFinal(buf, 0); + assertHeaderNonce(Hex.decode("0000000000000000FFFFFFFF"), buf); + } + + private void assertHeaderNonce(byte[] expectedNonce, byte[] buf) { + CipherFrameHeaders headers = new CipherFrameHeaders(); + headers.setNonceLength(cryptoAlgorithm_.getNonceLen()); + headers.deserialize(buf, 0); + + assertArrayEquals(expectedNonce, headers.getNonce()); + } + + private void generateTestBlock(byte[] buf) { + frameEncryptionHandler_.processBytes(new byte[frameSize_], 0, frameSize_, buf, 0); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerVeryLongTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerVeryLongTest.java index 228bcc28f..0b99292d6 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerVeryLongTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/FrameEncryptionHandlerVeryLongTest.java @@ -1,74 +1,70 @@ package com.amazonaws.encryptionsdk.internal; -import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.fail; -import javax.crypto.spec.SecretKeySpec; - +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.model.CipherFrameHeaders; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; - +import javax.crypto.spec.SecretKeySpec; import org.bouncycastle.util.encoders.Hex; -import org.bouncycastle.util.encoders.HexTranslator; import org.junit.Test; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.model.CipherFrameHeaders; - /* * This test exhaustively encrypts a 2^32 frame message, which takes approximately 2-3 hours on my hardware. Because of * this long test time, this test is not run as part of the normal suites. */ public class FrameEncryptionHandlerVeryLongTest { - @Test - public void exhaustiveIVCheck() throws Exception { - CryptoAlgorithm algorithm = CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF; - FrameEncryptionHandler frameEncryptionHandler_ = new FrameEncryptionHandler( - new SecretKeySpec(new byte[16], "AES"), - 12, - algorithm, - new byte[16], - 1 - ); + @Test + public void exhaustiveIVCheck() throws Exception { + CryptoAlgorithm algorithm = CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF; + FrameEncryptionHandler frameEncryptionHandler_ = + new FrameEncryptionHandler( + new SecretKeySpec(new byte[16], "AES"), 12, algorithm, new byte[16], 1); - byte[] buf = new byte[1024]; + byte[] buf = new byte[1024]; - ByteBuffer expectedNonce = ByteBuffer.allocate(12); - long lastIndex = 1; // starting index for the test - long lastTS = System.nanoTime(); - for (long i = lastIndex; i <= Constants.MAX_FRAME_NUMBER; i++) { - Utils.clear(expectedNonce); - expectedNonce.order(ByteOrder.BIG_ENDIAN); - expectedNonce.putInt(0); - expectedNonce.putLong(i); + ByteBuffer expectedNonce = ByteBuffer.allocate(12); + long lastIndex = 1; // starting index for the test + long lastTS = System.nanoTime(); + for (long i = lastIndex; i <= Constants.MAX_FRAME_NUMBER; i++) { + Utils.clear(expectedNonce); + expectedNonce.order(ByteOrder.BIG_ENDIAN); + expectedNonce.putInt(0); + expectedNonce.putLong(i); - if (i != Constants.MAX_FRAME_NUMBER) { - frameEncryptionHandler_.processBytes(buf, 0, 1, buf, 0); - } else { - frameEncryptionHandler_.doFinal(buf, 0); - } + if (i != Constants.MAX_FRAME_NUMBER) { + frameEncryptionHandler_.processBytes(buf, 0, 1, buf, 0); + } else { + frameEncryptionHandler_.doFinal(buf, 0); + } - CipherFrameHeaders headers = new CipherFrameHeaders(); - headers.setNonceLength(algorithm.getNonceLen()); - headers.deserialize(buf, 0); + CipherFrameHeaders headers = new CipherFrameHeaders(); + headers.setNonceLength(algorithm.getNonceLen()); + headers.deserialize(buf, 0); - byte[] nonce = headers.getNonce(); - byte[] expectedArray = expectedNonce.array(); - if (!Arrays.equals(nonce, expectedArray)) { - fail(String.format("Index %08x bytes %s != %s", i, new String(Hex.encode(nonce)), new String(Hex.encode(expectedArray)))); - } + byte[] nonce = headers.getNonce(); + byte[] expectedArray = expectedNonce.array(); + if (!Arrays.equals(nonce, expectedArray)) { + fail( + String.format( + "Index %08x bytes %s != %s", + i, new String(Hex.encode(nonce)), new String(Hex.encode(expectedArray)))); + } - if ((i & 0xFFFFF) == 0) { - // Print progress messages, since this test takes a _very_ long time to run. - System.out.print(String.format("%05.2f%% complete", 100*(double)i/(double)Constants.MAX_FRAME_NUMBER)); - long newTS = System.nanoTime(); - System.out.println( - String.format(" at a rate of %f/sec\n", (i - lastIndex)/((newTS - lastTS)/1_000_000_000.0)) - ); - lastTS = newTS; - lastIndex = i; - } - } + if ((i & 0xFFFFF) == 0) { + // Print progress messages, since this test takes a _very_ long time to run. + System.out.print( + String.format( + "%05.2f%% complete", 100 * (double) i / (double) Constants.MAX_FRAME_NUMBER)); + long newTS = System.nanoTime(); + System.out.println( + String.format( + " at a rate of %f/sec\n", (i - lastIndex) / ((newTS - lastTS) / 1_000_000_000.0))); + lastTS = newTS; + lastIndex = i; + } } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/HmacKeyDerivationFunctionTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/HmacKeyDerivationFunctionTest.java index 5481cd442..7cea2c616 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/HmacKeyDerivationFunctionTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/HmacKeyDerivationFunctionTest.java @@ -12,189 +12,207 @@ */ package com.amazonaws.encryptionsdk.internal; +import static org.junit.Assert.assertArrayEquals; + import com.amazonaws.util.StringUtils; import org.junit.Test; -import static org.junit.Assert.assertArrayEquals; - public class HmacKeyDerivationFunctionTest { - private static final testCase[] testCases = new testCase[]{ - new testCase( - "HmacSHA256", - fromCHex("\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b" - + "\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b"), - fromCHex("\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c"), - fromCHex("\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9"), - fromHex("3CB25F25FAACD57A90434F64D0362F2A2D2D0A90CF1A5A4C5DB02D56ECC4C5BF34007208D5B887185865")), - new testCase( - "HmacSHA256", - fromCHex("\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d" - + "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b" - + "\\x1c\\x1d\\x1e\\x1f\\x20\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29" - + "\\x2a\\x2b\\x2c\\x2d\\x2e\\x2f\\x30\\x31\\x32\\x33\\x34\\x35\\x36\\x37" - + "\\x38\\x39\\x3a\\x3b\\x3c\\x3d\\x3e\\x3f\\x40\\x41\\x42\\x43\\x44\\x45" - + "\\x46\\x47\\x48\\x49\\x4a\\x4b\\x4c\\x4d\\x4e\\x4f"), - fromCHex("\\x60\\x61\\x62\\x63\\x64\\x65\\x66\\x67\\x68\\x69\\x6a\\x6b\\x6c\\x6d" - + "\\x6e\\x6f\\x70\\x71\\x72\\x73\\x74\\x75\\x76\\x77\\x78\\x79\\x7a\\x7b" - + "\\x7c\\x7d\\x7e\\x7f\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89" - + "\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97" - + "\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5" - + "\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf"), - fromCHex("\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc\\xbd" - + "\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb" - + "\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9" - + "\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7" - + "\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5" - + "\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd\\xfe\\xff"), - fromHex("B11E398DC80327A1C8E7F78C596A4934" - + "4F012EDA2D4EFAD8A050CC4C19AFA97C" - + "59045A99CAC7827271CB41C65E590E09" - + "DA3275600C2F09B8367793A9ACA3DB71" - + "CC30C58179EC3E87C14C01D5C1F3434F" + "1D87")), - new testCase( - "HmacSHA256", - fromCHex("\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b" - + "\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b"), - new byte[0], new byte[0], - fromHex("8DA4E775A563C18F715F802A063C5A31" - + "B8A11F5C5EE1879EC3454E5F3C738D2D" - + "9D201395FAA4B61A96C8")), - new testCase( - "HmacSHA1", - fromCHex("\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b"), - fromCHex("\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c"), - fromCHex("\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9"), - fromHex("085A01EA1B10F36933068B56EFA5AD81" - + "A4F14B822F5B091568A9CDD4F155FDA2" - + "C22E422478D305F3F896")), - new testCase( - "HmacSHA1", - fromCHex("\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d" - + "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b" - + "\\x1c\\x1d\\x1e\\x1f\\x20\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29" - + "\\x2a\\x2b\\x2c\\x2d\\x2e\\x2f\\x30\\x31\\x32\\x33\\x34\\x35\\x36\\x37" - + "\\x38\\x39\\x3a\\x3b\\x3c\\x3d\\x3e\\x3f\\x40\\x41\\x42\\x43\\x44\\x45" - + "\\x46\\x47\\x48\\x49\\x4a\\x4b\\x4c\\x4d\\x4e\\x4f"), - fromCHex("\\x60\\x61\\x62\\x63\\x64\\x65\\x66\\x67\\x68\\x69\\x6A\\x6B\\x6C\\x6D" - + "\\x6E\\x6F\\x70\\x71\\x72\\x73\\x74\\x75\\x76\\x77\\x78\\x79\\x7A\\x7B" - + "\\x7C\\x7D\\x7E\\x7F\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89" - + "\\x8A\\x8B\\x8C\\x8D\\x8E\\x8F\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97" - + "\\x98\\x99\\x9A\\x9B\\x9C\\x9D\\x9E\\x9F\\xA0\\xA1\\xA2\\xA3\\xA4\\xA5" - + "\\xA6\\xA7\\xA8\\xA9\\xAA\\xAB\\xAC\\xAD\\xAE\\xAF"), - fromCHex("\\xB0\\xB1\\xB2\\xB3\\xB4\\xB5\\xB6\\xB7\\xB8\\xB9\\xBA\\xBB\\xBC\\xBD" - + "\\xBE\\xBF\\xC0\\xC1\\xC2\\xC3\\xC4\\xC5\\xC6\\xC7\\xC8\\xC9\\xCA\\xCB" - + "\\xCC\\xCD\\xCE\\xCF\\xD0\\xD1\\xD2\\xD3\\xD4\\xD5\\xD6\\xD7\\xD8\\xD9" - + "\\xDA\\xDB\\xDC\\xDD\\xDE\\xDF\\xE0\\xE1\\xE2\\xE3\\xE4\\xE5\\xE6\\xE7" - + "\\xE8\\xE9\\xEA\\xEB\\xEC\\xED\\xEE\\xEF\\xF0\\xF1\\xF2\\xF3\\xF4\\xF5" - + "\\xF6\\xF7\\xF8\\xF9\\xFA\\xFB\\xFC\\xFD\\xFE\\xFF"), - fromHex("0BD770A74D1160F7C9F12CD5912A06EB" - + "FF6ADCAE899D92191FE4305673BA2FFE" - + "8FA3F1A4E5AD79F3F334B3B202B2173C" - + "486EA37CE3D397ED034C7F9DFEB15C5E" - + "927336D0441F4C4300E2CFF0D0900B52D3B4")), - new testCase( - "HmacSHA1", - fromCHex("\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b" - + "\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b"), - new byte[0], new byte[0], - fromHex("0AC1AF7002B3D761D1E55298DA9D0506" - + "B9AE52057220A306E07B6B87E8DF21D0")), - new testCase( - "HmacSHA1", - fromCHex("\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c" - + "\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c"), - null, new byte[0], - fromHex("2C91117204D745F3500D636A62F64F0A" - + "B3BAE548AA53D423B0D1F27EBBA6F5E5" - + "673A081D70CCE7ACFC48"))}; + private static final testCase[] testCases = + new testCase[] { + new testCase( + "HmacSHA256", + fromCHex( + "\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b" + + "\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b"), + fromCHex("\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c"), + fromCHex("\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9"), + fromHex( + "3CB25F25FAACD57A90434F64D0362F2A2D2D0A90CF1A5A4C5DB02D56ECC4C5BF34007208D5B887185865")), + new testCase( + "HmacSHA256", + fromCHex( + "\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d" + + "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b" + + "\\x1c\\x1d\\x1e\\x1f\\x20\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29" + + "\\x2a\\x2b\\x2c\\x2d\\x2e\\x2f\\x30\\x31\\x32\\x33\\x34\\x35\\x36\\x37" + + "\\x38\\x39\\x3a\\x3b\\x3c\\x3d\\x3e\\x3f\\x40\\x41\\x42\\x43\\x44\\x45" + + "\\x46\\x47\\x48\\x49\\x4a\\x4b\\x4c\\x4d\\x4e\\x4f"), + fromCHex( + "\\x60\\x61\\x62\\x63\\x64\\x65\\x66\\x67\\x68\\x69\\x6a\\x6b\\x6c\\x6d" + + "\\x6e\\x6f\\x70\\x71\\x72\\x73\\x74\\x75\\x76\\x77\\x78\\x79\\x7a\\x7b" + + "\\x7c\\x7d\\x7e\\x7f\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89" + + "\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97" + + "\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5" + + "\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf"), + fromCHex( + "\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc\\xbd" + + "\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb" + + "\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9" + + "\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7" + + "\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5" + + "\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd\\xfe\\xff"), + fromHex( + "B11E398DC80327A1C8E7F78C596A4934" + + "4F012EDA2D4EFAD8A050CC4C19AFA97C" + + "59045A99CAC7827271CB41C65E590E09" + + "DA3275600C2F09B8367793A9ACA3DB71" + + "CC30C58179EC3E87C14C01D5C1F3434F" + + "1D87")), + new testCase( + "HmacSHA256", + fromCHex( + "\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b" + + "\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b"), + new byte[0], + new byte[0], + fromHex( + "8DA4E775A563C18F715F802A063C5A31" + + "B8A11F5C5EE1879EC3454E5F3C738D2D" + + "9D201395FAA4B61A96C8")), + new testCase( + "HmacSHA1", + fromCHex("\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b"), + fromCHex("\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c"), + fromCHex("\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9"), + fromHex( + "085A01EA1B10F36933068B56EFA5AD81" + + "A4F14B822F5B091568A9CDD4F155FDA2" + + "C22E422478D305F3F896")), + new testCase( + "HmacSHA1", + fromCHex( + "\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d" + + "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b" + + "\\x1c\\x1d\\x1e\\x1f\\x20\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29" + + "\\x2a\\x2b\\x2c\\x2d\\x2e\\x2f\\x30\\x31\\x32\\x33\\x34\\x35\\x36\\x37" + + "\\x38\\x39\\x3a\\x3b\\x3c\\x3d\\x3e\\x3f\\x40\\x41\\x42\\x43\\x44\\x45" + + "\\x46\\x47\\x48\\x49\\x4a\\x4b\\x4c\\x4d\\x4e\\x4f"), + fromCHex( + "\\x60\\x61\\x62\\x63\\x64\\x65\\x66\\x67\\x68\\x69\\x6A\\x6B\\x6C\\x6D" + + "\\x6E\\x6F\\x70\\x71\\x72\\x73\\x74\\x75\\x76\\x77\\x78\\x79\\x7A\\x7B" + + "\\x7C\\x7D\\x7E\\x7F\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89" + + "\\x8A\\x8B\\x8C\\x8D\\x8E\\x8F\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97" + + "\\x98\\x99\\x9A\\x9B\\x9C\\x9D\\x9E\\x9F\\xA0\\xA1\\xA2\\xA3\\xA4\\xA5" + + "\\xA6\\xA7\\xA8\\xA9\\xAA\\xAB\\xAC\\xAD\\xAE\\xAF"), + fromCHex( + "\\xB0\\xB1\\xB2\\xB3\\xB4\\xB5\\xB6\\xB7\\xB8\\xB9\\xBA\\xBB\\xBC\\xBD" + + "\\xBE\\xBF\\xC0\\xC1\\xC2\\xC3\\xC4\\xC5\\xC6\\xC7\\xC8\\xC9\\xCA\\xCB" + + "\\xCC\\xCD\\xCE\\xCF\\xD0\\xD1\\xD2\\xD3\\xD4\\xD5\\xD6\\xD7\\xD8\\xD9" + + "\\xDA\\xDB\\xDC\\xDD\\xDE\\xDF\\xE0\\xE1\\xE2\\xE3\\xE4\\xE5\\xE6\\xE7" + + "\\xE8\\xE9\\xEA\\xEB\\xEC\\xED\\xEE\\xEF\\xF0\\xF1\\xF2\\xF3\\xF4\\xF5" + + "\\xF6\\xF7\\xF8\\xF9\\xFA\\xFB\\xFC\\xFD\\xFE\\xFF"), + fromHex( + "0BD770A74D1160F7C9F12CD5912A06EB" + + "FF6ADCAE899D92191FE4305673BA2FFE" + + "8FA3F1A4E5AD79F3F334B3B202B2173C" + + "486EA37CE3D397ED034C7F9DFEB15C5E" + + "927336D0441F4C4300E2CFF0D0900B52D3B4")), + new testCase( + "HmacSHA1", + fromCHex( + "\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b" + + "\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b\\x0b"), + new byte[0], + new byte[0], + fromHex("0AC1AF7002B3D761D1E55298DA9D0506" + "B9AE52057220A306E07B6B87E8DF21D0")), + new testCase( + "HmacSHA1", + fromCHex( + "\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c" + + "\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c\\x0c"), + null, + new byte[0], + fromHex( + "2C91117204D745F3500D636A62F64F0A" + + "B3BAE548AA53D423B0D1F27EBBA6F5E5" + + "673A081D70CCE7ACFC48")) + }; - @Test - public void rfc5869Tests() throws Exception { - for (int x = 0; x < testCases.length; x++) { - testCase trial = testCases[x]; - System.out.println("Test case A." + (x + 1)); - HmacKeyDerivationFunction kdf = HmacKeyDerivationFunction.getInstance(trial.algo); - kdf.init(trial.ikm, trial.salt); - byte[] result = kdf.deriveKey(trial.info, trial.expected.length); - assertArrayEquals("Trial A." + x, trial.expected, result); - } + @Test + public void rfc5869Tests() throws Exception { + for (int x = 0; x < testCases.length; x++) { + testCase trial = testCases[x]; + System.out.println("Test case A." + (x + 1)); + HmacKeyDerivationFunction kdf = HmacKeyDerivationFunction.getInstance(trial.algo); + kdf.init(trial.ikm, trial.salt); + byte[] result = kdf.deriveKey(trial.info, trial.expected.length); + assertArrayEquals("Trial A." + x, trial.expected, result); } + } - @Test - public void nullTests() throws Exception { - testCase trial = testCases[0]; - HmacKeyDerivationFunction kdf = HmacKeyDerivationFunction.getInstance(trial.algo); - kdf.init(trial.ikm, trial.salt); - // Just ensuring no exceptions are thrown - kdf.deriveKey(null, 16); - } + @Test + public void nullTests() throws Exception { + testCase trial = testCases[0]; + HmacKeyDerivationFunction kdf = HmacKeyDerivationFunction.getInstance(trial.algo); + kdf.init(trial.ikm, trial.salt); + // Just ensuring no exceptions are thrown + kdf.deriveKey(null, 16); + } - @Test(expected = IllegalArgumentException.class) - public void invalidLength() throws Exception { - testCase trial = testCases[0]; - HmacKeyDerivationFunction kdf = HmacKeyDerivationFunction.getInstance(trial.algo); - kdf.init(trial.ikm, trial.salt); - kdf.deriveKey(trial.info, -1); - } + @Test(expected = IllegalArgumentException.class) + public void invalidLength() throws Exception { + testCase trial = testCases[0]; + HmacKeyDerivationFunction kdf = HmacKeyDerivationFunction.getInstance(trial.algo); + kdf.init(trial.ikm, trial.salt); + kdf.deriveKey(trial.info, -1); + } - @Test - public void defaultSalt() throws Exception { - // Tests all the different ways to get the default salt + @Test + public void defaultSalt() throws Exception { + // Tests all the different ways to get the default salt - testCase trial = testCases[0]; - HmacKeyDerivationFunction kdf1 = HmacKeyDerivationFunction.getInstance(trial.algo); - kdf1.init(trial.ikm, null); - HmacKeyDerivationFunction kdf2 = HmacKeyDerivationFunction.getInstance(trial.algo); - kdf2.init(trial.ikm, new byte[0]); - HmacKeyDerivationFunction kdf3 = HmacKeyDerivationFunction.getInstance(trial.algo); - kdf3.init(trial.ikm); - HmacKeyDerivationFunction kdf4 = HmacKeyDerivationFunction.getInstance(trial.algo); - kdf4.init(trial.ikm, new byte[32]); + testCase trial = testCases[0]; + HmacKeyDerivationFunction kdf1 = HmacKeyDerivationFunction.getInstance(trial.algo); + kdf1.init(trial.ikm, null); + HmacKeyDerivationFunction kdf2 = HmacKeyDerivationFunction.getInstance(trial.algo); + kdf2.init(trial.ikm, new byte[0]); + HmacKeyDerivationFunction kdf3 = HmacKeyDerivationFunction.getInstance(trial.algo); + kdf3.init(trial.ikm); + HmacKeyDerivationFunction kdf4 = HmacKeyDerivationFunction.getInstance(trial.algo); + kdf4.init(trial.ikm, new byte[32]); - byte[] testBytes = "Test".getBytes(StringUtils.UTF8); - byte[] key1 = kdf1.deriveKey(testBytes, 16); - byte[] key2 = kdf2.deriveKey(testBytes, 16); - byte[] key3 = kdf3.deriveKey(testBytes, 16); - byte[] key4 = kdf4.deriveKey(testBytes, 16); + byte[] testBytes = "Test".getBytes(StringUtils.UTF8); + byte[] key1 = kdf1.deriveKey(testBytes, 16); + byte[] key2 = kdf2.deriveKey(testBytes, 16); + byte[] key3 = kdf3.deriveKey(testBytes, 16); + byte[] key4 = kdf4.deriveKey(testBytes, 16); - assertArrayEquals(key1, key2); - assertArrayEquals(key1, key3); - assertArrayEquals(key1, key4); - } + assertArrayEquals(key1, key2); + assertArrayEquals(key1, key3); + assertArrayEquals(key1, key4); + } - private static byte[] fromHex(String data) { - byte[] result = new byte[data.length() / 2]; - for (int x = 0; x < result.length; x++) { - result[x] = (byte) Integer.parseInt( - data.substring(2 * x, 2 * x + 2), 16); - } - return result; + private static byte[] fromHex(String data) { + byte[] result = new byte[data.length() / 2]; + for (int x = 0; x < result.length; x++) { + result[x] = (byte) Integer.parseInt(data.substring(2 * x, 2 * x + 2), 16); } + return result; + } - private static byte[] fromCHex(String data) { - byte[] result = new byte[data.length() / 4]; - for (int x = 0; x < result.length; x++) { - result[x] = (byte) Integer.parseInt( - data.substring(4 * x + 2, 4 * x + 4), 16); - } - return result; + private static byte[] fromCHex(String data) { + byte[] result = new byte[data.length() / 4]; + for (int x = 0; x < result.length; x++) { + result[x] = (byte) Integer.parseInt(data.substring(4 * x + 2, 4 * x + 4), 16); } + return result; + } - private static class testCase { - public final String algo; - public final byte[] ikm; - public final byte[] salt; - public final byte[] info; - public final byte[] expected; + private static class testCase { + public final String algo; + public final byte[] ikm; + public final byte[] salt; + public final byte[] info; + public final byte[] expected; - testCase(String algo, byte[] ikm, byte[] salt, byte[] info, - byte[] expected) { - super(); - this.algo = algo; - this.ikm = ikm; - this.salt = salt; - this.info = info; - this.expected = expected; - } + testCase(String algo, byte[] ikm, byte[] salt, byte[] info, byte[] expected) { + super(); + this.algo = algo; + this.ikm = ikm; + this.salt = salt; + this.info = info; + this.expected = expected; } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/PrimitivesParserTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/PrimitivesParserTest.java index 0d5b8532d..62b3570b7 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/PrimitivesParserTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/PrimitivesParserTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -18,86 +18,80 @@ import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; - import org.junit.Test; public class PrimitivesParserTest { - @Test - public void testParseLong() throws IOException { - final long[] tests = new long[] { - Long.MIN_VALUE, - Long.MAX_VALUE, - -1, - 0, - 1, - Long.MIN_VALUE + 1, - Long.MAX_VALUE - 1 + @Test + public void testParseLong() throws IOException { + final long[] tests = + new long[] { + Long.MIN_VALUE, Long.MAX_VALUE, -1, 0, 1, Long.MIN_VALUE + 1, Long.MAX_VALUE - 1 }; - for (long x : tests) { - try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - final DataOutputStream dos = new DataOutputStream(baos)) { - dos.writeLong(x); - dos.close(); - assertEquals(x, PrimitivesParser.parseLong(baos.toByteArray(), 0)); - } - } + for (long x : tests) { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream dos = new DataOutputStream(baos)) { + dos.writeLong(x); + dos.close(); + assertEquals(x, PrimitivesParser.parseLong(baos.toByteArray(), 0)); + } } + } - @Test - public void testParseInt() throws IOException { - final int[] tests = new int []{ - Integer.MIN_VALUE, - Integer.MAX_VALUE, - -1, - 0, - 1, - Integer.MIN_VALUE + 1, - Integer.MAX_VALUE - 1 + @Test + public void testParseInt() throws IOException { + final int[] tests = + new int[] { + Integer.MIN_VALUE, + Integer.MAX_VALUE, + -1, + 0, + 1, + Integer.MIN_VALUE + 1, + Integer.MAX_VALUE - 1 }; - for (int x : tests) { - try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - final DataOutputStream dos = new DataOutputStream(baos)) { - dos.writeInt(x); - dos.close(); - assertEquals(x, PrimitivesParser.parseInt(baos.toByteArray(), 0)); - } - } + for (int x : tests) { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream dos = new DataOutputStream(baos)) { + dos.writeInt(x); + dos.close(); + assertEquals(x, PrimitivesParser.parseInt(baos.toByteArray(), 0)); + } } + } - @Test - public void testParseShort() throws IOException { - for (int x = Short.MIN_VALUE; x < Short.MAX_VALUE; x++) { - try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - final DataOutputStream dos = new DataOutputStream(baos)) { - dos.writeShort(x); - dos.close(); - assertEquals((short) x, PrimitivesParser.parseShort(baos.toByteArray(), 0)); - } - } + @Test + public void testParseShort() throws IOException { + for (int x = Short.MIN_VALUE; x < Short.MAX_VALUE; x++) { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream dos = new DataOutputStream(baos)) { + dos.writeShort(x); + dos.close(); + assertEquals((short) x, PrimitivesParser.parseShort(baos.toByteArray(), 0)); + } } + } - @Test - public void testParseUnsignedShort() throws IOException { - for (int x = 0; x < Constants.UNSIGNED_SHORT_MAX_VAL; x++) { - try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - final DataOutputStream dos = new DataOutputStream(baos)) { - PrimitivesParser.writeUnsignedShort(dos, x); - assertEquals(x, PrimitivesParser.parseUnsignedShort(baos.toByteArray(), 0)); - } - } + @Test + public void testParseUnsignedShort() throws IOException { + for (int x = 0; x < Constants.UNSIGNED_SHORT_MAX_VAL; x++) { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream dos = new DataOutputStream(baos)) { + PrimitivesParser.writeUnsignedShort(dos, x); + assertEquals(x, PrimitivesParser.parseUnsignedShort(baos.toByteArray(), 0)); + } } + } - @Test - public void testParseByte() throws IOException { - for (int x = Byte.MIN_VALUE; x < Byte.MAX_VALUE; x++) { - try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - final DataOutputStream dos = new DataOutputStream(baos)) { - dos.writeByte(x); - dos.close(); - assertEquals((byte) x, PrimitivesParser.parseByte(baos.toByteArray(), 0)); - } - } + @Test + public void testParseByte() throws IOException { + for (int x = Byte.MIN_VALUE; x < Byte.MAX_VALUE; x++) { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream dos = new DataOutputStream(baos)) { + dos.writeByte(x); + dos.close(); + assertEquals((byte) x, PrimitivesParser.parseByte(baos.toByteArray(), 0)); + } } - + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/RandomBytesGenerator.java b/src/test/java/com/amazonaws/encryptionsdk/internal/RandomBytesGenerator.java index 45718b230..db3ba12c7 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/RandomBytesGenerator.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/RandomBytesGenerator.java @@ -16,33 +16,32 @@ import java.security.SecureRandom; public class RandomBytesGenerator { - private static final SecureRandom RND = new SecureRandom(); - - /* Some Providers (such as the FIPS certified Bouncy Castle) enforce a - * maximum number of bytes that may be requested from SecureRandom. If - * the requested len is larger than this value, the Secure Random will - * be called multiple times to achieve the requested total length. */ - private static final int MAX_BYTES = 1 << 15; - - /** - * Generates a byte array of random data of the given length. - * - * @param len The length of the byte array. - * @return The byte array. - */ - public static byte[] generate(final int len) { - final byte[] result = new byte[len]; - int bytesGenerated = 0; - - while (bytesGenerated < len) { - final int requestSize = Math.min(MAX_BYTES, len - bytesGenerated); - final byte[] request = new byte[requestSize]; - RND.nextBytes(request); - System.arraycopy(request, 0, result, bytesGenerated, requestSize); - bytesGenerated += requestSize; - } - - return result; + private static final SecureRandom RND = new SecureRandom(); + + /* Some Providers (such as the FIPS certified Bouncy Castle) enforce a + * maximum number of bytes that may be requested from SecureRandom. If + * the requested len is larger than this value, the Secure Random will + * be called multiple times to achieve the requested total length. */ + private static final int MAX_BYTES = 1 << 15; + + /** + * Generates a byte array of random data of the given length. + * + * @param len The length of the byte array. + * @return The byte array. + */ + public static byte[] generate(final int len) { + final byte[] result = new byte[len]; + int bytesGenerated = 0; + + while (bytesGenerated < len) { + final int requestSize = Math.min(MAX_BYTES, len - bytesGenerated); + final byte[] request = new byte[requestSize]; + RND.nextBytes(request); + System.arraycopy(request, 0, result, bytesGenerated, requestSize); + bytesGenerated += requestSize; } + return result; + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/StaticMasterKey.java b/src/test/java/com/amazonaws/encryptionsdk/internal/StaticMasterKey.java index b4d5b66db..21b1bd71c 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/StaticMasterKey.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/StaticMasterKey.java @@ -1,5 +1,11 @@ package com.amazonaws.encryptionsdk.internal; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import com.amazonaws.encryptionsdk.DataKey; +import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.MasterKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; +import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.security.KeyFactory; @@ -12,7 +18,6 @@ import java.util.Collection; import java.util.Map; import java.util.Objects; - import javax.annotation.Nonnull; import javax.annotation.concurrent.NotThreadSafe; import javax.crypto.Cipher; @@ -20,190 +25,165 @@ import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.DataKey; -import com.amazonaws.encryptionsdk.EncryptedDataKey; -import com.amazonaws.encryptionsdk.MasterKey; -import com.amazonaws.encryptionsdk.exception.AwsCryptoException; -import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; - /** - * Static implementation of the {@link MasterKey} interface that should only - * used for unit-tests. - *

- * Contains a statically defined asymmetric master key-pair that can be used - * to encrypt and decrypt (randomly generated) symmetric data key. + * Static implementation of the {@link MasterKey} interface that should only used for unit-tests. + * + *

Contains a statically defined asymmetric master key-pair that can be used to encrypt and + * decrypt (randomly generated) symmetric data key. + * *

- * + * * @author patye */ @NotThreadSafe public class StaticMasterKey extends MasterKey { - private static final String PROVIDER_ID = "static_provider"; - - /** - * Generates random strings that can be used to create data keys. - */ - private static final SecureRandom SRAND = new SecureRandom(); - - /** - * Encryption algorithm for the master key-pair - */ - private static final String MASTER_KEY_ENCRYPTION_ALGORITHM = "RSA/ECB/PKCS1Padding"; - - /** - * Encryption algorithm for the KeyFactory - */ - private static final String MASTER_KEY_ALGORITHM = "RSA"; - - /** - * Encryption algorithm for the randomly generated data key - */ - private static final String DATA_KEY_ENCRYPTION_ALGORITHM = "AES"; - - /** - * The ID of the master key - */ - @Nonnull - private String keyId_; - - /** - * The {@link Cipher} object created with the public part of - * the master-key. It's used to encrypt data keys. - */ - @Nonnull - private final Cipher masterKeyEncryptionCipher_; - - /** - * The {@link Cipher} object created with the private part of - * the master-key. It's used to decrypt encrypted data keys. - */ - @Nonnull - private final Cipher masterKeyDecryptionCipher_; - - /** - * Generates random data keys. - */ - @Nonnull - private KeyGenerator keyGenerator_; - - /** - * Creates a new object that encrypts the data key with a master key - * whose id is {@code keyId}. - *

- * The value of {@code keyId} does not affect how the data key will be - * generated or encrypted. The {@code keyId} forms part of the header - * of the encrypted data, and is used to ensure that the header cannot - * be tempered with. - */ - public StaticMasterKey(@Nonnull final String keyId) { - this.keyId_ = Objects.requireNonNull(keyId); - - try { - KeyFactory keyFactory = KeyFactory.getInstance(MASTER_KEY_ALGORITHM); - KeySpec publicKeySpec = new X509EncodedKeySpec(publicKey_v1); - PublicKey pubKey = keyFactory.generatePublic(publicKeySpec); - KeySpec privateKeySpec = new PKCS8EncodedKeySpec(privateKey_v1); - PrivateKey privKey = keyFactory.generatePrivate(privateKeySpec); - - masterKeyEncryptionCipher_ = Cipher.getInstance(MASTER_KEY_ENCRYPTION_ALGORITHM); - masterKeyEncryptionCipher_.init(Cipher.ENCRYPT_MODE, pubKey); - - masterKeyDecryptionCipher_ = Cipher.getInstance(MASTER_KEY_ENCRYPTION_ALGORITHM); - masterKeyDecryptionCipher_.init(Cipher.DECRYPT_MODE, privKey); - - } catch (GeneralSecurityException ex) { - throw new RuntimeException(ex); - } - } - - /** - * Changes the {@link #keyId_}. This method is expected to be used - * to test that header of an encrypted message cannot be tempered with. - */ - public void setKeyId(@Nonnull String keyId) { - this.keyId_ = Objects.requireNonNull(keyId); - } - - @Override - public String getProviderId() { - return PROVIDER_ID; - } - - @Override - public String getKeyId() { - return keyId_; + private static final String PROVIDER_ID = "static_provider"; + + /** Generates random strings that can be used to create data keys. */ + private static final SecureRandom SRAND = new SecureRandom(); + + /** Encryption algorithm for the master key-pair */ + private static final String MASTER_KEY_ENCRYPTION_ALGORITHM = "RSA/ECB/PKCS1Padding"; + + /** Encryption algorithm for the KeyFactory */ + private static final String MASTER_KEY_ALGORITHM = "RSA"; + + /** Encryption algorithm for the randomly generated data key */ + private static final String DATA_KEY_ENCRYPTION_ALGORITHM = "AES"; + + /** The ID of the master key */ + @Nonnull private String keyId_; + + /** + * The {@link Cipher} object created with the public part of the master-key. It's used to encrypt + * data keys. + */ + @Nonnull private final Cipher masterKeyEncryptionCipher_; + + /** + * The {@link Cipher} object created with the private part of the master-key. It's used to decrypt + * encrypted data keys. + */ + @Nonnull private final Cipher masterKeyDecryptionCipher_; + + /** Generates random data keys. */ + @Nonnull private KeyGenerator keyGenerator_; + + /** + * Creates a new object that encrypts the data key with a master key whose id is {@code keyId}. + * + *

The value of {@code keyId} does not affect how the data key will be generated or encrypted. + * The {@code keyId} forms part of the header of the encrypted data, and is used to ensure that + * the header cannot be tempered with. + */ + public StaticMasterKey(@Nonnull final String keyId) { + this.keyId_ = Objects.requireNonNull(keyId); + + try { + KeyFactory keyFactory = KeyFactory.getInstance(MASTER_KEY_ALGORITHM); + KeySpec publicKeySpec = new X509EncodedKeySpec(publicKey_v1); + PublicKey pubKey = keyFactory.generatePublic(publicKeySpec); + KeySpec privateKeySpec = new PKCS8EncodedKeySpec(privateKey_v1); + PrivateKey privKey = keyFactory.generatePrivate(privateKeySpec); + + masterKeyEncryptionCipher_ = Cipher.getInstance(MASTER_KEY_ENCRYPTION_ALGORITHM); + masterKeyEncryptionCipher_.init(Cipher.ENCRYPT_MODE, pubKey); + + masterKeyDecryptionCipher_ = Cipher.getInstance(MASTER_KEY_ENCRYPTION_ALGORITHM); + masterKeyDecryptionCipher_.init(Cipher.DECRYPT_MODE, privKey); + + } catch (GeneralSecurityException ex) { + throw new RuntimeException(ex); } - - @Override - public DataKey generateDataKey(CryptoAlgorithm algorithm, - Map encryptionContext) { - try { - this.keyGenerator_ = KeyGenerator.getInstance(DATA_KEY_ENCRYPTION_ALGORITHM); - this.keyGenerator_.init(algorithm.getDataKeyLength() * 8, SRAND); - SecretKey key = new SecretKeySpec(keyGenerator_.generateKey().getEncoded(), algorithm.getDataKeyAlgo()); - byte[] encryptedKey = masterKeyEncryptionCipher_.doFinal(key.getEncoded()); - return new DataKey<>(key, encryptedKey, keyId_.getBytes(StandardCharsets.UTF_8), this); - } catch (GeneralSecurityException ex) { - throw new RuntimeException(ex); - } + } + + /** + * Changes the {@link #keyId_}. This method is expected to be used to test that header of an + * encrypted message cannot be tempered with. + */ + public void setKeyId(@Nonnull String keyId) { + this.keyId_ = Objects.requireNonNull(keyId); + } + + @Override + public String getProviderId() { + return PROVIDER_ID; + } + + @Override + public String getKeyId() { + return keyId_; + } + + @Override + public DataKey generateDataKey( + CryptoAlgorithm algorithm, Map encryptionContext) { + try { + this.keyGenerator_ = KeyGenerator.getInstance(DATA_KEY_ENCRYPTION_ALGORITHM); + this.keyGenerator_.init(algorithm.getDataKeyLength() * 8, SRAND); + SecretKey key = + new SecretKeySpec(keyGenerator_.generateKey().getEncoded(), algorithm.getDataKeyAlgo()); + byte[] encryptedKey = masterKeyEncryptionCipher_.doFinal(key.getEncoded()); + return new DataKey<>(key, encryptedKey, keyId_.getBytes(StandardCharsets.UTF_8), this); + } catch (GeneralSecurityException ex) { + throw new RuntimeException(ex); } - - @Override - public DataKey encryptDataKey(CryptoAlgorithm algorithm, - Map encryptionContext, DataKey dataKey) { - try { - byte[] unencryptedKey = dataKey.getKey().getEncoded(); - byte[] encryptedKey = masterKeyEncryptionCipher_.doFinal(unencryptedKey); - SecretKey newKey = new SecretKeySpec(dataKey.getKey().getEncoded(), algorithm.getDataKeyAlgo()); - return new DataKey<>(newKey, encryptedKey, keyId_.getBytes(StandardCharsets.UTF_8), this); - } catch (GeneralSecurityException ex) { - throw new RuntimeException(ex); - } + } + + @Override + public DataKey encryptDataKey( + CryptoAlgorithm algorithm, Map encryptionContext, DataKey dataKey) { + try { + byte[] unencryptedKey = dataKey.getKey().getEncoded(); + byte[] encryptedKey = masterKeyEncryptionCipher_.doFinal(unencryptedKey); + SecretKey newKey = + new SecretKeySpec(dataKey.getKey().getEncoded(), algorithm.getDataKeyAlgo()); + return new DataKey<>(newKey, encryptedKey, keyId_.getBytes(StandardCharsets.UTF_8), this); + } catch (GeneralSecurityException ex) { + throw new RuntimeException(ex); } - - @Override - public DataKey decryptDataKey(CryptoAlgorithm algorithm, - Collection encryptedDataKeys, - Map encryptionContext) - throws UnsupportedProviderException, AwsCryptoException { - try { - for (EncryptedDataKey edk :encryptedDataKeys) { - if (keyId_.equals(new String(edk.getProviderInformation(), StandardCharsets.UTF_8))) { - byte[] unencryptedDataKey = masterKeyDecryptionCipher_.doFinal(edk.getEncryptedDataKey()); - SecretKey key = new SecretKeySpec(unencryptedDataKey, algorithm.getDataKeyAlgo()); - return new DataKey<>(key, edk.getEncryptedDataKey(), edk.getProviderInformation(), this); - } - } - } catch (GeneralSecurityException ex) { - throw new RuntimeException(ex); + } + + @Override + public DataKey decryptDataKey( + CryptoAlgorithm algorithm, + Collection encryptedDataKeys, + Map encryptionContext) + throws UnsupportedProviderException, AwsCryptoException { + try { + for (EncryptedDataKey edk : encryptedDataKeys) { + if (keyId_.equals(new String(edk.getProviderInformation(), StandardCharsets.UTF_8))) { + byte[] unencryptedDataKey = masterKeyDecryptionCipher_.doFinal(edk.getEncryptedDataKey()); + SecretKey key = new SecretKeySpec(unencryptedDataKey, algorithm.getDataKeyAlgo()); + return new DataKey<>(key, edk.getEncryptedDataKey(), edk.getProviderInformation(), this); } - return null; + } + } catch (GeneralSecurityException ex) { + throw new RuntimeException(ex); } - - /** - * Statically configured private key. - */ - private static final byte[] privateKey_v1 = Utils.decodeBase64String( - "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAKLpwqjYtYExVilW/Hg0ogWv9xZ+" - + "THj4IzvISLlPtK8W6KXMcqukfuxdYmndPv8UD1DbdHFYSSistdqoBN32vVQOQnJZyYm45i2TDOV0" - + "M2DtHtR6aMMlBLGtdPeeaT88nQfI1ORjRDyR1byMwomvmKifZYga6FjLt/sgqfSE9BUnAgMBAAEC" - + "gYAqnewGL2qLuVRIzDCPYXVg938zqyZmHsNYyDP+BhPGGcASX0FAFW/+dQ9hkjcAk0bOaBo17Fp3" - + "AXcxE/Lx/bHY+GWZ0wOJfl3aJBVJOpW8J6kwu68BUCmuFtRgbLSFu5+fbey3pKafYSptbX1fAI+z" - + "hTx+a9B8pnn79ad4ziJ2QQJBAM+YHPGAEbr5qcNkwyy0xZgR/TLlcW2NQUt8HZpmErdX6d328iBC" - + "SPb8+whXxCXZC3Mr+35IZ1pxxf0go/zGQv0CQQDI5oH0z1CKxoT6ErswNzB0oHxq/wD5mhutyqHa" - + "mxbG5G3fN7I2IclwaXEA2eutIKxFMQNZYsX5mNYsrveSKivzAkABiujUJpZ7JDXNvObyYxmAyslt" - + "4mSYYs9UZ0S1DAMhl6amPpqIANYX98NJyZUsjtNV9MK2qoUSF/xXqDFvxG1lAkBhP5Ow2Zn3U1mT" - + "Y/XQxSZjjjwr3vyt1neHjQsEMwa3iGPXJbLSmVBVZfUZoGOBDsvVQoCIiFOlGuKyBpA45MkZAkAH" - + "ksUrS9xLrDIUOI2BzMNRsK0bH7KJ+PFxm2SBgJOF9+Uf2A9LIP4IvESZq+ufp6c8YaqgR6Id1vws" - + "7rUyGoa5"); - - /** - * Statically configured public key. - */ - private static final byte[] publicKey_v1 = Utils.decodeBase64String( - "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCi6cKo2LWBMVYpVvx4NKIFr/cWfkx4+CM7yEi5" - + "T7SvFuilzHKrpH7sXWJp3T7/FA9Q23RxWEkorLXaqATd9r1UDkJyWcmJuOYtkwzldDNg7R7UemjD" - + "JQSxrXT3nmk/PJ0HyNTkY0Q8kdW8jMKJr5ion2WIGuhYy7f7IKn0hPQVJwIDAQAB"); - + return null; + } + + /** Statically configured private key. */ + private static final byte[] privateKey_v1 = + Utils.decodeBase64String( + "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAKLpwqjYtYExVilW/Hg0ogWv9xZ+" + + "THj4IzvISLlPtK8W6KXMcqukfuxdYmndPv8UD1DbdHFYSSistdqoBN32vVQOQnJZyYm45i2TDOV0" + + "M2DtHtR6aMMlBLGtdPeeaT88nQfI1ORjRDyR1byMwomvmKifZYga6FjLt/sgqfSE9BUnAgMBAAEC" + + "gYAqnewGL2qLuVRIzDCPYXVg938zqyZmHsNYyDP+BhPGGcASX0FAFW/+dQ9hkjcAk0bOaBo17Fp3" + + "AXcxE/Lx/bHY+GWZ0wOJfl3aJBVJOpW8J6kwu68BUCmuFtRgbLSFu5+fbey3pKafYSptbX1fAI+z" + + "hTx+a9B8pnn79ad4ziJ2QQJBAM+YHPGAEbr5qcNkwyy0xZgR/TLlcW2NQUt8HZpmErdX6d328iBC" + + "SPb8+whXxCXZC3Mr+35IZ1pxxf0go/zGQv0CQQDI5oH0z1CKxoT6ErswNzB0oHxq/wD5mhutyqHa" + + "mxbG5G3fN7I2IclwaXEA2eutIKxFMQNZYsX5mNYsrveSKivzAkABiujUJpZ7JDXNvObyYxmAyslt" + + "4mSYYs9UZ0S1DAMhl6amPpqIANYX98NJyZUsjtNV9MK2qoUSF/xXqDFvxG1lAkBhP5Ow2Zn3U1mT" + + "Y/XQxSZjjjwr3vyt1neHjQsEMwa3iGPXJbLSmVBVZfUZoGOBDsvVQoCIiFOlGuKyBpA45MkZAkAH" + + "ksUrS9xLrDIUOI2BzMNRsK0bH7KJ+PFxm2SBgJOF9+Uf2A9LIP4IvESZq+ufp6c8YaqgR6Id1vws" + + "7rUyGoa5"); + + /** Statically configured public key. */ + private static final byte[] publicKey_v1 = + Utils.decodeBase64String( + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCi6cKo2LWBMVYpVvx4NKIFr/cWfkx4+CM7yEi5" + + "T7SvFuilzHKrpH7sXWJp3T7/FA9Q23RxWEkorLXaqATd9r1UDkJyWcmJuOYtkwzldDNg7R7UemjD" + + "JQSxrXT3nmk/PJ0HyNTkY0Q8kdW8jMKJr5ion2WIGuhYy7f7IKn0hPQVJwIDAQAB"); } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/TestIOUtils.java b/src/test/java/com/amazonaws/encryptionsdk/internal/TestIOUtils.java index 792a8708b..e2681e8ec 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/TestIOUtils.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/TestIOUtils.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -26,101 +26,99 @@ import java.util.Random; public class TestIOUtils { - private static final SecureRandom rng_ = new SecureRandom(); + private static final SecureRandom rng_ = new SecureRandom(); - public static byte[] generateRandomPlaintext(final int size) { - return RandomBytesGenerator.generate(size); - } + public static byte[] generateRandomPlaintext(final int size) { + return RandomBytesGenerator.generate(size); + } - /** - * Generates and returns a string of the given {@code length}. - *

- * This function can be replaced by the RandomStringUtil class - * from Apache Commons. - *

- * This method re-implemented here to keep this library's dependency - * to a minimum which would reduce friction when it's consumed - * by other packages. - */ - public static String generateRandomString(final int size) { - StringBuilder sb = new StringBuilder(); - Random rand = new Random(); - for (int i = 0; i < size; i++) { - int c = rand.nextInt(Byte.MAX_VALUE); - c = c == 0 ? (c + 1) : c; - sb.append((char)c); - } - return sb.toString(); + /** + * Generates and returns a string of the given {@code length}. + * + *

This function can be replaced by the RandomStringUtil class from Apache Commons. + * + *

This method re-implemented here to keep this library's dependency to a minimum which would + * reduce friction when it's consumed by other packages. + */ + public static String generateRandomString(final int size) { + StringBuilder sb = new StringBuilder(); + Random rand = new Random(); + for (int i = 0; i < size; i++) { + int c = rand.nextInt(Byte.MAX_VALUE); + c = c == 0 ? (c + 1) : c; + sb.append((char) c); } + return sb.toString(); + } - public static byte[] computeFileDigest(final String fileName) throws IOException { - try { - final FileInputStream fis = new FileInputStream(fileName); - final MessageDigest md = MessageDigest.getInstance("SHA-256"); - final DigestInputStream dis = new DigestInputStream(fis, md); - - final int readLen = 128; - final byte[] readBytes = new byte[readLen]; - while (dis.read(readBytes) != -1) { - } - dis.close(); + public static byte[] computeFileDigest(final String fileName) throws IOException { + try { + final FileInputStream fis = new FileInputStream(fileName); + final MessageDigest md = MessageDigest.getInstance("SHA-256"); + final DigestInputStream dis = new DigestInputStream(fis, md); - return md.digest(); - } catch (NoSuchAlgorithmException e) { - // shouldn't get here since we hardcode the algorithm. - } + final int readLen = 128; + final byte[] readBytes = new byte[readLen]; + while (dis.read(readBytes) != -1) {} + dis.close(); - return null; + return md.digest(); + } catch (NoSuchAlgorithmException e) { + // shouldn't get here since we hardcode the algorithm. } - public static byte[] getSha256Hash(final byte[] input) { - MessageDigest md = null; - try { - md = MessageDigest.getInstance("SHA-256"); - } catch (NoSuchAlgorithmException e) { - // should never get here. - } - return md.digest(input); - } + return null; + } - public static void generateFile(final String fileName, final long fileSize) throws IOException { - final FileOutputStream fs = new FileOutputStream(fileName); - final byte[] fileBytes = new byte[(int) fileSize]; - rng_.nextBytes(fileBytes); - fs.write(fileBytes); - fs.close(); + public static byte[] getSha256Hash(final byte[] input) { + MessageDigest md = null; + try { + md = MessageDigest.getInstance("SHA-256"); + } catch (NoSuchAlgorithmException e) { + // should never get here. } + return md.digest(input); + } - public static void copyInStreamToOutStream(final InputStream inStream, final OutputStream outStream, - final int readLen) throws IOException { - final byte[] readBuffer = new byte[readLen]; - int actualRead = 0; - while (actualRead >= 0) { - outStream.write(readBuffer, 0, actualRead); - actualRead = inStream.read(readBuffer); - } - inStream.close(); - outStream.close(); + public static void generateFile(final String fileName, final long fileSize) throws IOException { + final FileOutputStream fs = new FileOutputStream(fileName); + final byte[] fileBytes = new byte[(int) fileSize]; + rng_.nextBytes(fileBytes); + fs.write(fileBytes); + fs.close(); + } + + public static void copyInStreamToOutStream( + final InputStream inStream, final OutputStream outStream, final int readLen) + throws IOException { + final byte[] readBuffer = new byte[readLen]; + int actualRead = 0; + while (actualRead >= 0) { + outStream.write(readBuffer, 0, actualRead); + actualRead = inStream.read(readBuffer); } + inStream.close(); + outStream.close(); + } - public static void deleteDir(final File filePath) { - if (filePath.exists()) { - File[] files = filePath.listFiles(); - for (int i = 0; i < files.length; i++) { - if (files[i].isFile()) { - files[i].delete(); - } else { - deleteDir(files[i]); - } - } + public static void deleteDir(final File filePath) { + if (filePath.exists()) { + File[] files = filePath.listFiles(); + for (int i = 0; i < files.length; i++) { + if (files[i].isFile()) { + files[i].delete(); + } else { + deleteDir(files[i]); } - - filePath.delete(); + } } - public static void copyInStreamToOutStream(final InputStream inStream, final OutputStream outStream) - throws IOException { - final int readLen = 1024; // 1KB - copyInStreamToOutStream(inStream, outStream, readLen); - } + filePath.delete(); + } + + public static void copyInStreamToOutStream( + final InputStream inStream, final OutputStream outStream) throws IOException { + final int readLen = 1024; // 1KB + copyInStreamToOutStream(inStream, outStream, readLen); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/TrailingSignatureAlgorithmTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/TrailingSignatureAlgorithmTest.java index babce1455..65103685b 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/TrailingSignatureAlgorithmTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/TrailingSignatureAlgorithmTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,10 +13,11 @@ package com.amazonaws.encryptionsdk.internal; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.TestUtils; -import org.junit.Test; - import java.math.BigInteger; import java.security.AlgorithmParameters; import java.security.KeyFactory; @@ -26,145 +27,177 @@ import java.security.spec.ECParameterSpec; import java.security.spec.ECPoint; import java.security.spec.ECPublicKeySpec; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import org.junit.Test; public class TrailingSignatureAlgorithmTest { - private static final int[] secp256r1PublicFixture_X = new int[] { - 163, 132, 202, 41, 50, 135, 193, 159, 67, 19, 186, - 212, 0, 129, 16, 182, 186, 176, 124, 94, 242, 139, - 48, 143, 158, 96, 51, 133, 188, 144, 137, 148}; - - private static final int[] secp256r1PublicFixture_Y = new int[] { - 71, 234, 253, 112, 131, 106, 243, 169, 143, 58, 39, - 222, 47, 211, 230, 90, 139, 163, 54, 249, 187, 115, - 209, 203, 239, 98, 26, 47, 101, 213, 140, 212}; - - private static final int[] secp2561CompressedFixture = new int[] { - 2, - 163, 132, 202, 41, 50, 135, 193, 159, 67, 19, 186, - 212, 0, 129, 16, 182, 186, 176, 124, 94, 242, 139, - 48, 143, 158, 96, 51, 133, 188, 144, 137, 148}; - - private static final int[] secp384r1PublicFixture_X = new int[] { - 207, 62, 215, 143, 116, 128, 174, 103, 1, 81, 127, - 212, 163, 19, 165, 220, 74, 144, 26, 59, 87, 0, - 214, 47, 66, 73, 152, 227, 196, 81, 14, 28, 58, - 221, 178, 63, 150, 119, 62, 195, 99, 63, 60, 42, - 223, 207, 28, 65}; - - private static final int[] secp384r1PublicFixture_Y = new int[] { - 180, 143, 190, 5, 150, 247, 225, 240, 153, 150, 119, - 109, 210, 243, 151, 206, 217, 120, 2, 171, 75, - 180, 31, 4, 91, 78, 206, 217, 241, 119, 55, 230, - 216, 23, 237, 101, 21, 89, 132, 84, 100, 3, 255, - 90, 197, 237, 139, 209}; - - private static final int[] secp384r1CompressedFixture = new int[] { - 3, - 207, 62, 215, 143, 116, 128, 174, 103, 1, 81, 127, - 212, 163, 19, 165, 220, 74, 144, 26, 59, 87, 0, - 214, 47, 66, 73, 152, 227, 196, 81, 14, 28, 58, - 221, 178, 63, 150, 119, 62, 195, 99, 63, 60, 42, - 223, 207, 28, 65 - }; - - - @Test - public void serializationEquality() throws Exception { - CryptoAlgorithm algorithm = CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256; - - PublicKey publicKey = TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).generateKey().getPublic(); - - String serializedPublicKey = TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).serializePublicKey(publicKey); - PublicKey deserializedPublicKey = TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).deserializePublicKey(serializedPublicKey); - - assertEquals(publicKey, deserializedPublicKey); - - algorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; - - publicKey = TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).generateKey().getPublic(); - - serializedPublicKey = TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).serializePublicKey(publicKey); - deserializedPublicKey = TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).deserializePublicKey(serializedPublicKey); - - assertEquals(publicKey, deserializedPublicKey); - } - - @Test - public void deserializeSecp384() { - testDeserialization(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - secp384r1CompressedFixture, secp384r1PublicFixture_X, secp384r1PublicFixture_Y); - } - - @Test - public void serializeSecp384() throws Exception { - testSerialization(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - "secp384r1", secp384r1PublicFixture_X, secp384r1PublicFixture_Y, secp384r1CompressedFixture); - } - - @Test - public void deserializeSecp256() { - testDeserialization(CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, - secp2561CompressedFixture, secp256r1PublicFixture_X, secp256r1PublicFixture_Y); - } - - @Test - public void serializeSecp256() throws Exception { - testSerialization(CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, - "secp256r1", secp256r1PublicFixture_X, secp256r1PublicFixture_Y, secp2561CompressedFixture); - } - - @Test(expected = IllegalArgumentException.class) - public void testBadPoint() { - byte[] bytes = TestUtils.unsignedBytesToSignedBytes(secp384r1CompressedFixture); - bytes[20]++; - - String publicKey = Utils.encodeBase64String(bytes); - - TrailingSignatureAlgorithm - .forCryptoAlgorithm(CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384) - .deserializePublicKey(publicKey); - } - - private void testSerialization(CryptoAlgorithm algorithm, String curveName, int[] x, int[] y, int[] expected) throws Exception { - byte[] xBytes = TestUtils.unsignedBytesToSignedBytes(x); - byte[] yBytes = TestUtils.unsignedBytesToSignedBytes(y); - - final AlgorithmParameters parameters = AlgorithmParameters.getInstance("EC"); - parameters.init(new ECGenParameterSpec(curveName)); - ECParameterSpec ecParameterSpec = parameters.getParameterSpec(ECParameterSpec.class); - - PublicKey publicKey = KeyFactory.getInstance("EC").generatePublic( - new ECPublicKeySpec(new ECPoint(new BigInteger(1, xBytes), new BigInteger(1, yBytes)), ecParameterSpec)); - - int[] result = TestUtils.signedBytesToUnsignedBytes(Utils.decodeBase64String(TrailingSignatureAlgorithm - .forCryptoAlgorithm(algorithm) - .serializePublicKey(publicKey))); - - assertArrayEquals(expected, result); - } - - private void testDeserialization(CryptoAlgorithm algorithm, int[] compressedKey, int[] expectedX, int[] expectedY) { - byte[] bytes = TestUtils.unsignedBytesToSignedBytes(compressedKey); - - String publicKey = Utils.encodeBase64String(bytes); - - PublicKey publicKeyDeserialized = TrailingSignatureAlgorithm - .forCryptoAlgorithm(algorithm) - .deserializePublicKey(publicKey); - - ECPublicKey desKey = (ECPublicKey) publicKeyDeserialized; - - BigInteger x = desKey.getW().getAffineX(); - BigInteger y = desKey.getW().getAffineY(); - - BigInteger expectedXBigInteger = new BigInteger(1, TestUtils.unsignedBytesToSignedBytes(expectedX)); - BigInteger expectedYBigInteger = new BigInteger(1, TestUtils.unsignedBytesToSignedBytes(expectedY)); - - assertEquals(expectedXBigInteger, x); - assertEquals(expectedYBigInteger, y); - } + private static final int[] secp256r1PublicFixture_X = + new int[] { + 163, 132, 202, 41, 50, 135, 193, 159, 67, 19, 186, + 212, 0, 129, 16, 182, 186, 176, 124, 94, 242, 139, + 48, 143, 158, 96, 51, 133, 188, 144, 137, 148 + }; + + private static final int[] secp256r1PublicFixture_Y = + new int[] { + 71, 234, 253, 112, 131, 106, 243, 169, 143, 58, 39, + 222, 47, 211, 230, 90, 139, 163, 54, 249, 187, 115, + 209, 203, 239, 98, 26, 47, 101, 213, 140, 212 + }; + + private static final int[] secp2561CompressedFixture = + new int[] { + 2, 163, 132, 202, 41, 50, 135, 193, 159, 67, 19, 186, 212, 0, 129, 16, 182, 186, 176, 124, + 94, 242, 139, 48, 143, 158, 96, 51, 133, 188, 144, 137, 148 + }; + + private static final int[] secp384r1PublicFixture_X = + new int[] { + 207, 62, 215, 143, 116, 128, 174, 103, 1, 81, 127, + 212, 163, 19, 165, 220, 74, 144, 26, 59, 87, 0, + 214, 47, 66, 73, 152, 227, 196, 81, 14, 28, 58, + 221, 178, 63, 150, 119, 62, 195, 99, 63, 60, 42, + 223, 207, 28, 65 + }; + + private static final int[] secp384r1PublicFixture_Y = + new int[] { + 180, 143, 190, 5, 150, 247, 225, 240, 153, 150, 119, 109, 210, 243, 151, 206, 217, 120, 2, + 171, 75, 180, 31, 4, 91, 78, 206, 217, 241, 119, 55, 230, 216, 23, 237, 101, 21, 89, 132, + 84, 100, 3, 255, 90, 197, 237, 139, 209 + }; + + private static final int[] secp384r1CompressedFixture = + new int[] { + 3, 207, 62, 215, 143, 116, 128, 174, 103, 1, 81, 127, 212, 163, 19, 165, 220, 74, 144, 26, + 59, 87, 0, 214, 47, 66, 73, 152, 227, 196, 81, 14, 28, 58, 221, 178, 63, 150, 119, 62, 195, + 99, 63, 60, 42, 223, 207, 28, 65 + }; + + @Test + public void serializationEquality() throws Exception { + CryptoAlgorithm algorithm = CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256; + + PublicKey publicKey = + TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).generateKey().getPublic(); + + String serializedPublicKey = + TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).serializePublicKey(publicKey); + PublicKey deserializedPublicKey = + TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm) + .deserializePublicKey(serializedPublicKey); + + assertEquals(publicKey, deserializedPublicKey); + + algorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + + publicKey = TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).generateKey().getPublic(); + + serializedPublicKey = + TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).serializePublicKey(publicKey); + deserializedPublicKey = + TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm) + .deserializePublicKey(serializedPublicKey); + + assertEquals(publicKey, deserializedPublicKey); + } + + @Test + public void deserializeSecp384() { + testDeserialization( + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + secp384r1CompressedFixture, + secp384r1PublicFixture_X, + secp384r1PublicFixture_Y); + } + + @Test + public void serializeSecp384() throws Exception { + testSerialization( + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + "secp384r1", + secp384r1PublicFixture_X, + secp384r1PublicFixture_Y, + secp384r1CompressedFixture); + } + + @Test + public void deserializeSecp256() { + testDeserialization( + CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, + secp2561CompressedFixture, + secp256r1PublicFixture_X, + secp256r1PublicFixture_Y); + } + + @Test + public void serializeSecp256() throws Exception { + testSerialization( + CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, + "secp256r1", + secp256r1PublicFixture_X, + secp256r1PublicFixture_Y, + secp2561CompressedFixture); + } + + @Test(expected = IllegalArgumentException.class) + public void testBadPoint() { + byte[] bytes = TestUtils.unsignedBytesToSignedBytes(secp384r1CompressedFixture); + bytes[20]++; + + String publicKey = Utils.encodeBase64String(bytes); + + TrailingSignatureAlgorithm.forCryptoAlgorithm( + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384) + .deserializePublicKey(publicKey); + } + + private void testSerialization( + CryptoAlgorithm algorithm, String curveName, int[] x, int[] y, int[] expected) + throws Exception { + byte[] xBytes = TestUtils.unsignedBytesToSignedBytes(x); + byte[] yBytes = TestUtils.unsignedBytesToSignedBytes(y); + + final AlgorithmParameters parameters = AlgorithmParameters.getInstance("EC"); + parameters.init(new ECGenParameterSpec(curveName)); + ECParameterSpec ecParameterSpec = parameters.getParameterSpec(ECParameterSpec.class); + + PublicKey publicKey = + KeyFactory.getInstance("EC") + .generatePublic( + new ECPublicKeySpec( + new ECPoint(new BigInteger(1, xBytes), new BigInteger(1, yBytes)), + ecParameterSpec)); + + int[] result = + TestUtils.signedBytesToUnsignedBytes( + Utils.decodeBase64String( + TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm) + .serializePublicKey(publicKey))); + + assertArrayEquals(expected, result); + } + + private void testDeserialization( + CryptoAlgorithm algorithm, int[] compressedKey, int[] expectedX, int[] expectedY) { + byte[] bytes = TestUtils.unsignedBytesToSignedBytes(compressedKey); + + String publicKey = Utils.encodeBase64String(bytes); + + PublicKey publicKeyDeserialized = + TrailingSignatureAlgorithm.forCryptoAlgorithm(algorithm).deserializePublicKey(publicKey); + + ECPublicKey desKey = (ECPublicKey) publicKeyDeserialized; + + BigInteger x = desKey.getW().getAffineX(); + BigInteger y = desKey.getW().getAffineY(); + + BigInteger expectedXBigInteger = + new BigInteger(1, TestUtils.unsignedBytesToSignedBytes(expectedX)); + BigInteger expectedYBigInteger = + new BigInteger(1, TestUtils.unsignedBytesToSignedBytes(expectedY)); + + assertEquals(expectedXBigInteger, x); + assertEquals(expectedYBigInteger, y); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/UtilsTest.java b/src/test/java/com/amazonaws/encryptionsdk/internal/UtilsTest.java index 1a002f417..16f18f824 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/UtilsTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/internal/UtilsTest.java @@ -11,129 +11,124 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; - import org.junit.Test; -/** - * Unit tests for {@link Utils} - */ +/** Unit tests for {@link Utils} */ public class UtilsTest { - @Test - public void compareObjectIdentityTest() { - assertNotEquals(0, Utils.compareObjectIdentity(null, new Object())); - assertNotEquals(0, Utils.compareObjectIdentity(new Object(), null)); - - assertEquals(0, Utils.compareObjectIdentity(Utils.class, Utils.class)); - assertNotEquals(0, Utils.compareObjectIdentity(new Object(), new Object())); + @Test + public void compareObjectIdentityTest() { + assertNotEquals(0, Utils.compareObjectIdentity(null, new Object())); + assertNotEquals(0, Utils.compareObjectIdentity(new Object(), null)); + + assertEquals(0, Utils.compareObjectIdentity(Utils.class, Utils.class)); + assertNotEquals(0, Utils.compareObjectIdentity(new Object(), new Object())); + } + + @Test + public void compareObjectIdentity_handlesHashCodeCollisions() { + // With this large of an array, it is overwhelmingly likely that we will see two objects with + // identical + // identity hash codes. + Object[] testArray = new Object[512_000]; + + for (int i = 0; i < testArray.length; i++) { + testArray[i] = new Object(); } - @Test - public void compareObjectIdentity_handlesHashCodeCollisions() { - // With this large of an array, it is overwhelmingly likely that we will see two objects with identical - // identity hash codes. - Object[] testArray = new Object[512_000]; - - for (int i = 0; i < testArray.length; i++) { - testArray[i] = new Object(); - } - - java.util.Arrays.sort(testArray, Utils::compareObjectIdentity); + java.util.Arrays.sort(testArray, Utils::compareObjectIdentity); - // Verify that we do not have any objects that are equal (compare to zero) in the array. - // We know the primary sort is by hashcode, so we'll just do exhaustive comparison within each hashcode. + // Verify that we do not have any objects that are equal (compare to zero) in the array. + // We know the primary sort is by hashcode, so we'll just do exhaustive comparison within each + // hashcode. - boolean sawCollison = false; - for (int i = 0; i ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(1, ct.getMasterKeyIds().size()); - final CryptoResult result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk1, result.getMasterKeys().get(0)); - } - - @Test - public void singleKeyOaepSha1() throws Exception { - addEntry("key1"); - final KeyStoreProvider mkp = new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-1AndMGF1Padding", - "key1"); - final JceMasterKey mk1 = mkp.getMasterKey("key1"); - final AwsCrypto crypto = AwsCrypto.standard(); - final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(1, ct.getMasterKeyIds().size()); - final CryptoResult result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk1, result.getMasterKeys().get(0)); - } - - @Test - public void singleKeyOaepSha256() throws Exception { - addEntry("key1"); - final KeyStoreProvider mkp = new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - "key1"); - final JceMasterKey mk1 = mkp.getMasterKey("key1"); - final AwsCrypto crypto = AwsCrypto.standard(); - final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(1, ct.getMasterKeyIds().size()); - final CryptoResult result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk1, result.getMasterKeys().get(0)); - } - - @Test - public void multipleKeys() throws Exception { - addEntry("key1"); - addEntry("key2"); - final KeyStoreProvider mkp = new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - "key1", - "key2"); - @SuppressWarnings("unused") - final JceMasterKey mk1 = mkp.getMasterKey("key1"); - final JceMasterKey mk2 = mkp.getMasterKey("key2"); - final AwsCrypto crypto = AwsCrypto.standard(); - final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(2, ct.getMasterKeyIds().size()); - CryptoResult result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Order is non-deterministic - assertEquals(1, result.getMasterKeys().size()); - - // Delete the first key and see if it works - ks.deleteEntry("key1"); - result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk2, result.getMasterKeys().get(0)); - } - - @Test(expected = CannotUnwrapDataKeyException.class) - public void encryptOnly() throws Exception { - addPublicEntry("key1"); - final KeyStoreProvider mkp = new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - "key1"); - final AwsCrypto crypto = AwsCrypto.standard(); - final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(1, ct.getMasterKeyIds().size()); - crypto.decryptData(mkp, ct.getResult()); - } - - @Test - public void escrowAndSymmetric() throws Exception { - addPublicEntry("key1"); - addEntry("key2"); - final KeyStoreProvider mkp = new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - "key1", - "key2"); - @SuppressWarnings("unused") - final JceMasterKey mk1 = mkp.getMasterKey("key1"); - final JceMasterKey mk2 = mkp.getMasterKey("key2"); - final AwsCrypto crypto = AwsCrypto.standard(); - final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(2, ct.getMasterKeyIds().size()); - CryptoResult result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only could have decrypted with the keypair - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk2, result.getMasterKeys().get(0)); - - // Delete the first key and see if it works - ks.deleteEntry("key1"); - result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk2, result.getMasterKeys().get(0)); - } - - @Test - public void escrowAndSymmetricSecondProvider() throws GeneralSecurityException, IOException { - addPublicEntry("key1"); - addEntry("key2"); - final KeyStoreProvider mkp = new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - "key1", - "key2"); - @SuppressWarnings("unused") - final JceMasterKey mk1 = mkp.getMasterKey("key1"); - final JceMasterKey mk2 = mkp.getMasterKey("key2"); - final AwsCrypto crypto = AwsCrypto.standard(); - final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(2, ct.getMasterKeyIds().size()); - - final KeyStoreProvider mkp2 = new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - "key1"); - CryptoResult result = crypto.decryptData(mkp2, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only could have decrypted with the keypair - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk2, result.getMasterKeys().get(0)); - } - - @Test - public void escrowCase() throws GeneralSecurityException, IOException { - addEntry("escrowKey"); - KeyStore ks2 = KeyStore.getInstance(KeyStore.getDefaultType()); - ks2.load(null, PASSWORD); - copyPublicPart(ks, ks2, "escrowKey"); - - final KeyStoreProvider mkp = new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - "escrowKey"); - final KeyStoreProvider escrowProvider = new KeyStoreProvider(ks2, PP, "KeyStore", - "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - "escrowKey"); - - final JceMasterKey mk1 = escrowProvider.getMasterKey("escrowKey"); - final AwsCrypto crypto = AwsCrypto.standard(); - final CryptoResult ct = crypto.encryptData(escrowProvider, PLAINTEXT); - assertEquals(1, ct.getMasterKeyIds().size()); - - try { - crypto.decryptData(escrowProvider, ct.getResult()); - fail("Expected CannotUnwrapDataKeyException"); - } catch (final CannotUnwrapDataKeyException ex) { - // expected - } - CryptoResult result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only could have decrypted with the keypair - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk1, result.getMasterKeys().get(0)); - } - - @Test - public void keystoreAndRawProvider() throws GeneralSecurityException, IOException { - addEntry("key1"); - final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); - final JceMasterKey jcep = JceMasterKey.getInstance(k1, "jce", "1", "AES/GCM/NoPadding"); - final KeyStoreProvider ksp = new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", - "key1"); - - MasterKeyProvider multiProvider = MultipleProviderFactory.buildMultiProvider(JceMasterKey.class, - jcep, ksp); - - assertEquals(jcep, multiProvider.getMasterKey("jce", "1")); - - final AwsCrypto crypto = AwsCrypto.standard(); - final CryptoResult ct = crypto.encryptData(multiProvider, PLAINTEXT); - assertEquals(2, ct.getMasterKeyIds().size()); - CryptoResult result = crypto.decryptData(multiProvider, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - assertEquals(jcep, result.getMasterKeys().get(0)); - - // Decrypt just using each individually - assertArrayEquals(PLAINTEXT, crypto.decryptData(jcep, ct.getResult()).getResult()); - assertArrayEquals(PLAINTEXT, crypto.decryptData(ksp, ct.getResult()).getResult()); - } - - private void addEntry(final String alias) throws GeneralSecurityException, IOException { - final KeyPair pair = KG.generateKeyPair(); - ks.setEntry(alias, new KeyStore.PrivateKeyEntry(pair.getPrivate(), - new X509Certificate[] { generateCertificate(pair, alias) }), PP); - } - - private void addPublicEntry(final String alias) throws GeneralSecurityException, IOException { - final KeyPair pair = KG.generateKeyPair(); - ks.setEntry(alias, new KeyStore.TrustedCertificateEntry(generateCertificate(pair, alias)), null); - } +/* These internal sun classes are included solely for test purposes as +this test cannot use BouncyCastle cert generation, as there are incompatibilities +between how standard BC and FIPS BC perform cert generation. */ - private X509Certificate generateCertificate(final KeyPair pair, final String alias) throws GeneralSecurityException, IOException { - final X509CertInfo info = new X509CertInfo(); - final X500Name name = new X500Name("dc=" + alias); - info.set(X509CertInfo.SERIAL_NUMBER, new CertificateSerialNumber(new BigInteger(256, RND))); - info.set(X509CertInfo.SUBJECT, name); - info.set(X509CertInfo.ISSUER, name); - info.set(X509CertInfo.VALIDITY, - new CertificateValidity(Date.from(Instant.now().minus(1, ChronoUnit.DAYS)), - Date.from(Instant.now().plus(730, ChronoUnit.DAYS)))); - info.set(X509CertInfo.KEY, new CertificateX509Key(pair.getPublic())); - info.set(X509CertInfo.ALGORITHM_ID, - new CertificateAlgorithmId(new AlgorithmId(AlgorithmId.sha256WithRSAEncryption_oid))); - - final X509CertImpl cert = new X509CertImpl(info); - cert.sign(pair.getPrivate(), AlgorithmId.sha256WithRSAEncryption_oid.toString()); - - return cert; +public class KeyStoreProviderTest { + private static final SecureRandom RND = new SecureRandom(); + private static final KeyPairGenerator KG; + private static final byte[] PLAINTEXT = generate(1024); + private static final char[] PASSWORD = "Password".toCharArray(); + private static final KeyStore.PasswordProtection PP = new PasswordProtection(PASSWORD); + private KeyStore ks; + + static { + try { + KG = KeyPairGenerator.getInstance("RSA"); + KG.initialize(2048); + } catch (Exception ex) { + throw new RuntimeException(ex); } - - private void copyPublicPart(final KeyStore src, final KeyStore dst, final String alias) throws KeyStoreException { - Certificate cert = src.getCertificate(alias); - dst.setCertificateEntry(alias, cert); + } + + @Before + public void setup() throws Exception { + ks = KeyStore.getInstance(KeyStore.getDefaultType()); + ks.load(null, PASSWORD); + } + + @Test + public void singleKeyPkcs1() throws Exception { + addEntry("key1"); + final KeyStoreProvider mkp = + new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/PKCS1Padding", "key1"); + final JceMasterKey mk1 = mkp.getMasterKey("key1"); + final AwsCrypto crypto = AwsCrypto.standard(); + final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(1, ct.getMasterKeyIds().size()); + final CryptoResult result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk1, result.getMasterKeys().get(0)); + } + + @Test + public void singleKeyOaepSha1() throws Exception { + addEntry("key1"); + final KeyStoreProvider mkp = + new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-1AndMGF1Padding", "key1"); + final JceMasterKey mk1 = mkp.getMasterKey("key1"); + final AwsCrypto crypto = AwsCrypto.standard(); + final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(1, ct.getMasterKeyIds().size()); + final CryptoResult result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk1, result.getMasterKeys().get(0)); + } + + @Test + public void singleKeyOaepSha256() throws Exception { + addEntry("key1"); + final KeyStoreProvider mkp = + new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", "key1"); + final JceMasterKey mk1 = mkp.getMasterKey("key1"); + final AwsCrypto crypto = AwsCrypto.standard(); + final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(1, ct.getMasterKeyIds().size()); + final CryptoResult result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk1, result.getMasterKeys().get(0)); + } + + @Test + public void multipleKeys() throws Exception { + addEntry("key1"); + addEntry("key2"); + final KeyStoreProvider mkp = + new KeyStoreProvider( + ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", "key1", "key2"); + @SuppressWarnings("unused") + final JceMasterKey mk1 = mkp.getMasterKey("key1"); + final JceMasterKey mk2 = mkp.getMasterKey("key2"); + final AwsCrypto crypto = AwsCrypto.standard(); + final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(2, ct.getMasterKeyIds().size()); + CryptoResult result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Order is non-deterministic + assertEquals(1, result.getMasterKeys().size()); + + // Delete the first key and see if it works + ks.deleteEntry("key1"); + result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk2, result.getMasterKeys().get(0)); + } + + @Test(expected = CannotUnwrapDataKeyException.class) + public void encryptOnly() throws Exception { + addPublicEntry("key1"); + final KeyStoreProvider mkp = + new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", "key1"); + final AwsCrypto crypto = AwsCrypto.standard(); + final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(1, ct.getMasterKeyIds().size()); + crypto.decryptData(mkp, ct.getResult()); + } + + @Test + public void escrowAndSymmetric() throws Exception { + addPublicEntry("key1"); + addEntry("key2"); + final KeyStoreProvider mkp = + new KeyStoreProvider( + ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", "key1", "key2"); + @SuppressWarnings("unused") + final JceMasterKey mk1 = mkp.getMasterKey("key1"); + final JceMasterKey mk2 = mkp.getMasterKey("key2"); + final AwsCrypto crypto = AwsCrypto.standard(); + final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(2, ct.getMasterKeyIds().size()); + CryptoResult result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only could have decrypted with the keypair + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk2, result.getMasterKeys().get(0)); + + // Delete the first key and see if it works + ks.deleteEntry("key1"); + result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk2, result.getMasterKeys().get(0)); + } + + @Test + public void escrowAndSymmetricSecondProvider() throws GeneralSecurityException, IOException { + addPublicEntry("key1"); + addEntry("key2"); + final KeyStoreProvider mkp = + new KeyStoreProvider( + ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", "key1", "key2"); + @SuppressWarnings("unused") + final JceMasterKey mk1 = mkp.getMasterKey("key1"); + final JceMasterKey mk2 = mkp.getMasterKey("key2"); + final AwsCrypto crypto = AwsCrypto.standard(); + final CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(2, ct.getMasterKeyIds().size()); + + final KeyStoreProvider mkp2 = + new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", "key1"); + CryptoResult result = crypto.decryptData(mkp2, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only could have decrypted with the keypair + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk2, result.getMasterKeys().get(0)); + } + + @Test + public void escrowCase() throws GeneralSecurityException, IOException { + addEntry("escrowKey"); + KeyStore ks2 = KeyStore.getInstance(KeyStore.getDefaultType()); + ks2.load(null, PASSWORD); + copyPublicPart(ks, ks2, "escrowKey"); + + final KeyStoreProvider mkp = + new KeyStoreProvider( + ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", "escrowKey"); + final KeyStoreProvider escrowProvider = + new KeyStoreProvider( + ks2, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", "escrowKey"); + + final JceMasterKey mk1 = escrowProvider.getMasterKey("escrowKey"); + final AwsCrypto crypto = AwsCrypto.standard(); + final CryptoResult ct = crypto.encryptData(escrowProvider, PLAINTEXT); + assertEquals(1, ct.getMasterKeyIds().size()); + + try { + crypto.decryptData(escrowProvider, ct.getResult()); + fail("Expected CannotUnwrapDataKeyException"); + } catch (final CannotUnwrapDataKeyException ex) { + // expected } + CryptoResult result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only could have decrypted with the keypair + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk1, result.getMasterKeys().get(0)); + } + + @Test + public void keystoreAndRawProvider() throws GeneralSecurityException, IOException { + addEntry("key1"); + final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); + final JceMasterKey jcep = JceMasterKey.getInstance(k1, "jce", "1", "AES/GCM/NoPadding"); + final KeyStoreProvider ksp = + new KeyStoreProvider(ks, PP, "KeyStore", "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", "key1"); + + MasterKeyProvider multiProvider = + MultipleProviderFactory.buildMultiProvider(JceMasterKey.class, jcep, ksp); + + assertEquals(jcep, multiProvider.getMasterKey("jce", "1")); + + final AwsCrypto crypto = AwsCrypto.standard(); + final CryptoResult ct = crypto.encryptData(multiProvider, PLAINTEXT); + assertEquals(2, ct.getMasterKeyIds().size()); + CryptoResult result = crypto.decryptData(multiProvider, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + assertEquals(jcep, result.getMasterKeys().get(0)); + + // Decrypt just using each individually + assertArrayEquals(PLAINTEXT, crypto.decryptData(jcep, ct.getResult()).getResult()); + assertArrayEquals(PLAINTEXT, crypto.decryptData(ksp, ct.getResult()).getResult()); + } + + private void addEntry(final String alias) throws GeneralSecurityException, IOException { + final KeyPair pair = KG.generateKeyPair(); + ks.setEntry( + alias, + new KeyStore.PrivateKeyEntry( + pair.getPrivate(), new X509Certificate[] {generateCertificate(pair, alias)}), + PP); + } + + private void addPublicEntry(final String alias) throws GeneralSecurityException, IOException { + final KeyPair pair = KG.generateKeyPair(); + ks.setEntry( + alias, new KeyStore.TrustedCertificateEntry(generateCertificate(pair, alias)), null); + } + + private X509Certificate generateCertificate(final KeyPair pair, final String alias) + throws GeneralSecurityException, IOException { + final X509CertInfo info = new X509CertInfo(); + final X500Name name = new X500Name("dc=" + alias); + info.set(X509CertInfo.SERIAL_NUMBER, new CertificateSerialNumber(new BigInteger(256, RND))); + info.set(X509CertInfo.SUBJECT, name); + info.set(X509CertInfo.ISSUER, name); + info.set( + X509CertInfo.VALIDITY, + new CertificateValidity( + Date.from(Instant.now().minus(1, ChronoUnit.DAYS)), + Date.from(Instant.now().plus(730, ChronoUnit.DAYS)))); + info.set(X509CertInfo.KEY, new CertificateX509Key(pair.getPublic())); + info.set( + X509CertInfo.ALGORITHM_ID, + new CertificateAlgorithmId(new AlgorithmId(AlgorithmId.sha256WithRSAEncryption_oid))); + + final X509CertImpl cert = new X509CertImpl(info); + cert.sign(pair.getPrivate(), AlgorithmId.sha256WithRSAEncryption_oid.toString()); + + return cert; + } + + private void copyPublicPart(final KeyStore src, final KeyStore dst, final String alias) + throws KeyStoreException { + Certificate cert = src.getCertificate(alias); + dst.setCertificateEntry(alias, cert); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyProviderTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyProviderTest.java index 3bc65a63c..f896dce35 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyProviderTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyProviderTest.java @@ -3,6 +3,12 @@ package com.amazonaws.encryptionsdk.kms; +import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.parseInfoFromKeyArn; +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.spy; + import com.amazonaws.AmazonServiceException; import com.amazonaws.auth.BasicAWSCredentials; import com.amazonaws.encryptionsdk.*; @@ -10,876 +16,840 @@ import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; - -import static com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo.parseInfoFromKeyArn; - import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier; import com.amazonaws.encryptionsdk.model.KeyBlob; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClientBuilder; import com.amazonaws.services.kms.model.DecryptRequest; import com.amazonaws.services.kms.model.DecryptResult; -import com.amazonaws.services.kms.model.GenerateDataKeyRequest; -import org.junit.Test; -import org.junit.experimental.runners.Enclosed; -import org.junit.jupiter.api.DisplayName; -import org.junit.runner.RunWith; -import org.mockito.ArgumentCaptor; - import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.ConcurrentHashMap; - -import static org.junit.Assert.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; -import static org.mockito.Mockito.spy; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.jupiter.api.DisplayName; +import org.junit.runner.RunWith; @RunWith(Enclosed.class) public class AwsKmsMrkAwareMasterKeyProviderTest { - static public class getResourceForResourceTypeKey { - @Test - @DisplayName("Postcondition: Return the key id.") - public void basic_use() { - assertEquals( - "mrk-edb7fe6942894d32ac46dbb1c922d574", - AwsKmsMrkAwareMasterKeyProvider - .getResourceForResourceTypeKey("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574")); - } - - @Test - @DisplayName("Check for early return (Postcondition): Non-ARNs may be raw resources.") - public void not_an_arn() { - assertEquals( - "mrk-edb7fe6942894d32ac46dbb1c922d574", - AwsKmsMrkAwareMasterKeyProvider - .getResourceForResourceTypeKey("mrk-edb7fe6942894d32ac46dbb1c922d574")); - final String malformed = "aws:kms:us-west-2::key/garbage"; - assertEquals( - malformed, - AwsKmsMrkAwareMasterKeyProvider - .getResourceForResourceTypeKey(malformed)); - } - - @Test - @DisplayName("Check for early return (Postcondition): Return the identifier for non-key resource types.") - public void not_a_key() { - final String alias = "arn:aws:kms:us-west-2:658956600833:alias/EncryptDecrypt"; - assertEquals( - alias, - AwsKmsMrkAwareMasterKeyProvider - .getResourceForResourceTypeKey(alias)); - } - } - - static public class assertMrksAreUnique { - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 - //= type=test - //# The caller MUST provide: - public void basic_use() { - AwsKmsMrkAwareMasterKeyProvider - .assertMrksAreUnique(Arrays.asList( - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - )); - } - - @Test - public void no_duplicates() { - //= compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 - //= type=test - //# If there are zero duplicate resource ids between the multi-region - //# keys, this function MUST exit successfully - AwsKmsMrkAwareMasterKeyProvider - .assertMrksAreUnique(Arrays.asList( - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f" - )); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 - //= type=test - //# If the list does not contain any multi-Region keys (aws-kms-key- - //# arn.md#identifying-an-aws-kms-multi-region-key) this function MUST - //# exit successfully. - public void no_mrks_at_all() { - AwsKmsMrkAwareMasterKeyProvider - .assertMrksAreUnique(Arrays.asList( - "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", - "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f" - )); - } - - @Test - @DisplayName("Postcondition: Filter out duplicate resources that are not multi-region keys.") - public void non_mrk_duplicates_ok() { - AwsKmsMrkAwareMasterKeyProvider - .assertMrksAreUnique(Arrays.asList( - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", - "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", - "arn:aws:kms:us-west-2:658956600833:alias/EncryptDecrypt", - "arn:aws:kms:us-west-2:658956600833:alias/EncryptDecrypt" - )); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 - //= type=test - //# If any duplicate multi-region resource ids exist, this function MUST - //# yield an error that includes all identifiers with duplicate resource - //# ids not only the first duplicate found. - public void no_duplicate_mrks() { - assertThrows( - IllegalArgumentException.class, - () -> AwsKmsMrkAwareMasterKeyProvider - .assertMrksAreUnique(Arrays.asList( - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574" - ))); - } - } - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# On initialization the caller MUST provide: - static public class AwsKmsMrkAwareMasterKeyProviderBuilderTests { - @Test - public void basic_use() { - final AwsKmsMrkAwareMasterKeyProvider strict = AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildStrict("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"); - final AwsKmsMrkAwareMasterKeyProvider discovery = AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildDiscovery(); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.5 - //= type=test - //# MUST implement the Master Key Provider Interface (../master-key- - //# provider-interface.md#interface) - assertTrue(MasterKeyProvider.class.isInstance(strict)); - assertTrue(MasterKeyProvider.class.isInstance(discovery)); - - // These are not testable because of how the builder is structured. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# A discovery filter MUST NOT be configured in strict mode. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# A default MRK Region MUST NOT be configured in strict mode. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# In - //# discovery mode if a default MRK Region is not configured the AWS SDK - //# Default Region MUST be used. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# The key id list MUST be empty in discovery mode. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# The regional client - //# supplier MUST be defined in discovery mode. - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# The key id list MUST NOT be empty or null in strict mode. - public void no_noop() { - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildStrict()); - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildStrict(new ArrayList())); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# The key id - //# list MUST NOT contain any null or empty string values. - public void no_null_identifiers() { - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildStrict("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", "")); - - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildStrict("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", null)); - - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# All AWS KMS - //# key identifiers are be passed to Assert AWS KMS MRK are unique (aws- - //# kms-mrk-are-unique.md#Implementation) and the function MUST return - //# success. - public void no_duplicate_mrks() { - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildStrict( - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574")); - } - - @Test - @DisplayName("Precondition: A region is required to contact AWS KMS.") - public void always_need_a_region() { - assertThrows(AwsCryptoException.class, () -> AwsKmsMrkAwareMasterKeyProvider - .builder() - .withDefaultRegion(null) - .buildStrict( - "mrk-edb7fe6942894d32ac46dbb1c922d574")); - - AwsKmsMrkAwareMasterKeyProvider - .builder() - .withDefaultRegion("us-east-1") - .buildStrict( - "mrk-edb7fe6942894d32ac46dbb1c922d574"); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 - //= type=test - //# If an AWS SDK Default Region can not be - //# obtained initialization MUST fail. - public void discovery_region_can_not_be_null() { - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKeyProvider - .builder() - // need to force the default region to `null` - // otherwise it may pick one up from the environment. - .withDefaultRegion(null) - .withDiscoveryMrkRegion(null) - .buildDiscovery()); - } - - @Test - public void basic_credentials_and_builder() { - BasicAWSCredentials creds = new BasicAWSCredentials("asdf", "qwer"); - AwsKmsMrkAwareMasterKeyProvider - .builder() - .withClientBuilder(AWSKMSClientBuilder.standard()) - .withCredentials(creds) - .buildDiscovery(); - } - } - - static public class extractRegion { - - @Test - public void basic_use() { - final String test = AwsKmsMrkAwareMasterKeyProvider - .extractRegion( - "us-east-1", - "us-east-2", - Optional.of("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - false - ); - - assertEquals("us-west-2", test); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# If the requested AWS KMS key identifier is not a well formed ARN the - //# AWS Region MUST be the configured default region this SHOULD be - //# obtained from the AWS SDK. - public void not_an_arn() { - final String test = AwsKmsMrkAwareMasterKeyProvider - .extractRegion( - "us-east-1", - "us-east-2", - Optional.empty(), - parseInfoFromKeyArn("mrk-edb7fe6942894d32ac46dbb1c922d574"), - false - ); - - assertEquals("us-east-1", test); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# Otherwise if the requested AWS KMS key - //# identifier is identified as a multi-Region key (aws-kms-key- - //# arn.md#identifying-an-aws-kms-multi-region-key), then AWS Region MUST - //# be the region from the AWS KMS key ARN stored in the provider info - //# from the encrypted data key. - public void not_an_mrk() { - final String test = AwsKmsMrkAwareMasterKeyProvider - .extractRegion( - "us-east-1", - "us-east-2", - Optional.of("arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"), - parseInfoFromKeyArn("arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"), - false - ); - - assertEquals("us-west-2", test); - - final String test2 = AwsKmsMrkAwareMasterKeyProvider - .extractRegion( - "us-east-1", - "us-east-2", - Optional.of("arn:aws:kms:us-west-2:658956600833:alias/mrk-nasty"), - parseInfoFromKeyArn("arn:aws:kms:us-west-2:658956600833:alias/mrk-nasty"), - false - ); - - assertEquals("us-west-2", test2); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# Otherwise if the mode is discovery then - //# the AWS Region MUST be the discovery MRK region. - public void mrk_in_discovery() { - final String test = AwsKmsMrkAwareMasterKeyProvider - .extractRegion( - "us-east-1", - "us-east-2", - Optional.empty(), - parseInfoFromKeyArn("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - true - ); - - assertEquals("us-east-2", test); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# Finally if the - //# provider info is identified as a multi-Region key (aws-kms-key- - //# arn.md#identifying-an-aws-kms-multi-region-key) the AWS Region MUST - //# be the region from the AWS KMS key in the configured key ids matched - //# to the requested AWS KMS key by using AWS KMS MRK Match for Decrypt - //# (aws-kms-mrk-match-for-decrypt.md#implementation). - public void fuzzy_match_mrk() { - final String test = AwsKmsMrkAwareMasterKeyProvider - .extractRegion( - "us-east-1", - "us-east-2", - Optional.of("arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - parseInfoFromKeyArn("arn:aws:kms:us-west-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), - false - ); - - assertEquals("us-west-2", test); - } - } - - static public class getMasterKey { - @Test - public void basic_use() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final AWSKMS client = spy(new MockKMSClient()); - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .withCustomClientFactory(supplier) - .buildStrict(identifier); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# The input MUST be the same as the Master Key Provider Get Master Key - //# (../master-key-provider-interface.md#get-master-key) interface. - AwsKmsMrkAwareMasterKey test = mkp.getMasterKey( - "aws-kms", - identifier); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# The output MUST be the same as the Master Key Provider Get Master Key - //# (../master-key-provider-interface.md#get-master-key) interface. - assertTrue(AwsKmsMrkAwareMasterKey.class.isInstance((test))); - - assertEquals(identifier, test.getKeyId()); - verify(supplier, times(1)).getClient("us-west-2"); - } - - @Test - public void basic_mrk_use() { - final String configuredIdentifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final String requestedIdentifier = "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final AWSKMS client = spy(new MockKMSClient()); - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .withCustomClientFactory(supplier) - .buildStrict(configuredIdentifier); - - AwsKmsMrkAwareMasterKey test = mkp.getMasterKey( - "aws-kms", - requestedIdentifier); - - assertEquals(configuredIdentifier, test.getKeyId()); - verify(supplier, times(1)).getClient("us-west-2"); - } - - @Test - public void other_basic_uses() { - final AWSKMS client = spy(new MockKMSClient()); - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - // A raw alias is a valid configuration for encryption - final String rawAliasIdentifier = "alias/my-alias"; - AwsKmsMrkAwareMasterKeyProvider - .builder() + public static class getResourceForResourceTypeKey { + @Test + @DisplayName("Postcondition: Return the key id.") + public void basic_use() { + assertEquals( + "mrk-edb7fe6942894d32ac46dbb1c922d574", + AwsKmsMrkAwareMasterKeyProvider.getResourceForResourceTypeKey( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574")); + } + + @Test + @DisplayName("Check for early return (Postcondition): Non-ARNs may be raw resources.") + public void not_an_arn() { + assertEquals( + "mrk-edb7fe6942894d32ac46dbb1c922d574", + AwsKmsMrkAwareMasterKeyProvider.getResourceForResourceTypeKey( + "mrk-edb7fe6942894d32ac46dbb1c922d574")); + final String malformed = "aws:kms:us-west-2::key/garbage"; + assertEquals( + malformed, AwsKmsMrkAwareMasterKeyProvider.getResourceForResourceTypeKey(malformed)); + } + + @Test + @DisplayName( + "Check for early return (Postcondition): Return the identifier for non-key resource types.") + public void not_a_key() { + final String alias = "arn:aws:kms:us-west-2:658956600833:alias/EncryptDecrypt"; + assertEquals(alias, AwsKmsMrkAwareMasterKeyProvider.getResourceForResourceTypeKey(alias)); + } + } + + public static class assertMrksAreUnique { + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 + // = type=test + // # The caller MUST provide: + public void basic_use() { + AwsKmsMrkAwareMasterKeyProvider.assertMrksAreUnique( + Arrays.asList( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574")); + } + + @Test + public void no_duplicates() { + // = compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 + // = type=test + // # If there are zero duplicate resource ids between the multi-region + // # keys, this function MUST exit successfully + AwsKmsMrkAwareMasterKeyProvider.assertMrksAreUnique( + Arrays.asList( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f")); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 + // = type=test + // # If the list does not contain any multi-Region keys (aws-kms-key- + // # arn.md#identifying-an-aws-kms-multi-region-key) this function MUST + // # exit successfully. + public void no_mrks_at_all() { + AwsKmsMrkAwareMasterKeyProvider.assertMrksAreUnique( + Arrays.asList( + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f")); + } + + @Test + @DisplayName("Postcondition: Filter out duplicate resources that are not multi-region keys.") + public void non_mrk_duplicates_ok() { + AwsKmsMrkAwareMasterKeyProvider.assertMrksAreUnique( + Arrays.asList( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", + "arn:aws:kms:us-west-2:658956600833:alias/EncryptDecrypt", + "arn:aws:kms:us-west-2:658956600833:alias/EncryptDecrypt")); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-are-unique.txt#2.5 + // = type=test + // # If any duplicate multi-region resource ids exist, this function MUST + // # yield an error that includes all identifiers with duplicate resource + // # ids not only the first duplicate found. + public void no_duplicate_mrks() { + assertThrows( + IllegalArgumentException.class, + () -> + AwsKmsMrkAwareMasterKeyProvider.assertMrksAreUnique( + Arrays.asList( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"))); + } + } + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # On initialization the caller MUST provide: + public static class AwsKmsMrkAwareMasterKeyProviderBuilderTests { + @Test + public void basic_use() { + final AwsKmsMrkAwareMasterKeyProvider strict = + AwsKmsMrkAwareMasterKeyProvider.builder() + .buildStrict( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"); + final AwsKmsMrkAwareMasterKeyProvider discovery = + AwsKmsMrkAwareMasterKeyProvider.builder().buildDiscovery(); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.5 + // = type=test + // # MUST implement the Master Key Provider Interface (../master-key- + // # provider-interface.md#interface) + assertTrue(MasterKeyProvider.class.isInstance(strict)); + assertTrue(MasterKeyProvider.class.isInstance(discovery)); + + // These are not testable because of how the builder is structured. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # A discovery filter MUST NOT be configured in strict mode. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # A default MRK Region MUST NOT be configured in strict mode. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # In + // # discovery mode if a default MRK Region is not configured the AWS SDK + // # Default Region MUST be used. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # The key id list MUST be empty in discovery mode. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # The regional client + // # supplier MUST be defined in discovery mode. + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # The key id list MUST NOT be empty or null in strict mode. + public void no_noop() { + assertThrows( + IllegalArgumentException.class, + () -> AwsKmsMrkAwareMasterKeyProvider.builder().buildStrict()); + assertThrows( + IllegalArgumentException.class, + () -> AwsKmsMrkAwareMasterKeyProvider.builder().buildStrict(new ArrayList())); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # The key id + // # list MUST NOT contain any null or empty string values. + public void no_null_identifiers() { + assertThrows( + IllegalArgumentException.class, + () -> + AwsKmsMrkAwareMasterKeyProvider.builder() + .buildStrict( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "")); + + assertThrows( + IllegalArgumentException.class, + () -> + AwsKmsMrkAwareMasterKeyProvider.builder() + .buildStrict( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + null)); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # All AWS KMS + // # key identifiers are be passed to Assert AWS KMS MRK are unique (aws- + // # kms-mrk-are-unique.md#Implementation) and the function MUST return + // # success. + public void no_duplicate_mrks() { + assertThrows( + IllegalArgumentException.class, + () -> + AwsKmsMrkAwareMasterKeyProvider.builder() + .buildStrict( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574")); + } + + @Test + @DisplayName("Precondition: A region is required to contact AWS KMS.") + public void always_need_a_region() { + assertThrows( + AwsCryptoException.class, + () -> + AwsKmsMrkAwareMasterKeyProvider.builder() + .withDefaultRegion(null) + .buildStrict("mrk-edb7fe6942894d32ac46dbb1c922d574")); + + AwsKmsMrkAwareMasterKeyProvider.builder() + .withDefaultRegion("us-east-1") + .buildStrict("mrk-edb7fe6942894d32ac46dbb1c922d574"); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.6 + // = type=test + // # If an AWS SDK Default Region can not be + // # obtained initialization MUST fail. + public void discovery_region_can_not_be_null() { + assertThrows( + IllegalArgumentException.class, + () -> + AwsKmsMrkAwareMasterKeyProvider.builder() + // need to force the default region to `null` + // otherwise it may pick one up from the environment. + .withDefaultRegion(null) + .withDiscoveryMrkRegion(null) + .buildDiscovery()); + } + + @Test + public void basic_credentials_and_builder() { + BasicAWSCredentials creds = new BasicAWSCredentials("asdf", "qwer"); + AwsKmsMrkAwareMasterKeyProvider.builder() + .withClientBuilder(AWSKMSClientBuilder.standard()) + .withCredentials(creds) + .buildDiscovery(); + } + } + + public static class extractRegion { + + @Test + public void basic_use() { + final String test = + AwsKmsMrkAwareMasterKeyProvider.extractRegion( + "us-east-1", + "us-east-2", + Optional.of( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + + assertEquals("us-west-2", test); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # If the requested AWS KMS key identifier is not a well formed ARN the + // # AWS Region MUST be the configured default region this SHOULD be + // # obtained from the AWS SDK. + public void not_an_arn() { + final String test = + AwsKmsMrkAwareMasterKeyProvider.extractRegion( + "us-east-1", + "us-east-2", + Optional.empty(), + parseInfoFromKeyArn("mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + + assertEquals("us-east-1", test); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # Otherwise if the requested AWS KMS key + // # identifier is identified as a multi-Region key (aws-kms-key- + // # arn.md#identifying-an-aws-kms-multi-region-key), then AWS Region MUST + // # be the region from the AWS KMS key ARN stored in the provider info + // # from the encrypted data key. + public void not_an_mrk() { + final String test = + AwsKmsMrkAwareMasterKeyProvider.extractRegion( + "us-east-1", + "us-east-2", + Optional.of( + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"), + parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"), + false); + + assertEquals("us-west-2", test); + + final String test2 = + AwsKmsMrkAwareMasterKeyProvider.extractRegion( + "us-east-1", + "us-east-2", + Optional.of("arn:aws:kms:us-west-2:658956600833:alias/mrk-nasty"), + parseInfoFromKeyArn("arn:aws:kms:us-west-2:658956600833:alias/mrk-nasty"), + false); + + assertEquals("us-west-2", test2); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # Otherwise if the mode is discovery then + // # the AWS Region MUST be the discovery MRK region. + public void mrk_in_discovery() { + final String test = + AwsKmsMrkAwareMasterKeyProvider.extractRegion( + "us-east-1", + "us-east-2", + Optional.empty(), + parseInfoFromKeyArn( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + true); + + assertEquals("us-east-2", test); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # Finally if the + // # provider info is identified as a multi-Region key (aws-kms-key- + // # arn.md#identifying-an-aws-kms-multi-region-key) the AWS Region MUST + // # be the region from the AWS KMS key in the configured key ids matched + // # to the requested AWS KMS key by using AWS KMS MRK Match for Decrypt + // # (aws-kms-mrk-match-for-decrypt.md#implementation). + public void fuzzy_match_mrk() { + final String test = + AwsKmsMrkAwareMasterKeyProvider.extractRegion( + "us-east-1", + "us-east-2", + Optional.of( + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + parseInfoFromKeyArn( + "arn:aws:kms:us-west-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"), + false); + + assertEquals("us-west-2", test); + } + } + + public static class getMasterKey { + @Test + public void basic_use() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final AWSKMS client = spy(new MockKMSClient()); + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict(identifier); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # The input MUST be the same as the Master Key Provider Get Master Key + // # (../master-key-provider-interface.md#get-master-key) interface. + AwsKmsMrkAwareMasterKey test = mkp.getMasterKey("aws-kms", identifier); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # The output MUST be the same as the Master Key Provider Get Master Key + // # (../master-key-provider-interface.md#get-master-key) interface. + assertTrue(AwsKmsMrkAwareMasterKey.class.isInstance((test))); + + assertEquals(identifier, test.getKeyId()); + verify(supplier, times(1)).getClient("us-west-2"); + } + + @Test + public void basic_mrk_use() { + final String configuredIdentifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final String requestedIdentifier = + "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final AWSKMS client = spy(new MockKMSClient()); + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict(configuredIdentifier); + + AwsKmsMrkAwareMasterKey test = mkp.getMasterKey("aws-kms", requestedIdentifier); + + assertEquals(configuredIdentifier, test.getKeyId()); + verify(supplier, times(1)).getClient("us-west-2"); + } + + @Test + public void other_basic_uses() { + final AWSKMS client = spy(new MockKMSClient()); + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + // A raw alias is a valid configuration for encryption + final String rawAliasIdentifier = "alias/my-alias"; + AwsKmsMrkAwareMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict(rawAliasIdentifier) + .getMasterKey("aws-kms", rawAliasIdentifier); + + // A raw alias is a valid configuration for encryption + final String rawKeyIdentifier = "mrk-edb7fe6942894d32ac46dbb1c922d574"; + AwsKmsMrkAwareMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict(rawKeyIdentifier) + .getMasterKey("aws-kms", rawKeyIdentifier); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # The function MUST only provide master keys if the input provider id + // # equals "aws-kms". + public void only_this_provider() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final AWSKMS client = spy(new MockKMSClient()); + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict(identifier); + + assertThrows( + UnsupportedProviderException.class, () -> mkp.getMasterKey("not-aws-kms", identifier)); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # In strict mode, the requested AWS KMS key ARN MUST + // # match a member of the configured key ids by using AWS KMS MRK Match + // # for Decrypt (aws-kms-mrk-match-for-decrypt.md#implementation) + // # otherwise this function MUST error. + public void no_key_id_match() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final AWSKMS client = spy(new MockKMSClient()); + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + final AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder() .withCustomClientFactory(supplier) - .buildStrict(rawAliasIdentifier) - .getMasterKey( - "aws-kms", - rawAliasIdentifier); - - // A raw alias is a valid configuration for encryption - final String rawKeyIdentifier = "mrk-edb7fe6942894d32ac46dbb1c922d574"; - AwsKmsMrkAwareMasterKeyProvider - .builder() + .buildStrict(identifier); + + assertThrows( + NoSuchMasterKeyException.class, + () -> mkp.getMasterKey("aws-kms", "does-not-match-configured")); + } + + @Test + @DisplayName("Precondition: Discovery mode requires requestedKeyArn be an ARN.") + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # In discovery mode, the requested + // # AWS KMS key identifier MUST be a well formed AWS KMS ARN. + public void discovery_request_must_be_arn() { + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder().buildDiscovery(); + + assertThrows( + NoSuchMasterKeyException.class, + () -> mkp.getMasterKey("aws-kms", "mrk-edb7fe6942894d32ac46dbb1c922d574")); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # In + // # discovery mode if a discovery filter is configured the requested AWS + // # KMS key ARN's "partition" MUST match the discovery filter's + // # "partition" and the AWS KMS key ARN's "account" MUST exist in the + // # discovery filter's account id set. + public void discovery_filter_must_match() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final AWSKMS client = spy(new MockKMSClient()); + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + assertThrows( + NoSuchMasterKeyException.class, + () -> + AwsKmsMrkAwareMasterKeyProvider.builder() + .buildDiscovery(new DiscoveryFilter("aws", Arrays.asList("not-111122223333"))) + .getMasterKey("aws-kms", identifier)); + + assertThrows( + NoSuchMasterKeyException.class, + () -> + AwsKmsMrkAwareMasterKeyProvider.builder() + .buildDiscovery(new DiscoveryFilter("not-aws", Arrays.asList("111122223333"))) + .getMasterKey("aws-kms", identifier)); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # In discovery mode a AWS KMS MRK Aware Master Key (aws-kms-mrk-aware- + // # master-key.md) MUST be returned configured with + public void discovery_magic_to_make_the_region_match() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final AWSKMS client = spy(new MockKMSClient()); + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder() .withCustomClientFactory(supplier) - .buildStrict(rawKeyIdentifier) - .getMasterKey( - "aws-kms", - rawKeyIdentifier); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# The function MUST only provide master keys if the input provider id - //# equals "aws-kms". - public void only_this_provider() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final AWSKMS client = spy(new MockKMSClient()); - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .withCustomClientFactory(supplier) - .buildStrict(identifier); - - assertThrows(UnsupportedProviderException.class, () -> mkp.getMasterKey( - "not-aws-kms", - identifier)); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# In strict mode, the requested AWS KMS key ARN MUST - //# match a member of the configured key ids by using AWS KMS MRK Match - //# for Decrypt (aws-kms-mrk-match-for-decrypt.md#implementation) - //# otherwise this function MUST error. - public void no_key_id_match() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final AWSKMS client = spy(new MockKMSClient()); - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - final AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .withCustomClientFactory(supplier) - .buildStrict(identifier); - - assertThrows(NoSuchMasterKeyException.class, () -> mkp.getMasterKey( - "aws-kms", - "does-not-match-configured")); - } - - @Test - @DisplayName("Precondition: Discovery mode requires requestedKeyArn be an ARN.") - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# In discovery mode, the requested - //# AWS KMS key identifier MUST be a well formed AWS KMS ARN. - public void discovery_request_must_be_arn() { - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildDiscovery(); - - assertThrows(NoSuchMasterKeyException.class, - () -> mkp.getMasterKey( - "aws-kms", - "mrk-edb7fe6942894d32ac46dbb1c922d574")); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# In - //# discovery mode if a discovery filter is configured the requested AWS - //# KMS key ARN's "partition" MUST match the discovery filter's - //# "partition" and the AWS KMS key ARN's "account" MUST exist in the - //# discovery filter's account id set. - public void discovery_filter_must_match() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final AWSKMS client = spy(new MockKMSClient()); - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - assertThrows(NoSuchMasterKeyException.class, () -> AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildDiscovery(new DiscoveryFilter("aws", Arrays.asList("not-111122223333"))) - .getMasterKey( - "aws-kms", - identifier) - ); - - assertThrows(NoSuchMasterKeyException.class, () -> AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildDiscovery(new DiscoveryFilter("not-aws", Arrays.asList("111122223333"))) - .getMasterKey( - "aws-kms", - identifier) - ); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# In discovery mode a AWS KMS MRK Aware Master Key (aws-kms-mrk-aware- - //# master-key.md) MUST be returned configured with - public void discovery_magic_to_make_the_region_match() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final AWSKMS client = spy(new MockKMSClient()); - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .withCustomClientFactory(supplier) - .withDiscoveryMrkRegion("my-region") - .buildDiscovery(); - - AwsKmsMrkAwareMasterKey test = mkp.getMasterKey( - "aws-kms", - identifier); - - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# An AWS KMS client - //# MUST be obtained by calling the regional client supplier with this - //# AWS Region. - assertEquals( - "arn:aws:kms:my-region:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - test.getKeyId() - ); - verify(supplier, times(1)).getClient("my-region"); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 - //= type=test - //# In strict mode a AWS KMS MRK Aware Master Key (aws-kms-mrk-aware- - //# master-key.md) MUST be returned configured with - public void strict_mrk_region_match() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final String configIdentifier = "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final AWSKMS client = spy(new MockKMSClient()); - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() + .withDiscoveryMrkRegion("my-region") + .buildDiscovery(); + + AwsKmsMrkAwareMasterKey test = mkp.getMasterKey("aws-kms", identifier); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # An AWS KMS client + // # MUST be obtained by calling the regional client supplier with this + // # AWS Region. + assertEquals( + "arn:aws:kms:my-region:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + test.getKeyId()); + verify(supplier, times(1)).getClient("my-region"); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.7 + // = type=test + // # In strict mode a AWS KMS MRK Aware Master Key (aws-kms-mrk-aware- + // # master-key.md) MUST be returned configured with + public void strict_mrk_region_match() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final String configIdentifier = + "arn:aws:kms:us-east-1:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final AWSKMS client = spy(new MockKMSClient()); + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder() .withCustomClientFactory(supplier) .buildStrict(configIdentifier); - AwsKmsMrkAwareMasterKey test = mkp.getMasterKey( - "aws-kms", - identifier); - - assertEquals( - configIdentifier, - test.getKeyId() - ); - verify(supplier, times(1)).getClient("us-east-1"); - } - } - - static public class decryptDataKey { - - @Test - public void basic_use() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final byte[] cipherText = new byte[10]; - final EncryptedDataKey edk1 = new KeyBlob( - "aws-kms", - identifier.getBytes(StandardCharsets.UTF_8), - cipherText); - final EncryptedDataKey edk2 = new KeyBlob( - "aws-kms", - identifier.getBytes(StandardCharsets.UTF_8), - cipherText); - - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - final AWSKMS client = mock(AWSKMS.class); - when(client.decrypt(any())) - .thenReturn(new DecryptResult() - .withKeyId(identifier) - .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); - when(supplier.getClient(any())).thenReturn(client); - - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() + AwsKmsMrkAwareMasterKey test = mkp.getMasterKey("aws-kms", identifier); + + assertEquals(configIdentifier, test.getKeyId()); + verify(supplier, times(1)).getClient("us-east-1"); + } + } + + public static class decryptDataKey { + + @Test + public void basic_use() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final byte[] cipherText = new byte[10]; + final EncryptedDataKey edk1 = + new KeyBlob("aws-kms", identifier.getBytes(StandardCharsets.UTF_8), cipherText); + final EncryptedDataKey edk2 = + new KeyBlob("aws-kms", identifier.getBytes(StandardCharsets.UTF_8), cipherText); + + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + final AWSKMS client = mock(AWSKMS.class); + when(client.decrypt(any())) + .thenReturn( + new DecryptResult() + .withKeyId(identifier) + .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); + when(supplier.getClient(any())).thenReturn(client); + + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder() .withCustomClientFactory(supplier) .buildStrict(identifier) .withGrantTokens(GRANT_TOKENS); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# The input MUST be the same as the Master Key Provider Decrypt Data - //# Key (../master-key-provider-interface.md#decrypt-data-key) interface. - final DataKey test = mkp - .decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(edk1, edk2), - ENCRYPTION_CONTEXT); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# For each encrypted data key in the filtered set, one at a time, the - //# master key provider MUST call Get Master Key (aws-kms-mrk-aware- - //# master-key-provider.md#get-master-key) with the encrypted data key's - //# provider info as the AWS KMS key ARN. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# It MUST call Decrypt Data Key - //# (aws-kms-mrk-aware-master-key.md#decrypt-data-key) on this master key - //# with the input algorithm, this single encrypted data key, and the - //# input encryption context. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# If the decrypt data key call is - //# successful, then this function MUST return this result and not - //# attempt to decrypt any more encrypted data keys. - verify(client, times((1))).decrypt(new DecryptRequest() - .withGrantTokens(GRANT_TOKENS) - .withEncryptionContext(ENCRYPTION_CONTEXT) - .withKeyId(identifier) - .withCiphertextBlob(ByteBuffer.wrap(cipherText)) - ); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# The output MUST be the same as the Master Key Provider Decrypt Data - //# Key (../master-key-provider-interface.md#decrypt-data-key) interface. - assertTrue(DataKey.class.isInstance(test)); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# The set of encrypted data keys MUST first be filtered to match this - //# master key's configuration. - public void only_if_providers_match() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final EncryptedDataKey edk = new KeyBlob( - "not-aws-kms", - "not the identifier".getBytes(StandardCharsets.UTF_8), - new byte[10]); - - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildStrict(identifier); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# To match the encrypted data key's - //# provider ID MUST exactly match the value "aws-kms". - final CannotUnwrapDataKeyException test = assertThrows( - "Unable to decrypt any data keys", - CannotUnwrapDataKeyException.class, () -> mkp - .decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(edk), - ENCRYPTION_CONTEXT)); - assertEquals(0, test.getSuppressed().length); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# Additionally - //# each provider info MUST be a valid AWS KMS ARN (aws-kms-key-arn.md#a- - //# valid-aws-kms-arn) with a resource type of "key". - public void provider_info_must_be_arn() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final String aliasArn = "arn:aws:kms:us-west-2:111122223333:alias/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final EncryptedDataKey edk = new KeyBlob( - "aws-kms", - aliasArn.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildStrict(identifier); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # The input MUST be the same as the Master Key Provider Decrypt Data + // # Key (../master-key-provider-interface.md#decrypt-data-key) interface. + final DataKey test = + mkp.decryptDataKey(ALGORITHM_SUITE, Arrays.asList(edk1, edk2), ENCRYPTION_CONTEXT); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # For each encrypted data key in the filtered set, one at a time, the + // # master key provider MUST call Get Master Key (aws-kms-mrk-aware- + // # master-key-provider.md#get-master-key) with the encrypted data key's + // # provider info as the AWS KMS key ARN. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # It MUST call Decrypt Data Key + // # (aws-kms-mrk-aware-master-key.md#decrypt-data-key) on this master key + // # with the input algorithm, this single encrypted data key, and the + // # input encryption context. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # If the decrypt data key call is + // # successful, then this function MUST return this result and not + // # attempt to decrypt any more encrypted data keys. + verify(client, times((1))) + .decrypt( + new DecryptRequest() + .withGrantTokens(GRANT_TOKENS) + .withEncryptionContext(ENCRYPTION_CONTEXT) + .withKeyId(identifier) + .withCiphertextBlob(ByteBuffer.wrap(cipherText))); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # The output MUST be the same as the Master Key Provider Decrypt Data + // # Key (../master-key-provider-interface.md#decrypt-data-key) interface. + assertTrue(DataKey.class.isInstance(test)); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # The set of encrypted data keys MUST first be filtered to match this + // # master key's configuration. + public void only_if_providers_match() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final EncryptedDataKey edk = + new KeyBlob( + "not-aws-kms", "not the identifier".getBytes(StandardCharsets.UTF_8), new byte[10]); + + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder().buildStrict(identifier); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # To match the encrypted data key's + // # provider ID MUST exactly match the value "aws-kms". + final CannotUnwrapDataKeyException test = + assertThrows( + "Unable to decrypt any data keys", + CannotUnwrapDataKeyException.class, + () -> mkp.decryptDataKey(ALGORITHM_SUITE, Arrays.asList(edk), ENCRYPTION_CONTEXT)); + assertEquals(0, test.getSuppressed().length); + } - final IllegalStateException test = assertThrows( + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # Additionally + // # each provider info MUST be a valid AWS KMS ARN (aws-kms-key-arn.md#a- + // # valid-aws-kms-arn) with a resource type of "key". + public void provider_info_must_be_arn() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final String aliasArn = + "arn:aws:kms:us-west-2:111122223333:alias/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final EncryptedDataKey edk = + new KeyBlob("aws-kms", aliasArn.getBytes(StandardCharsets.UTF_8), new byte[10]); + + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder().buildStrict(identifier); + + final IllegalStateException test = + assertThrows( "Invalid provider info in message.", - IllegalStateException.class, () -> mkp - .decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(edk), - ENCRYPTION_CONTEXT)); - assertEquals(0, test.getSuppressed().length); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# If this attempt results in an error, then - //# these errors MUST be collected. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 - //= type=test - //# If all the input encrypted data keys have been processed then this - //# function MUST yield an error that includes all the collected errors. - public void exception_wrapped() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final EncryptedDataKey edk = new KeyBlob( - "aws-kms", - identifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - final AWSKMS client = mock(AWSKMS.class); - final String clientErrMsg = "asdf"; - when(client.decrypt(any())).thenThrow(new AmazonServiceException(clientErrMsg)); - when(supplier.getClient(any())).thenReturn(client); - - AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .withCustomClientFactory(supplier) - .buildStrict(identifier); - - CannotUnwrapDataKeyException test = assertThrows( - "Unable to decrypt any data keys", - CannotUnwrapDataKeyException.class, () -> mkp - .decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(edk), - ENCRYPTION_CONTEXT)); - assertEquals(1, test.getSuppressed().length); - Throwable fromMasterKey = Arrays.stream(test.getSuppressed()).findFirst().get(); - assertTrue(fromMasterKey instanceof CannotUnwrapDataKeyException); - assertEquals(1, fromMasterKey.getSuppressed().length); - Throwable fromClient = Arrays.stream(fromMasterKey.getSuppressed()).findFirst().get(); - assertTrue(fromClient instanceof AmazonServiceException); - assertTrue(fromClient.getMessage().startsWith(clientErrMsg)); - } - } - - static public class clientFactory { - @Test - public void basic_use() { - final ConcurrentHashMap cache = spy(new ConcurrentHashMap<>()); - - final AWSKMS test = AwsKmsMrkAwareMasterKeyProvider - .Builder - .clientFactory(cache, null) - .getClient("asdf"); - assertNotEquals(null, test); - verify(cache, times(1)).containsKey("asdf"); - } - - @Test - @DisplayName("Check for early return (Postcondition): If a client already exists, use that.") - public void use_clients_that_exist() { - final String region = "asdf"; - final ConcurrentHashMap cache = spy(new ConcurrentHashMap<>()); - // Add something so we can verify that we get it - final AWSKMS client = mock(AWSKMS.class); - cache.put(region, client); - - final AWSKMS test = AwsKmsMrkAwareMasterKeyProvider - .Builder - .clientFactory(cache, null) - .getClient(region); - - assertEquals(client, test); - } - } - - static public class getMasterKeysForEncryption { - @Test - public void basic_use() { - final String identifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final AWSKMS client = spy(new MockKMSClient()); - final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient("us-west-2")).thenReturn(client); - final MasterKeyRequest request = MasterKeyRequest.newBuilder().build(); - - final AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .withCustomClientFactory(supplier) - .buildStrict(identifier); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 - //= type=test - //# The input MUST be the same as the Master Key Provider Get Master Keys - //# For Encryption (../master-key-provider-interface.md#get-master-keys- - //# for-encryption) interface. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 - //= type=test - //# The output MUST be the same as the Master Key Provider Get Master - //# Keys For Encryption (../master-key-provider-interface.md#get-master- - //# keys-for-encryption) interface. - final List test = mkp.getMasterKeysForEncryption(request); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 - //= type=test - //# If the configured mode is strict this function MUST return a - //# list of master keys obtained by calling Get Master Key (aws-kms-mrk- - //# aware-master-key-provider.md#get-master-key) for each AWS KMS key - //# identifier in the configured key ids - assertEquals(1, test.size()); - assertEquals(identifier, test.get(0).getKeyId()); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 - //= type=test - //# If the configured mode is discovery the function MUST return an empty - //# list. - public void no_keys_is_empty_list() { - final AwsKmsMrkAwareMasterKeyProvider mkp = AwsKmsMrkAwareMasterKeyProvider - .builder() - .buildDiscovery(); - - final List test = mkp.getMasterKeysForEncryption(MasterKeyRequest.newBuilder().build()); - assertEquals(0, test.size()); - } + IllegalStateException.class, + () -> mkp.decryptDataKey(ALGORITHM_SUITE, Arrays.asList(edk), ENCRYPTION_CONTEXT)); + assertEquals(0, test.getSuppressed().length); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # If this attempt results in an error, then + // # these errors MUST be collected. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.9 + // = type=test + // # If all the input encrypted data keys have been processed then this + // # function MUST yield an error that includes all the collected errors. + public void exception_wrapped() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final EncryptedDataKey edk = + new KeyBlob("aws-kms", identifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + final AWSKMS client = mock(AWSKMS.class); + final String clientErrMsg = "asdf"; + when(client.decrypt(any())).thenThrow(new AmazonServiceException(clientErrMsg)); + when(supplier.getClient(any())).thenReturn(client); + + AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict(identifier); + + CannotUnwrapDataKeyException test = + assertThrows( + "Unable to decrypt any data keys", + CannotUnwrapDataKeyException.class, + () -> mkp.decryptDataKey(ALGORITHM_SUITE, Arrays.asList(edk), ENCRYPTION_CONTEXT)); + assertEquals(1, test.getSuppressed().length); + Throwable fromMasterKey = Arrays.stream(test.getSuppressed()).findFirst().get(); + assertTrue(fromMasterKey instanceof CannotUnwrapDataKeyException); + assertEquals(1, fromMasterKey.getSuppressed().length); + Throwable fromClient = Arrays.stream(fromMasterKey.getSuppressed()).findFirst().get(); + assertTrue(fromClient instanceof AmazonServiceException); + assertTrue(fromClient.getMessage().startsWith(clientErrMsg)); + } + } + + public static class clientFactory { + @Test + public void basic_use() { + final ConcurrentHashMap cache = spy(new ConcurrentHashMap<>()); + + final AWSKMS test = + AwsKmsMrkAwareMasterKeyProvider.Builder.clientFactory(cache, null).getClient("asdf"); + assertNotEquals(null, test); + verify(cache, times(1)).containsKey("asdf"); + } + + @Test + @DisplayName("Check for early return (Postcondition): If a client already exists, use that.") + public void use_clients_that_exist() { + final String region = "asdf"; + final ConcurrentHashMap cache = spy(new ConcurrentHashMap<>()); + // Add something so we can verify that we get it + final AWSKMS client = mock(AWSKMS.class); + cache.put(region, client); + + final AWSKMS test = + AwsKmsMrkAwareMasterKeyProvider.Builder.clientFactory(cache, null).getClient(region); + + assertEquals(client, test); + } + } + + public static class getMasterKeysForEncryption { + @Test + public void basic_use() { + final String identifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final AWSKMS client = spy(new MockKMSClient()); + final RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient("us-west-2")).thenReturn(client); + final MasterKeyRequest request = MasterKeyRequest.newBuilder().build(); + + final AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict(identifier); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 + // = type=test + // # The input MUST be the same as the Master Key Provider Get Master Keys + // # For Encryption (../master-key-provider-interface.md#get-master-keys- + // # for-encryption) interface. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 + // = type=test + // # The output MUST be the same as the Master Key Provider Get Master + // # Keys For Encryption (../master-key-provider-interface.md#get-master- + // # keys-for-encryption) interface. + final List test = mkp.getMasterKeysForEncryption(request); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 + // = type=test + // # If the configured mode is strict this function MUST return a + // # list of master keys obtained by calling Get Master Key (aws-kms-mrk- + // # aware-master-key-provider.md#get-master-key) for each AWS KMS key + // # identifier in the configured key ids + assertEquals(1, test.size()); + assertEquals(identifier, test.get(0).getKeyId()); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key-provider.txt#2.8 + // = type=test + // # If the configured mode is discovery the function MUST return an empty + // # list. + public void no_keys_is_empty_list() { + final AwsKmsMrkAwareMasterKeyProvider mkp = + AwsKmsMrkAwareMasterKeyProvider.builder().buildDiscovery(); + + final List test = + mkp.getMasterKeysForEncryption(MasterKeyRequest.newBuilder().build()); + assertEquals(0, test.size()); } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyTest.java index b87527b8b..933e375e4 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/AwsKmsMrkAwareMasterKeyTest.java @@ -3,23 +3,20 @@ package com.amazonaws.encryptionsdk.kms; +import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; +import static org.junit.Assert.*; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; + import com.amazonaws.AmazonServiceException; import com.amazonaws.RequestClientOptions; import com.amazonaws.encryptionsdk.*; import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; -import com.amazonaws.encryptionsdk.internal.AwsKmsCmkArnInfo; import com.amazonaws.encryptionsdk.internal.VersionInfo; import com.amazonaws.encryptionsdk.model.KeyBlob; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.model.*; -import org.junit.Test; -import org.junit.experimental.runners.Enclosed; -import org.junit.jupiter.api.DisplayName; -import org.junit.runner.RunWith; -import org.mockito.ArgumentCaptor; - -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; import java.lang.reflect.Method; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -27,887 +24,888 @@ import java.util.Collections; import java.util.List; import java.util.Map; - -import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; -import static org.junit.Assert.*; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.*; -import static org.mockito.Mockito.mock; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.jupiter.api.DisplayName; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; @RunWith(Enclosed.class) public class AwsKmsMrkAwareMasterKeyTest { - public static class getInstance { - - @Test - public void basic_use() { - AWSKMS client = spy(new MockKMSClient()); - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //= type=test - //# On initialization, the caller MUST provide: - final AwsKmsMrkAwareMasterKey test = AwsKmsMrkAwareMasterKey - .getInstance( - client, - "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", - mkp); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.5 - //= type=test - //# MUST implement the Master Key Interface (../master-key- - //# interface.md#interface) - assertTrue(MasterKey.class.isInstance(test)); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //= type=test - //# The AWS KMS key identifier MUST NOT be null or empty. - public void requires_valid_identifiers() { - AWSKMS client = spy(new MockKMSClient()); - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKey - .getInstance( - client, - "", - mkp)); - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKey - .getInstance( - client, - null, - mkp)); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //= type=test - //# The AWS KMS - //# key identifier MUST be a valid identifier (aws-kms-key-arn.md#a- - //# valid-aws-kms-identifier). - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKey - .getInstance( - client, - "arn:aws:dynamodb:us-east-2:123456789012:table/myDynamoDBTable", - mkp)); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //= type=test - //# The AWS KMS SDK client MUST not be null. - public void requires_valid_client() { - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKey - .getInstance( - null, - "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", - mkp)); - } - - @Test - @DisplayName("Precondition: A provider is required.") - public void requires_valid_provider() { - AWSKMS client = spy(new MockKMSClient()); - - assertThrows(IllegalArgumentException.class, () -> AwsKmsMrkAwareMasterKey - .getInstance( - client, - "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", - null)); - } + public static class getInstance { + + @Test + public void basic_use() { + AWSKMS client = spy(new MockKMSClient()); + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // = type=test + // # On initialization, the caller MUST provide: + final AwsKmsMrkAwareMasterKey test = + AwsKmsMrkAwareMasterKey.getInstance( + client, + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", + mkp); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.5 + // = type=test + // # MUST implement the Master Key Interface (../master-key- + // # interface.md#interface) + assertTrue(MasterKey.class.isInstance(test)); } - public static class generateDataKey { - - @Test - public void basic_use() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final ByteBuffer udk = ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()); - final ByteBuffer ciphertext = ByteBuffer.allocate(10); - - final AWSKMS client = mock(AWSKMS.class); - when(client.generateDataKey(any())) - .thenReturn(new GenerateDataKeyResult() - .withPlaintext(udk) - .withKeyId(keyIdentifier) - .withCiphertextBlob(ciphertext)); - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 - //= type=test - //# The master key MUST be able to be configured with an optional list of - //# Grant Tokens. - masterKey.setGrantTokens(GRANT_TOKENS); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //= type=test - //# The inputs MUST be the same as the Master Key Generate Data Key - //# (../master-key-interface.md#generate-data-key) interface. - DataKey test = masterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT); - ArgumentCaptor gr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //= type=test - //# This - //# master key MUST use the configured AWS KMS client to make an AWS KMS - //# GenerateDatakey (https://docs.aws.amazon.com/kms/latest/APIReference/ - //# API_GenerateDataKey.html) request constructed as follows: - verify(client, times(1)).generateDataKey(gr.capture()); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //= type=test - //# The output MUST be the same as the Master Key Generate Data Key - //# (../master-key-interface.md#generate-data-key) interface. - assertTrue(DataKey.class.isInstance(test)); - - GenerateDataKeyRequest actualRequest = gr.getValue(); - - assertEquals(keyIdentifier, actualRequest.getKeyId()); - assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); - assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); - assertEquals(ALGORITHM_SUITE.getDataKeyLength(), actualRequest.getNumberOfBytes().longValue()); - assertTrue(actualRequest.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT) - .contains(VersionInfo.loadUserAgent())); - - assertNotNull(test.getKey()); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //= type=test - //# The response's "Plaintext" MUST be the plaintext in - //# the output. - assertEquals(ALGORITHM_SUITE.getDataKeyLength(), test.getKey().getEncoded().length); - assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), test.getKey().getAlgorithm()); - assertNotNull(test.getEncryptedDataKey()); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //= type=test - //# The response's cipher text blob MUST be used as the - //# returned as the ciphertext for the encrypted data key in the output. - assertEquals(10, test.getEncryptedDataKey().length); - - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //= type=test - //# If the call succeeds the AWS KMS Generate Data Key response's - //# "Plaintext" MUST match the key derivation input length specified by - //# the algorithm suite included in the input. - public void length_must_match() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - // I use more, because less _should_ trigger an underflow... but the condition should _always_ fail - final int wrongLength = ALGORITHM_SUITE.getDataKeyLength() + 1; - - final AWSKMS client = mock(AWSKMS.class); - when(client.generateDataKey(any())) - .thenReturn(new GenerateDataKeyResult() - .withPlaintext(ByteBuffer.allocate(wrongLength)) - .withKeyId(keyIdentifier) - .withCiphertextBlob(ByteBuffer.allocate(10))); - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - - assertThrows(IllegalStateException.class, - () -> masterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); - } - - @Test - @DisplayName("Exceptional Postcondition: Must have an AWS KMS ARN from AWS KMS generateDataKey.") - public void need_an_arn() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - - final AWSKMS client = mock(AWSKMS.class); - when(client.generateDataKey(any())) - .thenReturn(new GenerateDataKeyResult() - .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())) - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 - //= type=test - //# The response's "KeyId" - //# MUST be valid. - .withKeyId("b3537ef1-d8dc-4780-9f5a-55776cbb2f7f") - .withCiphertextBlob(ByteBuffer.allocate(10))); - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - - assertThrows(IllegalStateException.class, - () -> masterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); - } + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // = type=test + // # The AWS KMS key identifier MUST NOT be null or empty. + public void requires_valid_identifiers() { + AWSKMS client = spy(new MockKMSClient()); + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + + assertThrows( + IllegalArgumentException.class, + () -> AwsKmsMrkAwareMasterKey.getInstance(client, "", mkp)); + assertThrows( + IllegalArgumentException.class, + () -> AwsKmsMrkAwareMasterKey.getInstance(client, null, mkp)); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // = type=test + // # The AWS KMS + // # key identifier MUST be a valid identifier (aws-kms-key-arn.md#a- + // # valid-aws-kms-identifier). + assertThrows( + IllegalArgumentException.class, + () -> + AwsKmsMrkAwareMasterKey.getInstance( + client, "arn:aws:dynamodb:us-east-2:123456789012:table/myDynamoDBTable", mkp)); } - public static class encryptDataKey { - @Test - public void basic_use() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final SecretKey SECRET_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); - - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); - - final DataKey dataKey = new DataKey(SECRET_KEY, new byte[0], - "aws-kms".getBytes(StandardCharsets.UTF_8), mock(MasterKey.class)); - - final AWSKMS client = mock(AWSKMS.class); - when(client.encrypt(any())) - .thenReturn(new EncryptResult() - .withKeyId(keyIdentifier) - .withCiphertextBlob(ByteBuffer.allocate(10))); - - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - masterKey.setGrantTokens(GRANT_TOKENS); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //= type=test - //# The inputs MUST be the same as the Master Key Encrypt Data Key - //# (../master-key-interface.md#encrypt-data-key) interface. - DataKey test = masterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey); - - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //= type=test - //# The output MUST be the same as the Master Key Encrypt Data Key - //# (../master-key-interface.md#encrypt-data-key) interface. - assertTrue(DataKey.class.isInstance(test)); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //= type=test - //# The master - //# key MUST use the configured AWS KMS client to make an AWS KMS Encrypt - //# (https://docs.aws.amazon.com/kms/latest/APIReference/ - //# API_Encrypt.html) request constructed as follows: - verify(client, times(1)).encrypt(any()); - ArgumentCaptor gr = ArgumentCaptor.forClass(EncryptRequest.class); - verify(client, times(1)).encrypt(gr.capture()); - - final EncryptRequest actualRequest = gr.getValue(); - - assertEquals(keyIdentifier, actualRequest.getKeyId()); - assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); - assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); - assertTrue(actualRequest.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT) - .contains(VersionInfo.loadUserAgent())); - - assertNotNull(test.getKey()); - assertEquals(ALGORITHM_SUITE.getDataKeyLength(), test.getKey().getEncoded().length); - assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), test.getKey().getAlgorithm()); - assertNotNull(test.getEncryptedDataKey()); - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //= type=test - //# The - //# response's cipher text blob MUST be used as the "ciphertext" for the - //# encrypted data key. - assertEquals(10, test.getEncryptedDataKey().length); - } - - @Test - @DisplayName("Precondition: The key format MUST be RAW.") - public void secret_key_must_be_raw() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); - - // Test "stuff" here - final SecretKey SECRET_KEY = mock(SecretKeySpec.class); - when(SECRET_KEY.getFormat()).thenReturn("NOT-RAW"); - - final DataKey dataKey = new DataKey(SECRET_KEY, new byte[0], - "aws-kms".getBytes(StandardCharsets.UTF_8), mock(MasterKey.class)); - - final AWSKMS client = mock(AWSKMS.class); - when(client.encrypt(any())) - .thenReturn(new EncryptResult() - .withKeyId(keyIdentifier) - .withCiphertextBlob(ByteBuffer.allocate(10))); - - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - masterKey.setGrantTokens(GRANT_TOKENS); - - assertThrows( - "Only RAW encoded keys are supported", - IllegalArgumentException.class, - () -> masterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey)); - } - - @Test - @DisplayName("Postcondition: Must have an AWS KMS ARN from AWS KMS encrypt.") - public void need_an_arn() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final SecretKey SECRET_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); - - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); - - final DataKey dataKey = new DataKey(SECRET_KEY, new byte[0], - "aws-kms".getBytes(StandardCharsets.UTF_8), mock(MasterKey.class)); - - final AWSKMS client = mock(AWSKMS.class); - when(client.encrypt(any())) - .thenReturn(new EncryptResult() - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 - //= type=test - //# The AWS KMS Encrypt response MUST contain a valid "KeyId". - .withKeyId("b3537ef1-d8dc-4780-9f5a-55776cbb2f7f") - .withCiphertextBlob(ByteBuffer.allocate(10))); - - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - masterKey.setGrantTokens(GRANT_TOKENS); - - assertThrows(IllegalStateException.class, - () -> masterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey)); - } + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // = type=test + // # The AWS KMS SDK client MUST not be null. + public void requires_valid_client() { + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + + assertThrows( + IllegalArgumentException.class, + () -> + AwsKmsMrkAwareMasterKey.getInstance( + null, + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", + mkp)); } - public static class filterEncryptedDataKeys { - @Test - public void basic_use() { - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final String providerId = "aws-kms"; - final EncryptedDataKey edk = new KeyBlob( - providerId, - keyIdentifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - assertTrue(AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys( - providerId, - keyIdentifier, - edk)); - - } - - @Test - public void mrk_specific() { - /* This may be overkill, - * but the whole point - * of multi-region optimization - * is this fuzzy match. - */ - final String configuredIdentifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final String ekdIdentifier = "arn:aws:kms:us-east-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - - final String providerId = "aws-kms"; - final EncryptedDataKey edk = new KeyBlob( - providerId, - ekdIdentifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - assertTrue(AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys( - providerId, - configuredIdentifier, - edk)); - - } - - @Test - public void provider_info_must_be_arn() { - final String configuredIdentifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final String rawKeyId = "mrk-edb7fe6942894d32ac46dbb1c922d574"; - final String alias = "arn:aws:kms:us-west-2:111122223333:alias/mrk-edb7fe6942894d32ac46dbb1c922d574"; - - final String providerId = "aws-kms"; - final EncryptedDataKey edkNotArn = new KeyBlob( - providerId, - rawKeyId.getBytes(StandardCharsets.UTF_8), - new byte[10]); + @Test + @DisplayName("Precondition: A provider is required.") + public void requires_valid_provider() { + AWSKMS client = spy(new MockKMSClient()); + + assertThrows( + IllegalArgumentException.class, + () -> + AwsKmsMrkAwareMasterKey.getInstance( + client, + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", + null)); + } + } + + public static class generateDataKey { + + @Test + public void basic_use() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final ByteBuffer udk = ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()); + final ByteBuffer ciphertext = ByteBuffer.allocate(10); + + final AWSKMS client = mock(AWSKMS.class); + when(client.generateDataKey(any())) + .thenReturn( + new GenerateDataKeyResult() + .withPlaintext(udk) + .withKeyId(keyIdentifier) + .withCiphertextBlob(ciphertext)); + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.6 + // = type=test + // # The master key MUST be able to be configured with an optional list of + // # Grant Tokens. + masterKey.setGrantTokens(GRANT_TOKENS); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // = type=test + // # The inputs MUST be the same as the Master Key Generate Data Key + // # (../master-key-interface.md#generate-data-key) interface. + DataKey test = + masterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT); + ArgumentCaptor gr = + ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // = type=test + // # This + // # master key MUST use the configured AWS KMS client to make an AWS KMS + // # GenerateDatakey (https://docs.aws.amazon.com/kms/latest/APIReference/ + // # API_GenerateDataKey.html) request constructed as follows: + verify(client, times(1)).generateDataKey(gr.capture()); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // = type=test + // # The output MUST be the same as the Master Key Generate Data Key + // # (../master-key-interface.md#generate-data-key) interface. + assertTrue(DataKey.class.isInstance(test)); + + GenerateDataKeyRequest actualRequest = gr.getValue(); + + assertEquals(keyIdentifier, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertEquals( + ALGORITHM_SUITE.getDataKeyLength(), actualRequest.getNumberOfBytes().longValue()); + assertTrue( + actualRequest + .getRequestClientOptions() + .getClientMarker(RequestClientOptions.Marker.USER_AGENT) + .contains(VersionInfo.loadUserAgent())); + + assertNotNull(test.getKey()); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // = type=test + // # The response's "Plaintext" MUST be the plaintext in + // # the output. + assertEquals(ALGORITHM_SUITE.getDataKeyLength(), test.getKey().getEncoded().length); + assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), test.getKey().getAlgorithm()); + assertNotNull(test.getEncryptedDataKey()); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // = type=test + // # The response's cipher text blob MUST be used as the + // # returned as the ciphertext for the encrypted data key in the output. + assertEquals(10, test.getEncryptedDataKey().length); + } - final EncryptedDataKey edkAliasArn = new KeyBlob( - providerId, - rawKeyId.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# Additionally each provider info MUST be a valid AWS KMS ARN - //# (aws-kms-key-arn.md#a-valid-aws-kms-arn) with a resource type of - //# "key". - assertThrows( - IllegalStateException.class, - () -> AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys( + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // = type=test + // # If the call succeeds the AWS KMS Generate Data Key response's + // # "Plaintext" MUST match the key derivation input length specified by + // # the algorithm suite included in the input. + public void length_must_match() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + // I use more, because less _should_ trigger an underflow... but the condition should _always_ + // fail + final int wrongLength = ALGORITHM_SUITE.getDataKeyLength() + 1; + + final AWSKMS client = mock(AWSKMS.class); + when(client.generateDataKey(any())) + .thenReturn( + new GenerateDataKeyResult() + .withPlaintext(ByteBuffer.allocate(wrongLength)) + .withKeyId(keyIdentifier) + .withCiphertextBlob(ByteBuffer.allocate(10))); + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + + assertThrows( + IllegalStateException.class, + () -> masterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + + @Test + @DisplayName( + "Exceptional Postcondition: Must have an AWS KMS ARN from AWS KMS generateDataKey.") + public void need_an_arn() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + + final AWSKMS client = mock(AWSKMS.class); + when(client.generateDataKey(any())) + .thenReturn( + new GenerateDataKeyResult() + .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())) + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.10 + // = type=test + // # The response's "KeyId" + // # MUST be valid. + .withKeyId("b3537ef1-d8dc-4780-9f5a-55776cbb2f7f") + .withCiphertextBlob(ByteBuffer.allocate(10))); + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + + assertThrows( + IllegalStateException.class, + () -> masterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + } + + public static class encryptDataKey { + @Test + public void basic_use() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final SecretKey SECRET_KEY = + new SecretKeySpec( + generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); + + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); + + final DataKey dataKey = + new DataKey( + SECRET_KEY, + new byte[0], + "aws-kms".getBytes(StandardCharsets.UTF_8), + mock(MasterKey.class)); + + final AWSKMS client = mock(AWSKMS.class); + when(client.encrypt(any())) + .thenReturn( + new EncryptResult() + .withKeyId(keyIdentifier) + .withCiphertextBlob(ByteBuffer.allocate(10))); + + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + masterKey.setGrantTokens(GRANT_TOKENS); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // = type=test + // # The inputs MUST be the same as the Master Key Encrypt Data Key + // # (../master-key-interface.md#encrypt-data-key) interface. + DataKey test = + masterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // = type=test + // # The output MUST be the same as the Master Key Encrypt Data Key + // # (../master-key-interface.md#encrypt-data-key) interface. + assertTrue(DataKey.class.isInstance(test)); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // = type=test + // # The master + // # key MUST use the configured AWS KMS client to make an AWS KMS Encrypt + // # (https://docs.aws.amazon.com/kms/latest/APIReference/ + // # API_Encrypt.html) request constructed as follows: + verify(client, times(1)).encrypt(any()); + ArgumentCaptor gr = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(gr.capture()); + + final EncryptRequest actualRequest = gr.getValue(); + + assertEquals(keyIdentifier, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertTrue( + actualRequest + .getRequestClientOptions() + .getClientMarker(RequestClientOptions.Marker.USER_AGENT) + .contains(VersionInfo.loadUserAgent())); + + assertNotNull(test.getKey()); + assertEquals(ALGORITHM_SUITE.getDataKeyLength(), test.getKey().getEncoded().length); + assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), test.getKey().getAlgorithm()); + assertNotNull(test.getEncryptedDataKey()); + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // = type=test + // # The + // # response's cipher text blob MUST be used as the "ciphertext" for the + // # encrypted data key. + assertEquals(10, test.getEncryptedDataKey().length); + } + + @Test + @DisplayName("Precondition: The key format MUST be RAW.") + public void secret_key_must_be_raw() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); + + // Test "stuff" here + final SecretKey SECRET_KEY = mock(SecretKeySpec.class); + when(SECRET_KEY.getFormat()).thenReturn("NOT-RAW"); + + final DataKey dataKey = + new DataKey( + SECRET_KEY, + new byte[0], + "aws-kms".getBytes(StandardCharsets.UTF_8), + mock(MasterKey.class)); + + final AWSKMS client = mock(AWSKMS.class); + when(client.encrypt(any())) + .thenReturn( + new EncryptResult() + .withKeyId(keyIdentifier) + .withCiphertextBlob(ByteBuffer.allocate(10))); + + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + masterKey.setGrantTokens(GRANT_TOKENS); + + assertThrows( + "Only RAW encoded keys are supported", + IllegalArgumentException.class, + () -> masterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey)); + } + + @Test + @DisplayName("Postcondition: Must have an AWS KMS ARN from AWS KMS encrypt.") + public void need_an_arn() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final SecretKey SECRET_KEY = + new SecretKeySpec( + generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); + + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn("aws-kms"); + + final DataKey dataKey = + new DataKey( + SECRET_KEY, + new byte[0], + "aws-kms".getBytes(StandardCharsets.UTF_8), + mock(MasterKey.class)); + + final AWSKMS client = mock(AWSKMS.class); + when(client.encrypt(any())) + .thenReturn( + new EncryptResult() + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.11 + // = type=test + // # The AWS KMS Encrypt response MUST contain a valid "KeyId". + .withKeyId("b3537ef1-d8dc-4780-9f5a-55776cbb2f7f") + .withCiphertextBlob(ByteBuffer.allocate(10))); + + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + masterKey.setGrantTokens(GRANT_TOKENS); + + assertThrows( + IllegalStateException.class, + () -> masterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey)); + } + } + + public static class filterEncryptedDataKeys { + @Test + public void basic_use() { + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final String providerId = "aws-kms"; + final EncryptedDataKey edk = + new KeyBlob(providerId, keyIdentifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + assertTrue(AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys(providerId, keyIdentifier, edk)); + } + + @Test + public void mrk_specific() { + /* This may be overkill, + * but the whole point + * of multi-region optimization + * is this fuzzy match. + */ + final String configuredIdentifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final String ekdIdentifier = + "arn:aws:kms:us-east-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + + final String providerId = "aws-kms"; + final EncryptedDataKey edk = + new KeyBlob(providerId, ekdIdentifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + assertTrue( + AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys(providerId, configuredIdentifier, edk)); + } + + @Test + public void provider_info_must_be_arn() { + final String configuredIdentifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final String rawKeyId = "mrk-edb7fe6942894d32ac46dbb1c922d574"; + final String alias = + "arn:aws:kms:us-west-2:111122223333:alias/mrk-edb7fe6942894d32ac46dbb1c922d574"; + + final String providerId = "aws-kms"; + final EncryptedDataKey edkNotArn = + new KeyBlob(providerId, rawKeyId.getBytes(StandardCharsets.UTF_8), new byte[10]); + + final EncryptedDataKey edkAliasArn = + new KeyBlob(providerId, rawKeyId.getBytes(StandardCharsets.UTF_8), new byte[10]); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # Additionally each provider info MUST be a valid AWS KMS ARN + // # (aws-kms-key-arn.md#a-valid-aws-kms-arn) with a resource type of + // # "key". + assertThrows( + IllegalStateException.class, + () -> + AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys( + providerId, configuredIdentifier, edkNotArn)); + assertThrows( + IllegalStateException.class, + () -> + AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys( + providerId, configuredIdentifier, edkAliasArn)); + } + + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # To match the encrypted data key's + // # provider ID MUST exactly match the value "aws-kms" and the the + // # function AWS KMS MRK Match for Decrypt (aws-kms-mrk-match-for- + // # decrypt.md#implementation) called with the configured AWS KMS key + // # identifier and the encrypted data key's provider info MUST return + // # "true". + public void may_not_match() { + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final String providerId = "aws-kms"; + final EncryptedDataKey edk = + new KeyBlob(providerId, keyIdentifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + assertFalse( + AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys("not-aws-kms", keyIdentifier, edk)); + + assertFalse( + AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys( providerId, - configuredIdentifier, - edkNotArn)); - assertThrows( - IllegalStateException.class, - () -> AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys( - providerId, - configuredIdentifier, - edkAliasArn)); - - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# To match the encrypted data key's - //# provider ID MUST exactly match the value "aws-kms" and the the - //# function AWS KMS MRK Match for Decrypt (aws-kms-mrk-match-for- - //# decrypt.md#implementation) called with the configured AWS KMS key - //# identifier and the encrypted data key's provider info MUST return - //# "true". - public void may_not_match() { - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final String providerId = "aws-kms"; - final EncryptedDataKey edk = new KeyBlob( - providerId, - keyIdentifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - assertFalse(AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys( - "not-aws-kms", - keyIdentifier, - edk)); - - assertFalse(AwsKmsMrkAwareMasterKey.filterEncryptedDataKeys( - providerId, - "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", - edk)); - } + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574", + edk)); + } + } + + public static class decryptSingleEncryptedDataKey { + @Test + public void basic_use() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final String providerId = "aws-kms"; + final EncryptedDataKey edk = + new KeyBlob(providerId, keyIdentifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(providerId); + + final AWSKMS client = mock(AWSKMS.class); + when(client.decrypt(any())) + .thenReturn( + new DecryptResult() + .withKeyId(keyIdentifier) + .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); + + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + masterKey.setGrantTokens(GRANT_TOKENS); + + DataKey test = + AwsKmsMrkAwareMasterKey.decryptSingleEncryptedDataKey( + any(), client, keyIdentifier, GRANT_TOKENS, ALGORITHM_SUITE, edk, ENCRYPTION_CONTEXT); + + verify(client, times(1)).decrypt(any()); + ArgumentCaptor gr = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(gr.capture()); + + final DecryptRequest actualRequest = gr.getValue(); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # To decrypt the encrypted data key this master key MUST use the + // # configured AWS KMS client to make an AWS KMS Decrypt + // # (https://docs.aws.amazon.com/kms/latest/APIReference/ + // # API_Decrypt.html) request constructed as follows: + assertEquals(keyIdentifier, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertTrue( + actualRequest + .getRequestClientOptions() + .getClientMarker(RequestClientOptions.Marker.USER_AGENT) + .contains(VersionInfo.loadUserAgent())); + + assertNotNull(test.getKey()); + assertEquals(ALGORITHM_SUITE.getDataKeyLength(), test.getKey().getEncoded().length); + assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), test.getKey().getAlgorithm()); } - public static class decryptSingleEncryptedDataKey { - @Test - public void basic_use() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final String providerId = "aws-kms"; - final EncryptedDataKey edk = new KeyBlob( - providerId, - keyIdentifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(providerId); - - final AWSKMS client = mock(AWSKMS.class); - when(client.decrypt(any())) - .thenReturn(new DecryptResult() - .withKeyId(keyIdentifier) - .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); - - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - masterKey.setGrantTokens(GRANT_TOKENS); - - DataKey test = AwsKmsMrkAwareMasterKey.decryptSingleEncryptedDataKey( - any(), - client, - keyIdentifier, - GRANT_TOKENS, - ALGORITHM_SUITE, - edk, - ENCRYPTION_CONTEXT - ); - - verify(client, times(1)).decrypt(any()); - ArgumentCaptor gr = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(1)).decrypt(gr.capture()); - - final DecryptRequest actualRequest = gr.getValue(); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# To decrypt the encrypted data key this master key MUST use the - //# configured AWS KMS client to make an AWS KMS Decrypt - //# (https://docs.aws.amazon.com/kms/latest/APIReference/ - //# API_Decrypt.html) request constructed as follows: - assertEquals(keyIdentifier, actualRequest.getKeyId()); - assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); - assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); - assertTrue(actualRequest.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT) - .contains(VersionInfo.loadUserAgent())); - - assertNotNull(test.getKey()); - assertEquals(ALGORITHM_SUITE.getDataKeyLength(), test.getKey().getEncoded().length); - assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), test.getKey().getAlgorithm()); - } - - - @Test - @DisplayName("Exceptional Postcondition: Must have a CMK ARN from AWS KMS to match.") - public void expect_key_arn() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final String providerId = "aws-kms"; - final EncryptedDataKey edk = new KeyBlob( - providerId, - keyIdentifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(providerId); - - final AWSKMS client = mock(AWSKMS.class); - when(client.decrypt(any())) - .thenReturn(new DecryptResult() - .withKeyId(null) - .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); - - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - masterKey.setGrantTokens(GRANT_TOKENS); - - assertThrows(IllegalStateException.class, () -> AwsKmsMrkAwareMasterKey.decryptSingleEncryptedDataKey( - any(), - client, - keyIdentifier, - GRANT_TOKENS, - ALGORITHM_SUITE, - edk, - ENCRYPTION_CONTEXT - )); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# If the call succeeds then the response's "KeyId" MUST be equal to the - //# configured AWS KMS key identifier otherwise the function MUST collect - //# an error. - public void returned_arn_must_match() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final String providerId = "aws-kms"; - final EncryptedDataKey edk = new KeyBlob( - providerId, - keyIdentifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(providerId); - - final AWSKMS client = mock(AWSKMS.class); - when(client.decrypt(any())) - .thenReturn(new DecryptResult() - .withKeyId("arn:aws:kms:us-west-2:658956600833:key/something-else") - .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); - - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - masterKey.setGrantTokens(GRANT_TOKENS); - - assertThrows(IllegalStateException.class, () -> AwsKmsMrkAwareMasterKey.decryptSingleEncryptedDataKey( - any(), - client, - keyIdentifier, - GRANT_TOKENS, - ALGORITHM_SUITE, - edk, - ENCRYPTION_CONTEXT - )); - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# The response's "Plaintext"'s length MUST equal the length - //# required by the requested algorithm suite otherwise the function MUST - //# collect an error. - public void key_length_must_match() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final String providerId = "aws-kms"; - // I use more, because less _should_ trigger an underflow... but the condition should _always_ fail - final int wrongLength = ALGORITHM_SUITE.getDataKeyLength() + 1; - final EncryptedDataKey edk = new KeyBlob( - providerId, - keyIdentifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(providerId); - - final AWSKMS client = mock(AWSKMS.class); - when(client.decrypt(any())) - .thenReturn(new DecryptResult() - .withKeyId(keyIdentifier) - .withPlaintext(ByteBuffer.allocate(wrongLength))); - - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - masterKey.setGrantTokens(GRANT_TOKENS); - - assertThrows(IllegalStateException.class, () -> AwsKmsMrkAwareMasterKey.decryptSingleEncryptedDataKey( - any(), - client, - keyIdentifier, - GRANT_TOKENS, - ALGORITHM_SUITE, - edk, - ENCRYPTION_CONTEXT - )); - } + @Test + @DisplayName("Exceptional Postcondition: Must have a CMK ARN from AWS KMS to match.") + public void expect_key_arn() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final String providerId = "aws-kms"; + final EncryptedDataKey edk = + new KeyBlob(providerId, keyIdentifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(providerId); + + final AWSKMS client = mock(AWSKMS.class); + when(client.decrypt(any())) + .thenReturn( + new DecryptResult() + .withKeyId(null) + .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); + + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + masterKey.setGrantTokens(GRANT_TOKENS); + + assertThrows( + IllegalStateException.class, + () -> + AwsKmsMrkAwareMasterKey.decryptSingleEncryptedDataKey( + any(), + client, + keyIdentifier, + GRANT_TOKENS, + ALGORITHM_SUITE, + edk, + ENCRYPTION_CONTEXT)); } - public static class decryptDataKey { - - - @Test - public void basic_use() { - final String keyIdentifier = "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final byte[] cipherText = new byte[10]; - final String providerId = "aws-kms"; - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(providerId); - - final EncryptedDataKey edk1 = new KeyBlob( - "aws-kms", - keyIdentifier.getBytes(StandardCharsets.UTF_8), - cipherText); - final EncryptedDataKey edk2 = new KeyBlob( - "aws-kms", - keyIdentifier.getBytes(StandardCharsets.UTF_8), - cipherText); - - final AWSKMS client = mock(AWSKMS.class); - when(client.decrypt(any())) - .thenReturn(new DecryptResult() - .withKeyId(keyIdentifier) - .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); - - final AwsKmsMrkAwareMasterKey mk = AwsKmsMrkAwareMasterKey - .getInstance(client, - keyIdentifier, - mkp); - mk.setGrantTokens(GRANT_TOKENS); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# The inputs MUST be the same as the Master Key Decrypt Data Key - //# (../master-key-interface.md#decrypt-data-key) interface. - final DataKey test = mk - .decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(edk1, edk2), - ENCRYPTION_CONTEXT); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# For each encrypted data key in the filtered set, one at a time, the - //# master key MUST attempt to decrypt the data key. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# If the AWS KMS response satisfies the requirements then it MUST be - //# use and this function MUST return and not attempt to decrypt any more - //# encrypted data keys. - verify(client, times((1))).decrypt(new DecryptRequest() - .withGrantTokens(GRANT_TOKENS) - .withEncryptionContext(ENCRYPTION_CONTEXT) - .withKeyId(keyIdentifier) - .withCiphertextBlob(ByteBuffer.wrap(cipherText)) - ); - - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# The output MUST be the same as the Master Key Decrypt Data Key - //# (../master-key-interface.md#decrypt-data-key) interface. - assertTrue(DataKey.class.isInstance(test)); - - } - - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# The set of encrypted data keys MUST first be filtered to match this - //# master key's configuration. - public void edk_match() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final String providerId = "aws-kms"; - final String clientErrMsg = "asdf"; - final EncryptedDataKey edk1 = new KeyBlob( - "not-aws-kms", - keyIdentifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - final EncryptedDataKey edk2 = new KeyBlob( - providerId, - "not-key-identifier".getBytes(StandardCharsets.UTF_8), - new byte[10]); - - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(providerId); - - final AWSKMS client = mock(AWSKMS.class); - when(client.decrypt(any())).thenThrow(new AmazonServiceException(clientErrMsg)); - final KmsMasterKeyProvider.RegionalClientSupplier supplier = mock(KmsMasterKeyProvider.RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - final AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - masterKey.setGrantTokens(GRANT_TOKENS); - - final CannotUnwrapDataKeyException testProviderNotMatch = assertThrows( - "Unable to decrypt any data keys", - CannotUnwrapDataKeyException.class, () -> masterKey.decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(edk1), - ENCRYPTION_CONTEXT - )); - assertEquals(0, testProviderNotMatch.getSuppressed().length); - - final IllegalStateException testArnNotMatch = assertThrows( - "Unable to decrypt any data keys", - IllegalStateException.class, () -> masterKey.decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(edk2), - ENCRYPTION_CONTEXT - )); - assertEquals(0, testArnNotMatch.getSuppressed().length); - } - - @Test - @DisplayName("Exceptional Postcondition: Master key was unable to decrypt.") - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# If this attempt - //# results in an error, then these errors MUST be collected. - // - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 - //= type=test - //# If all the input encrypted data keys have been processed then this - //# function MUST yield an error that includes all the collected errors. - public void exception_wrapped() { - final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - final String keyIdentifier = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - final String providerId = "aws-kms"; - final String clientErrMsg = "asdf"; - final EncryptedDataKey edk = new KeyBlob( - providerId, - keyIdentifier.getBytes(StandardCharsets.UTF_8), - new byte[10]); - - final MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(providerId); - - final AWSKMS client = mock(AWSKMS.class); - when(client.decrypt(any())).thenThrow(new AmazonServiceException(clientErrMsg)); - - KmsMasterKeyProvider.RegionalClientSupplier supplier = mock(KmsMasterKeyProvider.RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - AwsKmsMrkAwareMasterKey masterKey = AwsKmsMrkAwareMasterKey - .getInstance( - client, - keyIdentifier, - mkp); - - masterKey.setGrantTokens(GRANT_TOKENS); - - final CannotUnwrapDataKeyException test = assertThrows( - "Unable to decrypt any data keys", - CannotUnwrapDataKeyException.class, () -> masterKey.decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(edk), - ENCRYPTION_CONTEXT - )); - assertEquals(1, test.getSuppressed().length); - Throwable fromClient = Arrays.stream(test.getSuppressed()).findFirst().get(); - assertTrue(fromClient instanceof AmazonServiceException); - assertTrue(fromClient.getMessage().startsWith(clientErrMsg)); - } + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # If the call succeeds then the response's "KeyId" MUST be equal to the + // # configured AWS KMS key identifier otherwise the function MUST collect + // # an error. + public void returned_arn_must_match() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final String providerId = "aws-kms"; + final EncryptedDataKey edk = + new KeyBlob(providerId, keyIdentifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(providerId); + + final AWSKMS client = mock(AWSKMS.class); + when(client.decrypt(any())) + .thenReturn( + new DecryptResult() + .withKeyId("arn:aws:kms:us-west-2:658956600833:key/something-else") + .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); + + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + masterKey.setGrantTokens(GRANT_TOKENS); + + assertThrows( + IllegalStateException.class, + () -> + AwsKmsMrkAwareMasterKey.decryptSingleEncryptedDataKey( + any(), + client, + keyIdentifier, + GRANT_TOKENS, + ALGORITHM_SUITE, + edk, + ENCRYPTION_CONTEXT)); } - public static class getMasterKey { - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.7 - //= type=test - //# MUST be unchanged from the Master Key interface. - public void test_get_master_key() throws NoSuchMethodException { - String methodName = "getMasterKey"; - Class[] parameterTypes = new Class[]{ String.class, String.class }; - // Make sure the signature is correct by fetching the base method - Method baseMethod = MasterKey.class.getDeclaredMethod(methodName, parameterTypes); - assertNotNull(baseMethod); - // Assert AwsKmsMrkAwareMasterKey does not declare the same method directly - assertThrows(NoSuchMethodException.class, () -> AwsKmsMrkAwareMasterKey.class.getDeclaredMethod(methodName, parameterTypes)); - } + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # The response's "Plaintext"'s length MUST equal the length + // # required by the requested algorithm suite otherwise the function MUST + // # collect an error. + public void key_length_must_match() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final String providerId = "aws-kms"; + // I use more, because less _should_ trigger an underflow... but the condition should _always_ + // fail + final int wrongLength = ALGORITHM_SUITE.getDataKeyLength() + 1; + final EncryptedDataKey edk = + new KeyBlob(providerId, keyIdentifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(providerId); + + final AWSKMS client = mock(AWSKMS.class); + when(client.decrypt(any())) + .thenReturn( + new DecryptResult() + .withKeyId(keyIdentifier) + .withPlaintext(ByteBuffer.allocate(wrongLength))); + + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + masterKey.setGrantTokens(GRANT_TOKENS); + + assertThrows( + IllegalStateException.class, + () -> + AwsKmsMrkAwareMasterKey.decryptSingleEncryptedDataKey( + any(), + client, + keyIdentifier, + GRANT_TOKENS, + ALGORITHM_SUITE, + edk, + ENCRYPTION_CONTEXT)); + } + } + + public static class decryptDataKey { + + @Test + public void basic_use() { + final String keyIdentifier = + "arn:aws:kms:us-west-2:111122223333:key/mrk-edb7fe6942894d32ac46dbb1c922d574"; + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final byte[] cipherText = new byte[10]; + final String providerId = "aws-kms"; + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(providerId); + + final EncryptedDataKey edk1 = + new KeyBlob("aws-kms", keyIdentifier.getBytes(StandardCharsets.UTF_8), cipherText); + final EncryptedDataKey edk2 = + new KeyBlob("aws-kms", keyIdentifier.getBytes(StandardCharsets.UTF_8), cipherText); + + final AWSKMS client = mock(AWSKMS.class); + when(client.decrypt(any())) + .thenReturn( + new DecryptResult() + .withKeyId(keyIdentifier) + .withPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength()))); + + final AwsKmsMrkAwareMasterKey mk = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + mk.setGrantTokens(GRANT_TOKENS); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # The inputs MUST be the same as the Master Key Decrypt Data Key + // # (../master-key-interface.md#decrypt-data-key) interface. + final DataKey test = + mk.decryptDataKey(ALGORITHM_SUITE, Arrays.asList(edk1, edk2), ENCRYPTION_CONTEXT); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # For each encrypted data key in the filtered set, one at a time, the + // # master key MUST attempt to decrypt the data key. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # If the AWS KMS response satisfies the requirements then it MUST be + // # use and this function MUST return and not attempt to decrypt any more + // # encrypted data keys. + verify(client, times((1))) + .decrypt( + new DecryptRequest() + .withGrantTokens(GRANT_TOKENS) + .withEncryptionContext(ENCRYPTION_CONTEXT) + .withKeyId(keyIdentifier) + .withCiphertextBlob(ByteBuffer.wrap(cipherText))); + + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # The output MUST be the same as the Master Key Decrypt Data Key + // # (../master-key-interface.md#decrypt-data-key) interface. + assertTrue(DataKey.class.isInstance(test)); } + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # The set of encrypted data keys MUST first be filtered to match this + // # master key's configuration. + public void edk_match() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final String providerId = "aws-kms"; + final String clientErrMsg = "asdf"; + final EncryptedDataKey edk1 = + new KeyBlob("not-aws-kms", keyIdentifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + final EncryptedDataKey edk2 = + new KeyBlob( + providerId, "not-key-identifier".getBytes(StandardCharsets.UTF_8), new byte[10]); + + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(providerId); + + final AWSKMS client = mock(AWSKMS.class); + when(client.decrypt(any())).thenThrow(new AmazonServiceException(clientErrMsg)); + final KmsMasterKeyProvider.RegionalClientSupplier supplier = + mock(KmsMasterKeyProvider.RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + final AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + masterKey.setGrantTokens(GRANT_TOKENS); + + final CannotUnwrapDataKeyException testProviderNotMatch = + assertThrows( + "Unable to decrypt any data keys", + CannotUnwrapDataKeyException.class, + () -> + masterKey.decryptDataKey( + ALGORITHM_SUITE, Arrays.asList(edk1), ENCRYPTION_CONTEXT)); + assertEquals(0, testProviderNotMatch.getSuppressed().length); + + final IllegalStateException testArnNotMatch = + assertThrows( + "Unable to decrypt any data keys", + IllegalStateException.class, + () -> + masterKey.decryptDataKey( + ALGORITHM_SUITE, Arrays.asList(edk2), ENCRYPTION_CONTEXT)); + assertEquals(0, testArnNotMatch.getSuppressed().length); + } - public static class getMasterKeysForEncryption { - @Test - //= compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.8 - //= type=test - //# MUST be unchanged from the Master Key interface. - public void test_getMasterKeysForEncryption() throws NoSuchMethodException { - String methodName = "getMasterKeysForEncryption"; - Class[] parameterTypes = new Class[]{ MasterKeyRequest.class }; - - // Make sure the signature is correct by fetching the base method - Method baseMethod = MasterKey.class.getDeclaredMethod(methodName, parameterTypes); - assertNotNull(baseMethod); - // Assert AwsKmsMrkAwareMasterKey does no declare the same method directly - assertThrows(NoSuchMethodException.class, () -> AwsKmsMrkAwareMasterKey.class.getDeclaredMethod(methodName, parameterTypes)); - } + @Test + @DisplayName("Exceptional Postcondition: Master key was unable to decrypt.") + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # If this attempt + // # results in an error, then these errors MUST be collected. + // + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.9 + // = type=test + // # If all the input encrypted data keys have been processed then this + // # function MUST yield an error that includes all the collected errors. + public void exception_wrapped() { + final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); + final String keyIdentifier = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; + final String providerId = "aws-kms"; + final String clientErrMsg = "asdf"; + final EncryptedDataKey edk = + new KeyBlob(providerId, keyIdentifier.getBytes(StandardCharsets.UTF_8), new byte[10]); + + final MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(providerId); + + final AWSKMS client = mock(AWSKMS.class); + when(client.decrypt(any())).thenThrow(new AmazonServiceException(clientErrMsg)); + + KmsMasterKeyProvider.RegionalClientSupplier supplier = + mock(KmsMasterKeyProvider.RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + AwsKmsMrkAwareMasterKey masterKey = + AwsKmsMrkAwareMasterKey.getInstance(client, keyIdentifier, mkp); + + masterKey.setGrantTokens(GRANT_TOKENS); + + final CannotUnwrapDataKeyException test = + assertThrows( + "Unable to decrypt any data keys", + CannotUnwrapDataKeyException.class, + () -> + masterKey.decryptDataKey( + ALGORITHM_SUITE, Arrays.asList(edk), ENCRYPTION_CONTEXT)); + assertEquals(1, test.getSuppressed().length); + Throwable fromClient = Arrays.stream(test.getSuppressed()).findFirst().get(); + assertTrue(fromClient instanceof AmazonServiceException); + assertTrue(fromClient.getMessage().startsWith(clientErrMsg)); + } + } + + public static class getMasterKey { + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.7 + // = type=test + // # MUST be unchanged from the Master Key interface. + public void test_get_master_key() throws NoSuchMethodException { + String methodName = "getMasterKey"; + Class[] parameterTypes = new Class[] {String.class, String.class}; + // Make sure the signature is correct by fetching the base method + Method baseMethod = MasterKey.class.getDeclaredMethod(methodName, parameterTypes); + assertNotNull(baseMethod); + // Assert AwsKmsMrkAwareMasterKey does not declare the same method directly + assertThrows( + NoSuchMethodException.class, + () -> AwsKmsMrkAwareMasterKey.class.getDeclaredMethod(methodName, parameterTypes)); + } + } + + public static class getMasterKeysForEncryption { + @Test + // = compliance/framework/aws-kms/aws-kms-mrk-aware-master-key.txt#2.8 + // = type=test + // # MUST be unchanged from the Master Key interface. + public void test_getMasterKeysForEncryption() throws NoSuchMethodException { + String methodName = "getMasterKeysForEncryption"; + Class[] parameterTypes = new Class[] {MasterKeyRequest.class}; + + // Make sure the signature is correct by fetching the base method + Method baseMethod = MasterKey.class.getDeclaredMethod(methodName, parameterTypes); + assertNotNull(baseMethod); + // Assert AwsKmsMrkAwareMasterKey does no declare the same method directly + assertThrows( + NoSuchMethodException.class, + () -> AwsKmsMrkAwareMasterKey.class.getDeclaredMethod(methodName, parameterTypes)); } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/DiscoveryFilterTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/DiscoveryFilterTest.java index ef3675565..2e1f899a5 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/DiscoveryFilterTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/DiscoveryFilterTest.java @@ -3,65 +3,65 @@ package com.amazonaws.encryptionsdk.kms; -import org.junit.Test; +import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; +import static org.junit.Assert.assertNotNull; + import java.util.Arrays; import java.util.Collections; import java.util.List; - -import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; -import static org.junit.Assert.assertNotNull; +import org.junit.Test; public class DiscoveryFilterTest { - @Test - public void testValidConstruct() throws Exception { - DiscoveryFilter filter = new DiscoveryFilter("partition", Arrays.asList("accountId")); - assertNotNull(filter); + @Test + public void testValidConstruct() throws Exception { + DiscoveryFilter filter = new DiscoveryFilter("partition", Arrays.asList("accountId")); + assertNotNull(filter); - DiscoveryFilter filter2 = new DiscoveryFilter("partition", "accountId1", "accountId2"); - assertNotNull(filter2); - } + DiscoveryFilter filter2 = new DiscoveryFilter("partition", "accountId1", "accountId2"); + assertNotNull(filter2); + } - @Test - public void testConstructWithEmptyPartition() throws Exception { - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter("", Arrays.asList("accountId"))); - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter("", "accountId")); - } + @Test + public void testConstructWithEmptyPartition() throws Exception { + assertThrows( + IllegalArgumentException.class, () -> new DiscoveryFilter("", Arrays.asList("accountId"))); + assertThrows(IllegalArgumentException.class, () -> new DiscoveryFilter("", "accountId")); + } - @Test - public void testConstructWithNullPartition() throws Exception { - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter(null, Arrays.asList("accountId"))); - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter(null, "accountId")); - } + @Test + public void testConstructWithNullPartition() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> new DiscoveryFilter(null, Arrays.asList("accountId"))); + assertThrows(IllegalArgumentException.class, () -> new DiscoveryFilter(null, "accountId")); + } - @Test - public void testConstructWithEmptyIds() throws Exception { - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter("aws", Collections.emptyList())); - } + @Test + public void testConstructWithEmptyIds() throws Exception { + assertThrows( + IllegalArgumentException.class, () -> new DiscoveryFilter("aws", Collections.emptyList())); + } - @Test - public void testConstructWithNullIds() throws Exception { - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter("aws", (List) null)); - } + @Test + public void testConstructWithNullIds() throws Exception { + assertThrows( + IllegalArgumentException.class, () -> new DiscoveryFilter("aws", (List) null)); + } - @Test - public void testConstructWithIdsContainingEmptyId() throws Exception { - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter("aws", Arrays.asList("accountId", ""))); - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter("aws", "accountId", "")); - } + @Test + public void testConstructWithIdsContainingEmptyId() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> new DiscoveryFilter("aws", Arrays.asList("accountId", ""))); + assertThrows(IllegalArgumentException.class, () -> new DiscoveryFilter("aws", "accountId", "")); + } - @Test - public void testConstructWithIdsContainingNullId() throws Exception { - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter("aws", Arrays.asList("accountId", null))); - assertThrows(IllegalArgumentException.class, () -> - new DiscoveryFilter("aws", "accountId", null)); - } + @Test + public void testConstructWithIdsContainingNullId() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> new DiscoveryFilter("aws", Arrays.asList("accountId", null))); + assertThrows( + IllegalArgumentException.class, () -> new DiscoveryFilter("aws", "accountId", null)); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderIntegrationTests.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderIntegrationTests.java index a76ce2137..c30d2163a 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderIntegrationTests.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderIntegrationTests.java @@ -16,18 +16,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Collections; -import java.util.Map; -import java.util.HashMap; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; - import com.amazonaws.AbortedException; import com.amazonaws.ClientConfiguration; import com.amazonaws.Request; @@ -43,391 +31,415 @@ import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; import com.amazonaws.encryptionsdk.internal.VersionInfo; -import com.amazonaws.encryptionsdk.CommitmentPolicy; -import com.amazonaws.encryptionsdk.model.KeyBlob; import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier; +import com.amazonaws.encryptionsdk.model.KeyBlob; import com.amazonaws.handlers.RequestHandler2; import com.amazonaws.http.exception.HttpRequestTimeoutException; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClientBuilder; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; public class KMSProviderBuilderIntegrationTests { - private static final String AWS_KMS_PROVIDER_ID = "aws-kms"; - - private AWSKMS testUSWestClient__; - private AWSKMS testEUCentralClient__; - private RegionalClientSupplier testClientSupplier__; - - @Before - public void setup() { - testUSWestClient__ = spy(AWSKMSClientBuilder.standard().withRegion("us-west-2").build()); - testEUCentralClient__ = spy(AWSKMSClientBuilder.standard().withRegion("eu-central-1").build()); - testClientSupplier__ = regionName -> { - if (regionName.equals("us-west-2")) { - return testUSWestClient__; - } else if (regionName.equals("eu-central-1")) { - return testEUCentralClient__; - } else { - throw new AwsCryptoException("test supplier only configured for us-west-2 and eu-central-1"); - } + private static final String AWS_KMS_PROVIDER_ID = "aws-kms"; + + private AWSKMS testUSWestClient__; + private AWSKMS testEUCentralClient__; + private RegionalClientSupplier testClientSupplier__; + + @Before + public void setup() { + testUSWestClient__ = spy(AWSKMSClientBuilder.standard().withRegion("us-west-2").build()); + testEUCentralClient__ = spy(AWSKMSClientBuilder.standard().withRegion("eu-central-1").build()); + testClientSupplier__ = + regionName -> { + if (regionName.equals("us-west-2")) { + return testUSWestClient__; + } else if (regionName.equals("eu-central-1")) { + return testEUCentralClient__; + } else { + throw new AwsCryptoException( + "test supplier only configured for us-west-2 and eu-central-1"); + } }; - } + } - @Test - public void whenBogusRegionsDecrypted_doesNotLeakClients() throws Exception { - AtomicReference> kmsCache = new AtomicReference<>(); + @Test + public void whenBogusRegionsDecrypted_doesNotLeakClients() throws Exception { + AtomicReference> kmsCache = new AtomicReference<>(); - KmsMasterKeyProvider mkp = (new KmsMasterKeyProvider.Builder() { - @Override protected void snoopClientCache( - final ConcurrentHashMap map - ) { + KmsMasterKeyProvider mkp = + (new KmsMasterKeyProvider.Builder() { + @Override + protected void snoopClientCache(final ConcurrentHashMap map) { kmsCache.set(map); - } - }).buildDiscovery(); - - try { - mkp.decryptDataKey( - CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256, - Collections.singleton( - new KeyBlob("aws-kms", - "arn:aws:kms:us-bogus-1:123456789010:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f" - .getBytes(StandardCharsets.UTF_8), - new byte[40] - ) - ), - new HashMap<>() - ); - fail("Expected CannotUnwrapDataKeyException"); - } catch (CannotUnwrapDataKeyException e) { - // ok - } - - assertTrue(kmsCache.get().isEmpty()); + } + }) + .buildDiscovery(); + + try { + mkp.decryptDataKey( + CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256, + Collections.singleton( + new KeyBlob( + "aws-kms", + "arn:aws:kms:us-bogus-1:123456789010:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f" + .getBytes(StandardCharsets.UTF_8), + new byte[40])), + new HashMap<>()); + fail("Expected CannotUnwrapDataKeyException"); + } catch (CannotUnwrapDataKeyException e) { + // ok } - @Test - public void whenOperationSuccessful_clientIsCached() { - AtomicReference> kmsCache = new AtomicReference<>(); - - KmsMasterKeyProvider mkp = (new KmsMasterKeyProvider.Builder() { - @Override protected void snoopClientCache( - final ConcurrentHashMap map - ) { - kmsCache.set(map); - } - }).buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - - AwsCrypto.standard().encryptData(mkp, new byte[1]); - - AWSKMS kms = kmsCache.get().get("us-west-2"); - assertNotNull(kms); + assertTrue(kmsCache.get().isEmpty()); + } - AwsCrypto.standard().encryptData(mkp, new byte[1]); + @Test + public void whenOperationSuccessful_clientIsCached() { + AtomicReference> kmsCache = new AtomicReference<>(); - // Cache entry should stay the same - assertEquals(kms, kmsCache.get().get("us-west-2")); - } - - @Test - public void whenConstructedWithoutArguments_canUseMultipleRegions() throws Exception { - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder().buildDiscovery(); - - for (String key : KMSTestFixtures.TEST_KEY_IDS) { - byte[] ciphertext = - AwsCrypto.standard().encryptData( - KmsMasterKeyProvider.builder() - .buildStrict(key), - new byte[1] - ).getResult(); - - AwsCrypto.standard().decryptData(mkp, ciphertext); - } - } + KmsMasterKeyProvider mkp = + (new KmsMasterKeyProvider.Builder() { + @Override + protected void snoopClientCache(final ConcurrentHashMap map) { + kmsCache.set(map); + } + }) + .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - @Test - public void whenConstructedInStrictMode_encryptDecrypt() throws Exception { - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + AwsCrypto.standard().encryptData(mkp, new byte[1]); - byte[] ciphertext = AwsCrypto.standard().encryptData(mkp, new byte[1]).getResult(); - verify(testUSWestClient__, times(1)).generateDataKey(any()); + AWSKMS kms = kmsCache.get().get("us-west-2"); + assertNotNull(kms); - AwsCrypto.standard().decryptData(mkp, ciphertext); - verify(testUSWestClient__, times(1)).decrypt(any()); - } + AwsCrypto.standard().encryptData(mkp, new byte[1]); - @Test - public void whenConstructedInStrictMode_encryptDecryptMultipleCmks() throws Exception { - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .buildStrict( - KMSTestFixtures.US_WEST_2_KEY_ID, - KMSTestFixtures.EU_CENTRAL_1_KEY_ID); + // Cache entry should stay the same + assertEquals(kms, kmsCache.get().get("us-west-2")); + } - byte[] ciphertext = AwsCrypto.standard().encryptData(mkp, new byte[1]).getResult(); - verify(testUSWestClient__, times(1)).generateDataKey(any()); - verify(testEUCentralClient__, times(1)).encrypt(any()); + @Test + public void whenConstructedWithoutArguments_canUseMultipleRegions() throws Exception { + KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder().buildDiscovery(); - AwsCrypto.standard().decryptData(mkp, ciphertext); - verify(testUSWestClient__, times(1)).decrypt(any()); - } + for (String key : KMSTestFixtures.TEST_KEY_IDS) { + byte[] ciphertext = + AwsCrypto.standard() + .encryptData(KmsMasterKeyProvider.builder().buildStrict(key), new byte[1]) + .getResult(); - @Test - public void whenConstructedInStrictMode_encryptSingleBadKeyIdFails() throws Exception { - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .withDefaultRegion("us-west-2") - .buildStrict( - KMSTestFixtures.US_WEST_2_KEY_ID, - "badKeyId"); - - assertThrows(AwsCryptoException.class, () -> AwsCrypto.standard().encryptData(mkp, new byte[1]).getResult()); - verify(testUSWestClient__, times(1)).generateDataKey(any()); - verify(testUSWestClient__, times(1)).encrypt(any()); + AwsCrypto.standard().decryptData(mkp, ciphertext); } + } - @Test - public void whenConstructedInStrictMode_decryptBadEDKFails() throws Exception { - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .withDefaultRegion("us-west-2") - .buildStrict("badKeyId"); - - final CryptoAlgorithm algSuite = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final Map encCtx = Collections.singletonMap("myKey", "myValue"); - final EncryptedDataKey badEDK = new KeyBlob(AWS_KMS_PROVIDER_ID, - "badKeyId".getBytes(StandardCharsets.UTF_8), new byte[algSuite.getDataKeyLength()]); - - assertThrows(CannotUnwrapDataKeyException.class, () -> - mkp.decryptDataKey(algSuite, Collections.singletonList(badEDK), encCtx)); - verify(testUSWestClient__, times(1)).decrypt(any()); - } + @Test + public void whenConstructedInStrictMode_encryptDecrypt() throws Exception { + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - @Test - public void whenConstructedInDiscoveryMode_decrypt() throws Exception { - KmsMasterKeyProvider singleCmkMkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - byte[] singleCmkCiphertext = AwsCrypto.standard().encryptData(singleCmkMkp, new byte[1]).getResult(); - - KmsMasterKeyProvider mkpToTest = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .buildDiscovery(); - AwsCrypto.standard().decryptData(mkpToTest, singleCmkCiphertext); - verify(testUSWestClient__, times(1)).decrypt(any()); - } + byte[] ciphertext = AwsCrypto.standard().encryptData(mkp, new byte[1]).getResult(); + verify(testUSWestClient__, times(1)).generateDataKey(any()); - @Test - public void whenConstructedInDiscoveryMode_decryptBadEDKFails() throws Exception { - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .withDefaultRegion("us-west-2") - .buildDiscovery(); - - final CryptoAlgorithm algSuite = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final Map encCtx = Collections.singletonMap("myKey", "myValue"); - final EncryptedDataKey badEDK = new KeyBlob(AWS_KMS_PROVIDER_ID, - "badKeyId".getBytes(StandardCharsets.UTF_8), new byte[algSuite.getDataKeyLength()]); - - assertThrows(CannotUnwrapDataKeyException.class, () -> - mkp.decryptDataKey(algSuite, Collections.singletonList(badEDK), encCtx)); - verify(testUSWestClient__, times(1)).decrypt(any()); - } + AwsCrypto.standard().decryptData(mkp, ciphertext); + verify(testUSWestClient__, times(1)).decrypt(any()); + } + @Test + public void whenConstructedInStrictMode_encryptDecryptMultipleCmks() throws Exception { + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .buildStrict(KMSTestFixtures.US_WEST_2_KEY_ID, KMSTestFixtures.EU_CENTRAL_1_KEY_ID); - @Test - public void whenConstructedWithDiscoveryFilter_decrypt() throws Exception { - KmsMasterKeyProvider singleCmkMkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + byte[] ciphertext = AwsCrypto.standard().encryptData(mkp, new byte[1]).getResult(); + verify(testUSWestClient__, times(1)).generateDataKey(any()); + verify(testEUCentralClient__, times(1)).encrypt(any()); - byte[] singleCmkCiphertext = AwsCrypto.standard().encryptData(singleCmkMkp, new byte[1]).getResult(); + AwsCrypto.standard().decryptData(mkp, ciphertext); + verify(testUSWestClient__, times(1)).decrypt(any()); + } - KmsMasterKeyProvider mkpToTest = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .buildDiscovery(new DiscoveryFilter( - KMSTestFixtures.PARTITION, - Arrays.asList(KMSTestFixtures.ACCOUNT_ID))); + @Test + public void whenConstructedInStrictMode_encryptSingleBadKeyIdFails() throws Exception { + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .withDefaultRegion("us-west-2") + .buildStrict(KMSTestFixtures.US_WEST_2_KEY_ID, "badKeyId"); + + assertThrows( + AwsCryptoException.class, + () -> AwsCrypto.standard().encryptData(mkp, new byte[1]).getResult()); + verify(testUSWestClient__, times(1)).generateDataKey(any()); + verify(testUSWestClient__, times(1)).encrypt(any()); + } + + @Test + public void whenConstructedInStrictMode_decryptBadEDKFails() throws Exception { + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .withDefaultRegion("us-west-2") + .buildStrict("badKeyId"); + + final CryptoAlgorithm algSuite = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final Map encCtx = Collections.singletonMap("myKey", "myValue"); + final EncryptedDataKey badEDK = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + "badKeyId".getBytes(StandardCharsets.UTF_8), + new byte[algSuite.getDataKeyLength()]); + + assertThrows( + CannotUnwrapDataKeyException.class, + () -> mkp.decryptDataKey(algSuite, Collections.singletonList(badEDK), encCtx)); + verify(testUSWestClient__, times(1)).decrypt(any()); + } + + @Test + public void whenConstructedInDiscoveryMode_decrypt() throws Exception { + KmsMasterKeyProvider singleCmkMkp = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + byte[] singleCmkCiphertext = + AwsCrypto.standard().encryptData(singleCmkMkp, new byte[1]).getResult(); - AwsCrypto.standard().decryptData(mkpToTest, singleCmkCiphertext); - verify(testUSWestClient__, times(1)).decrypt(any()); - } + KmsMasterKeyProvider mkpToTest = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .buildDiscovery(); + AwsCrypto.standard().decryptData(mkpToTest, singleCmkCiphertext); + verify(testUSWestClient__, times(1)).decrypt(any()); + } + + @Test + public void whenConstructedInDiscoveryMode_decryptBadEDKFails() throws Exception { + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .withDefaultRegion("us-west-2") + .buildDiscovery(); + + final CryptoAlgorithm algSuite = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final Map encCtx = Collections.singletonMap("myKey", "myValue"); + final EncryptedDataKey badEDK = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + "badKeyId".getBytes(StandardCharsets.UTF_8), + new byte[algSuite.getDataKeyLength()]); + + assertThrows( + CannotUnwrapDataKeyException.class, + () -> mkp.decryptDataKey(algSuite, Collections.singletonList(badEDK), encCtx)); + verify(testUSWestClient__, times(1)).decrypt(any()); + } + + @Test + public void whenConstructedWithDiscoveryFilter_decrypt() throws Exception { + KmsMasterKeyProvider singleCmkMkp = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - @Test - public void whenConstructedWithDiscoveryFilter_decryptBadEDKFails() throws Exception { - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier__) - .withDefaultRegion("us-west-2") - .buildDiscovery(new DiscoveryFilter( - KMSTestFixtures.PARTITION, - Arrays.asList(KMSTestFixtures.ACCOUNT_ID))); - - final CryptoAlgorithm algSuite = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - final Map encCtx = Collections.singletonMap("myKey", "myValue"); - final String badARN = "arn:aws:kms:us-west-2:658956600833:key/badID"; - final EncryptedDataKey badEDK = new KeyBlob(AWS_KMS_PROVIDER_ID, - badARN.getBytes(StandardCharsets.UTF_8), new byte[algSuite.getDataKeyLength()]); - - assertThrows(CannotUnwrapDataKeyException.class, () -> - mkp.decryptDataKey(algSuite, Collections.singletonList(badEDK), encCtx)); - verify(testUSWestClient__, times(1)).decrypt(any()); - } + byte[] singleCmkCiphertext = + AwsCrypto.standard().encryptData(singleCmkMkp, new byte[1]).getResult(); - @Test - public void whenHandlerConfigured_handlerIsInvoked() throws Exception { - RequestHandler2 handler = spy(new RequestHandler2() {}); - KmsMasterKeyProvider mkp = - KmsMasterKeyProvider.builder() - .withClientBuilder( - AWSKMSClientBuilder.standard() - .withRequestHandlers(handler) - ) - .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + KmsMasterKeyProvider mkpToTest = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .buildDiscovery( + new DiscoveryFilter( + KMSTestFixtures.PARTITION, Arrays.asList(KMSTestFixtures.ACCOUNT_ID))); + + AwsCrypto.standard().decryptData(mkpToTest, singleCmkCiphertext); + verify(testUSWestClient__, times(1)).decrypt(any()); + } + + @Test + public void whenConstructedWithDiscoveryFilter_decryptBadEDKFails() throws Exception { + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier__) + .withDefaultRegion("us-west-2") + .buildDiscovery( + new DiscoveryFilter( + KMSTestFixtures.PARTITION, Arrays.asList(KMSTestFixtures.ACCOUNT_ID))); + + final CryptoAlgorithm algSuite = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + final Map encCtx = Collections.singletonMap("myKey", "myValue"); + final String badARN = "arn:aws:kms:us-west-2:658956600833:key/badID"; + final EncryptedDataKey badEDK = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + badARN.getBytes(StandardCharsets.UTF_8), + new byte[algSuite.getDataKeyLength()]); + + assertThrows( + CannotUnwrapDataKeyException.class, + () -> mkp.decryptDataKey(algSuite, Collections.singletonList(badEDK), encCtx)); + verify(testUSWestClient__, times(1)).decrypt(any()); + } + + @Test + public void whenHandlerConfigured_handlerIsInvoked() throws Exception { + RequestHandler2 handler = spy(new RequestHandler2() {}); + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withClientBuilder(AWSKMSClientBuilder.standard().withRequestHandlers(handler)) + .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - AwsCrypto.standard().encryptData(mkp, new byte[1]); + AwsCrypto.standard().encryptData(mkp, new byte[1]); - verify(handler).beforeRequest(any()); - } + verify(handler).beforeRequest(any()); + } - @Test - public void whenShortTimeoutSet_timesOut() throws Exception { - // By setting a timeout of 1ms, it's not physically possible to complete both the us-west-2 and eu-central-1 - // requests due to speed of light limits. - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withClientBuilder( - AWSKMSClientBuilder.standard() - .withClientConfiguration( - new ClientConfiguration() - .withRequestTimeout(1) - ) - ) - .buildStrict(Arrays.asList(KMSTestFixtures.TEST_KEY_IDS)); - - try { - AwsCrypto.standard().encryptData(mkp, new byte[1]); - fail("Expected exception"); - } catch (Exception e) { - if (e instanceof AbortedException) { - // ok - one manifestation of a timeout - } else if (e.getCause() instanceof HttpRequestTimeoutException) { - // ok - another kind of timeout - } else { - throw e; - } - } + @Test + public void whenShortTimeoutSet_timesOut() throws Exception { + // By setting a timeout of 1ms, it's not physically possible to complete both the us-west-2 and + // eu-central-1 + // requests due to speed of light limits. + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withClientBuilder( + AWSKMSClientBuilder.standard() + .withClientConfiguration(new ClientConfiguration().withRequestTimeout(1))) + .buildStrict(Arrays.asList(KMSTestFixtures.TEST_KEY_IDS)); + + try { + AwsCrypto.standard().encryptData(mkp, new byte[1]); + fail("Expected exception"); + } catch (Exception e) { + if (e instanceof AbortedException) { + // ok - one manifestation of a timeout + } else if (e.getCause() instanceof HttpRequestTimeoutException) { + // ok - another kind of timeout + } else { + throw e; + } } + } - @Test - public void whenCustomCredentialsSet_theyAreUsed() throws Exception { - AWSCredentialsProvider customProvider = spy(new DefaultAWSCredentialsProviderChain()); - - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCredentials(customProvider) - .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - - AwsCrypto.standard().encryptData(mkp, new byte[1]); + @Test + public void whenCustomCredentialsSet_theyAreUsed() throws Exception { + AWSCredentialsProvider customProvider = spy(new DefaultAWSCredentialsProviderChain()); - verify(customProvider, atLeastOnce()).getCredentials(); + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withCredentials(customProvider) + .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - AWSCredentials customCredentials = spy(customProvider.getCredentials()); + AwsCrypto.standard().encryptData(mkp, new byte[1]); - mkp = KmsMasterKeyProvider.builder() - .withCredentials(customCredentials) - .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + verify(customProvider, atLeastOnce()).getCredentials(); - AwsCrypto.standard().encryptData(mkp, new byte[1]); + AWSCredentials customCredentials = spy(customProvider.getCredentials()); - verify(customCredentials, atLeastOnce()).getAWSSecretKey(); - } + mkp = + KmsMasterKeyProvider.builder() + .withCredentials(customCredentials) + .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - @Test - public void whenBuilderCloned_configurationIsRetained() throws Exception { - AWSCredentialsProvider customProvider1 = spy(new DefaultAWSCredentialsProviderChain()); - AWSCredentialsProvider customProvider2 = spy(new DefaultAWSCredentialsProviderChain()); + AwsCrypto.standard().encryptData(mkp, new byte[1]); - KmsMasterKeyProvider.Builder builder = KmsMasterKeyProvider.builder() - .withCredentials(customProvider1); + verify(customCredentials, atLeastOnce()).getAWSSecretKey(); + } - KmsMasterKeyProvider.Builder builder2 = builder.clone(); + @Test + public void whenBuilderCloned_configurationIsRetained() throws Exception { + AWSCredentialsProvider customProvider1 = spy(new DefaultAWSCredentialsProviderChain()); + AWSCredentialsProvider customProvider2 = spy(new DefaultAWSCredentialsProviderChain()); - // This will mutate the first builder to change the creds, but leave the clone unchanged. - MasterKeyProvider mkp2 = builder.withCredentials(customProvider2) - .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - MasterKeyProvider mkp1 = builder2.buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + KmsMasterKeyProvider.Builder builder = + KmsMasterKeyProvider.builder().withCredentials(customProvider1); - CryptoResult result = AwsCrypto.standard().encryptData(mkp1, new byte[0]); + KmsMasterKeyProvider.Builder builder2 = builder.clone(); - verify(customProvider1, atLeastOnce()).getCredentials(); - verify(customProvider2, never()).getCredentials(); + // This will mutate the first builder to change the creds, but leave the clone unchanged. + MasterKeyProvider mkp2 = + builder.withCredentials(customProvider2).buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + MasterKeyProvider mkp1 = builder2.buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - reset(customProvider1, customProvider2); + CryptoResult result = AwsCrypto.standard().encryptData(mkp1, new byte[0]); - result = AwsCrypto.standard().encryptData(mkp2, new byte[0]); + verify(customProvider1, atLeastOnce()).getCredentials(); + verify(customProvider2, never()).getCredentials(); - verify(customProvider1, never()).getCredentials(); - verify(customProvider2, atLeastOnce()).getCredentials(); - } + reset(customProvider1, customProvider2); - @Test - public void whenBuilderCloned_clientBuilderCustomizationIsRetained() throws Exception { - RequestHandler2 handler = spy(new RequestHandler2() {}); + result = AwsCrypto.standard().encryptData(mkp2, new byte[0]); - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withClientBuilder( - AWSKMSClientBuilder.standard().withRequestHandlers(handler) - ) - .clone().buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + verify(customProvider1, never()).getCredentials(); + verify(customProvider2, atLeastOnce()).getCredentials(); + } - AwsCrypto.standard().encryptData(mkp, new byte[0]); + @Test + public void whenBuilderCloned_clientBuilderCustomizationIsRetained() throws Exception { + RequestHandler2 handler = spy(new RequestHandler2() {}); - verify(handler, atLeastOnce()).beforeRequest(any()); - } - - @Test(expected = IllegalArgumentException.class) - public void whenBogusEndpointIsSet_constructionFails() throws Exception { + KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withClientBuilder( - AWSKMSClientBuilder.standard() - .withEndpointConfiguration( - new AwsClientBuilder.EndpointConfiguration( - "https://this.does.not.exist.example.com", - "bad-region") - ) - ); - } - - @Test - public void whenUserAgentsOverridden_originalUAsPreserved() throws Exception { - RequestHandler2 handler = spy(new RequestHandler2() {}); - - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withClientBuilder( - AWSKMSClientBuilder.standard().withRequestHandlers(handler) - .withClientConfiguration( - new ClientConfiguration() - .withUserAgentPrefix("TEST-UA-PREFIX") - .withUserAgentSuffix("TEST-UA-SUFFIX") - ) - ) - .clone().buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); - - AwsCrypto.standard().encryptData(mkp, new byte[0]); - - ArgumentCaptor captor = ArgumentCaptor.forClass(Request.class); - verify(handler, atLeastOnce()).beforeRequest(captor.capture()); - - String ua = (String)captor.getValue().getHeaders().get("User-Agent"); - - assertTrue(ua.contains("TEST-UA-PREFIX")); - assertTrue(ua.contains("TEST-UA-SUFFIX")); - assertTrue(ua.contains(VersionInfo.loadUserAgent())); - } - - @Test - public void whenDefaultRegionSet_itIsUsedForBareKeyIds() throws Exception { - // TODO: Need to set up a role to assume as bare key IDs are relative to the caller account - } + .withClientBuilder(AWSKMSClientBuilder.standard().withRequestHandlers(handler)) + .clone() + .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + + AwsCrypto.standard().encryptData(mkp, new byte[0]); + + verify(handler, atLeastOnce()).beforeRequest(any()); + } + + @Test(expected = IllegalArgumentException.class) + public void whenBogusEndpointIsSet_constructionFails() throws Exception { + KmsMasterKeyProvider.builder() + .withClientBuilder( + AWSKMSClientBuilder.standard() + .withEndpointConfiguration( + new AwsClientBuilder.EndpointConfiguration( + "https://this.does.not.exist.example.com", "bad-region"))); + } + + @Test + public void whenUserAgentsOverridden_originalUAsPreserved() throws Exception { + RequestHandler2 handler = spy(new RequestHandler2() {}); + + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withClientBuilder( + AWSKMSClientBuilder.standard() + .withRequestHandlers(handler) + .withClientConfiguration( + new ClientConfiguration() + .withUserAgentPrefix("TEST-UA-PREFIX") + .withUserAgentSuffix("TEST-UA-SUFFIX"))) + .clone() + .buildStrict(KMSTestFixtures.TEST_KEY_IDS[0]); + + AwsCrypto.standard().encryptData(mkp, new byte[0]); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Request.class); + verify(handler, atLeastOnce()).beforeRequest(captor.capture()); + + String ua = (String) captor.getValue().getHeaders().get("User-Agent"); + + assertTrue(ua.contains("TEST-UA-PREFIX")); + assertTrue(ua.contains("TEST-UA-SUFFIX")); + assertTrue(ua.contains(VersionInfo.loadUserAgent())); + } + + @Test + public void whenDefaultRegionSet_itIsUsedForBareKeyIds() throws Exception { + // TODO: Need to set up a role to assume as bare key IDs are relative to the caller account + } } - diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderMockTests.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderMockTests.java index c3f7bb74c..2f3c82ce6 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderMockTests.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderMockTests.java @@ -6,7 +6,6 @@ import static com.amazonaws.encryptionsdk.multi.MultipleProviderFactory.buildMultiProvider; import static java.util.Collections.singletonList; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.notNull; @@ -18,171 +17,176 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -import java.util.Arrays; - -import org.junit.Test; -import org.mockito.ArgumentCaptor; - import com.amazonaws.AmazonWebServiceRequest; import com.amazonaws.RequestClientOptions; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.MasterKeyProvider; import com.amazonaws.encryptionsdk.internal.VersionInfo; -import com.amazonaws.encryptionsdk.CommitmentPolicy; import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier; import com.amazonaws.services.kms.model.CreateAliasRequest; import com.amazonaws.services.kms.model.DecryptRequest; import com.amazonaws.services.kms.model.EncryptRequest; import com.amazonaws.services.kms.model.GenerateDataKeyRequest; +import java.util.Arrays; +import org.junit.Test; +import org.mockito.ArgumentCaptor; public class KMSProviderBuilderMockTests { - @Test - public void testBareAliasMapping() { - MockKMSClient client = spy(new MockKMSClient()); + @Test + public void testBareAliasMapping() { + MockKMSClient client = spy(new MockKMSClient()); - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(notNull())).thenReturn(client); + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(notNull())).thenReturn(client); - String key1 = client.createKey().getKeyMetadata().getKeyId(); - client.createAlias(new CreateAliasRequest() - .withAliasName("foo") - .withTargetKeyId(key1) - ); + String key1 = client.createKey().getKeyMetadata().getKeyId(); + client.createAlias(new CreateAliasRequest().withAliasName("foo").withTargetKeyId(key1)); - KmsMasterKeyProvider mkp0 = KmsMasterKeyProvider.builder() - .withCustomClientFactory(supplier) - .withDefaultRegion("us-west-2") - .buildStrict("alias/foo"); + KmsMasterKeyProvider mkp0 = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .withDefaultRegion("us-west-2") + .buildStrict("alias/foo"); - AwsCrypto.standard().encryptData(mkp0, new byte[0]); - } + AwsCrypto.standard().encryptData(mkp0, new byte[0]); + } - @Test - public void testGrantTokenPassthrough_usingMKsetCall() throws Exception { - MockKMSClient client = spy(new MockKMSClient()); + @Test + public void testGrantTokenPassthrough_usingMKsetCall() throws Exception { + MockKMSClient client = spy(new MockKMSClient()); - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); - String key1 = client.createKey().getKeyMetadata().getArn(); - String key2 = client.createKey().getKeyMetadata().getArn(); + String key1 = client.createKey().getKeyMetadata().getArn(); + String key2 = client.createKey().getKeyMetadata().getArn(); - KmsMasterKeyProvider mkp0 = KmsMasterKeyProvider.builder() - .withDefaultRegion("us-west-2") - .withCustomClientFactory(supplier) - .buildStrict(key1, key2); - KmsMasterKey mk1 = mkp0.getMasterKey(key1); - KmsMasterKey mk2 = mkp0.getMasterKey(key2); + KmsMasterKeyProvider mkp0 = + KmsMasterKeyProvider.builder() + .withDefaultRegion("us-west-2") + .withCustomClientFactory(supplier) + .buildStrict(key1, key2); + KmsMasterKey mk1 = mkp0.getMasterKey(key1); + KmsMasterKey mk2 = mkp0.getMasterKey(key2); - mk1.setGrantTokens(singletonList("foo")); - mk2.setGrantTokens(singletonList("foo")); + mk1.setGrantTokens(singletonList("foo")); + mk2.setGrantTokens(singletonList("foo")); - MasterKeyProvider mkp = buildMultiProvider(mk1, mk2); + MasterKeyProvider mkp = buildMultiProvider(mk1, mk2); - byte[] ciphertext = AwsCrypto.standard().encryptData(mkp, new byte[0]).getResult(); + byte[] ciphertext = AwsCrypto.standard().encryptData(mkp, new byte[0]).getResult(); - ArgumentCaptor gdkr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); - verify(client, times(1)).generateDataKey(gdkr.capture()); + ArgumentCaptor gdkr = + ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + verify(client, times(1)).generateDataKey(gdkr.capture()); - assertEquals(key1, gdkr.getValue().getKeyId()); - assertEquals(1, gdkr.getValue().getGrantTokens().size()); - assertEquals("foo", gdkr.getValue().getGrantTokens().get(0)); + assertEquals(key1, gdkr.getValue().getKeyId()); + assertEquals(1, gdkr.getValue().getGrantTokens().size()); + assertEquals("foo", gdkr.getValue().getGrantTokens().get(0)); - ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); - verify(client, times(1)).encrypt(er.capture()); + ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(er.capture()); - assertEquals(key2, er.getValue().getKeyId()); - assertEquals(1, er.getValue().getGrantTokens().size()); - assertEquals("foo", er.getValue().getGrantTokens().get(0)); + assertEquals(key2, er.getValue().getKeyId()); + assertEquals(1, er.getValue().getGrantTokens().size()); + assertEquals("foo", er.getValue().getGrantTokens().get(0)); - AwsCrypto.standard().decryptData(mkp, ciphertext); + AwsCrypto.standard().decryptData(mkp, ciphertext); - ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(1)).decrypt(decrypt.capture()); + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); - assertEquals(1, decrypt.getValue().getGrantTokens().size()); - assertEquals("foo", decrypt.getValue().getGrantTokens().get(0)); + assertEquals(1, decrypt.getValue().getGrantTokens().size()); + assertEquals("foo", decrypt.getValue().getGrantTokens().get(0)); - verify(supplier, atLeastOnce()).getClient("us-west-2"); - verifyNoMoreInteractions(supplier); - } + verify(supplier, atLeastOnce()).getClient("us-west-2"); + verifyNoMoreInteractions(supplier); + } - @Test - public void testGrantTokenPassthrough_usingMKPWithers() throws Exception { - MockKMSClient client = spy(new MockKMSClient()); + @Test + public void testGrantTokenPassthrough_usingMKPWithers() throws Exception { + MockKMSClient client = spy(new MockKMSClient()); - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); - String key1 = client.createKey().getKeyMetadata().getArn(); - String key2 = client.createKey().getKeyMetadata().getArn(); + String key1 = client.createKey().getKeyMetadata().getArn(); + String key2 = client.createKey().getKeyMetadata().getArn(); - KmsMasterKeyProvider mkp0 = KmsMasterKeyProvider.builder() - .withDefaultRegion("us-west-2") - .withCustomClientFactory(supplier) - .buildStrict(key1, key2); + KmsMasterKeyProvider mkp0 = + KmsMasterKeyProvider.builder() + .withDefaultRegion("us-west-2") + .withCustomClientFactory(supplier) + .buildStrict(key1, key2); - MasterKeyProvider mkp = mkp0.withGrantTokens("foo"); + MasterKeyProvider mkp = mkp0.withGrantTokens("foo"); - byte[] ciphertext = AwsCrypto.standard().encryptData(mkp, new byte[0]).getResult(); + byte[] ciphertext = AwsCrypto.standard().encryptData(mkp, new byte[0]).getResult(); - ArgumentCaptor gdkr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); - verify(client, times(1)).generateDataKey(gdkr.capture()); + ArgumentCaptor gdkr = + ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + verify(client, times(1)).generateDataKey(gdkr.capture()); - assertEquals(key1, gdkr.getValue().getKeyId()); - assertEquals(1, gdkr.getValue().getGrantTokens().size()); - assertEquals("foo", gdkr.getValue().getGrantTokens().get(0)); + assertEquals(key1, gdkr.getValue().getKeyId()); + assertEquals(1, gdkr.getValue().getGrantTokens().size()); + assertEquals("foo", gdkr.getValue().getGrantTokens().get(0)); - ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); - verify(client, times(1)).encrypt(er.capture()); + ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(er.capture()); - assertEquals(key2, er.getValue().getKeyId()); - assertEquals(1, er.getValue().getGrantTokens().size()); - assertEquals("foo", er.getValue().getGrantTokens().get(0)); + assertEquals(key2, er.getValue().getKeyId()); + assertEquals(1, er.getValue().getGrantTokens().size()); + assertEquals("foo", er.getValue().getGrantTokens().get(0)); - mkp = mkp0.withGrantTokens(Arrays.asList("bar")); + mkp = mkp0.withGrantTokens(Arrays.asList("bar")); - AwsCrypto.standard().decryptData(mkp, ciphertext); + AwsCrypto.standard().decryptData(mkp, ciphertext); - ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(1)).decrypt(decrypt.capture()); + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); - assertEquals(1, decrypt.getValue().getGrantTokens().size()); - assertEquals("bar", decrypt.getValue().getGrantTokens().get(0)); + assertEquals(1, decrypt.getValue().getGrantTokens().size()); + assertEquals("bar", decrypt.getValue().getGrantTokens().get(0)); - verify(supplier, atLeastOnce()).getClient("us-west-2"); - verifyNoMoreInteractions(supplier); - } + verify(supplier, atLeastOnce()).getClient("us-west-2"); + verifyNoMoreInteractions(supplier); + } - @Test - public void testUserAgentPassthrough() throws Exception { - MockKMSClient client = spy(new MockKMSClient()); + @Test + public void testUserAgentPassthrough() throws Exception { + MockKMSClient client = spy(new MockKMSClient()); - String key1 = client.createKey().getKeyMetadata().getArn(); - String key2 = client.createKey().getKeyMetadata().getArn(); + String key1 = client.createKey().getKeyMetadata().getArn(); + String key2 = client.createKey().getKeyMetadata().getArn(); - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(ignored -> client) - .buildStrict(key1, key2); + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(ignored -> client) + .buildStrict(key1, key2); - AwsCrypto.standard().decryptData(mkp, AwsCrypto.standard().encryptData(mkp, new byte[0]).getResult()); + AwsCrypto.standard() + .decryptData(mkp, AwsCrypto.standard().encryptData(mkp, new byte[0]).getResult()); - ArgumentCaptor gdkr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); - verify(client, times(1)).generateDataKey(gdkr.capture()); - assertTrue(getUA(gdkr.getValue()).contains(VersionInfo.loadUserAgent())); + ArgumentCaptor gdkr = + ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + verify(client, times(1)).generateDataKey(gdkr.capture()); + assertTrue(getUA(gdkr.getValue()).contains(VersionInfo.loadUserAgent())); - ArgumentCaptor encr = ArgumentCaptor.forClass(EncryptRequest.class); - verify(client, times(1)).encrypt(encr.capture()); - assertTrue(getUA(encr.getValue()).contains(VersionInfo.loadUserAgent())); + ArgumentCaptor encr = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(encr.capture()); + assertTrue(getUA(encr.getValue()).contains(VersionInfo.loadUserAgent())); - ArgumentCaptor decr = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(1)).decrypt(decr.capture()); - assertTrue(getUA(decr.getValue()).contains(VersionInfo.loadUserAgent())); - } + ArgumentCaptor decr = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decr.capture()); + assertTrue(getUA(decr.getValue()).contains(VersionInfo.loadUserAgent())); + } - private String getUA(AmazonWebServiceRequest request) { - // Note: This test may break in future versions of the AWS SDK, as Marker is documented as being for internal - // use only. - return request.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT); - } + private String getUA(AmazonWebServiceRequest request) { + // Note: This test may break in future versions of the AWS SDK, as Marker is documented as being + // for internal + // use only. + return request + .getRequestClientOptions() + .getClientMarker(RequestClientOptions.Marker.USER_AGENT); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KMSTestFixtures.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KMSTestFixtures.java index 0484397d7..641e0c788 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/KMSTestFixtures.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KMSTestFixtures.java @@ -4,29 +4,30 @@ package com.amazonaws.encryptionsdk.kms; public final class KMSTestFixtures { - private KMSTestFixtures() { - throw new UnsupportedOperationException( - "This class exists to hold static constants and cannot be instantiated." - ); - } + private KMSTestFixtures() { + throw new UnsupportedOperationException( + "This class exists to hold static constants and cannot be instantiated."); + } - /** - * These special test keys have been configured to allow Encrypt, Decrypt, and GenerateDataKey operations from any - * AWS principal and should be used when adding new KMS tests. - * - * This should go without saying, but never use these keys for production purposes (as anyone in the world can - * decrypt data encrypted using them). - */ - public static final String US_WEST_2_KEY_ID = "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - public static final String EU_CENTRAL_1_KEY_ID = "arn:aws:kms:eu-central-1:658956600833:key/75414c93-5285-4b57-99c9-30c1cf0a22c2"; - public static final String US_EAST_1_MULTI_REGION_KEY_ID = "arn:aws:kms:us-east-1:658956600833:key/mrk-80bd8ecdcd4342aebd84b7dc9da498a7"; - public static final String US_WEST_2_MULTI_REGION_KEY_ID = "arn:aws:kms:us-west-2:658956600833:key/mrk-80bd8ecdcd4342aebd84b7dc9da498a7"; - public static final String ACCOUNT_ID = "658956600833"; - public static final String PARTITION = "aws"; - public static final String US_WEST_2 = "us-west-2"; + /** + * These special test keys have been configured to allow Encrypt, Decrypt, and GenerateDataKey + * operations from any AWS principal and should be used when adding new KMS tests. + * + *

This should go without saying, but never use these keys for production purposes (as anyone + * in the world can decrypt data encrypted using them). + */ + public static final String US_WEST_2_KEY_ID = + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"; - public static final String[] TEST_KEY_IDS = new String[] { - US_WEST_2_KEY_ID, - EU_CENTRAL_1_KEY_ID - }; + public static final String EU_CENTRAL_1_KEY_ID = + "arn:aws:kms:eu-central-1:658956600833:key/75414c93-5285-4b57-99c9-30c1cf0a22c2"; + public static final String US_EAST_1_MULTI_REGION_KEY_ID = + "arn:aws:kms:us-east-1:658956600833:key/mrk-80bd8ecdcd4342aebd84b7dc9da498a7"; + public static final String US_WEST_2_MULTI_REGION_KEY_ID = + "arn:aws:kms:us-west-2:658956600833:key/mrk-80bd8ecdcd4342aebd84b7dc9da498a7"; + public static final String ACCOUNT_ID = "658956600833"; + public static final String PARTITION = "aws"; + public static final String US_WEST_2 = "us-west-2"; + + public static final String[] TEST_KEY_IDS = new String[] {US_WEST_2_KEY_ID, EU_CENTRAL_1_KEY_ID}; } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProviderTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProviderTest.java index 89eac92ff..66dd72922 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProviderTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProviderTest.java @@ -18,404 +18,559 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import org.junit.Test; -import org.junit.experimental.runners.Enclosed; -import org.junit.runners.Parameterized; -import org.junit.runner.RunWith; -import org.mockito.ArgumentCaptor; - import com.amazonaws.AmazonServiceException; import com.amazonaws.AmazonWebServiceRequest; import com.amazonaws.RequestClientOptions; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.DataKey; import com.amazonaws.encryptionsdk.EncryptedDataKey; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.encryptionsdk.MasterKeyProvider; import com.amazonaws.encryptionsdk.MasterKeyRequest; import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; import com.amazonaws.encryptionsdk.internal.VersionInfo; -import com.amazonaws.encryptionsdk.model.KeyBlob; -import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier; import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.Builder; +import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier; +import com.amazonaws.encryptionsdk.model.KeyBlob; import com.amazonaws.services.kms.model.DecryptRequest; import com.amazonaws.services.kms.model.DecryptResult; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.mockito.ArgumentCaptor; @RunWith(Enclosed.class) public class KmsMasterKeyProviderTest { - private static final String AWS_PARTITION = "aws"; - private static final String AWS_KMS_PROVIDER_ID = "aws-kms"; - private static final String OTHER_PARTITION = "not-aws"; - private static final String OTHER_PROVIDER_ID = "not-aws-kms"; - private static final String ACCOUNT_ID = "999999999999"; - private static final String OTHER_ACCOUNT_ID = "000000000000"; - - private static final String KEY_ID_1 = "arn:" + AWS_PARTITION + ":kms:us-east-1:" + ACCOUNT_ID + ":key/01234567-89ab-cdef-fedc-ba9876543210"; - private static final String KEY_ID_2 = "arn:" + AWS_PARTITION + ":kms:us-east-1:" + ACCOUNT_ID + ":key/01234567-89ab-cdef-fedc-ba9876543211"; - private static final String KEY_ID_3 = "arn:" + AWS_PARTITION + ":kms:us-east-1:" + ACCOUNT_ID + ":key/01234567-89ab-cdef-fedc-ba9876543212"; - private static final String KEY_ID_4 = "arn:" + AWS_PARTITION + ":kms:us-east-1:" + OTHER_ACCOUNT_ID + ":key/01234567-89ab-cdef-fedc-ba9876543210"; - private static final String KEY_ID_5 = "arn:" + OTHER_PARTITION + ":kms:us-east-1:" + ACCOUNT_ID + ":key/01234567-89ab-cdef-fedc-ba9876543210"; - - private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - - private static final EncryptedDataKey EDK_ID_1 = new KeyBlob(AWS_KMS_PROVIDER_ID, - KEY_ID_1.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - private static final EncryptedDataKey EDK_ID_1_OTHER_CIPHERTEXT = new KeyBlob(AWS_KMS_PROVIDER_ID, - KEY_ID_1.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - private static final EncryptedDataKey EDK_ID_2 = new KeyBlob(AWS_KMS_PROVIDER_ID, - KEY_ID_2.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - private static final EncryptedDataKey EDK_ID_3 = new KeyBlob(AWS_KMS_PROVIDER_ID, - KEY_ID_3.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - private static final EncryptedDataKey EDK_NON_ARN = new KeyBlob(AWS_KMS_PROVIDER_ID, - "someAlias".getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - private static final EncryptedDataKey EDK_EMPTY_PROVIDER = new KeyBlob("", - "someId".getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - private static final EncryptedDataKey EDK_OTHER_PROVIDER = new KeyBlob(OTHER_PROVIDER_ID, - "someId".getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - private static final EncryptedDataKey EDK_OTHER_ACCOUNT = new KeyBlob(AWS_KMS_PROVIDER_ID, - KEY_ID_4.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - private static final EncryptedDataKey EDK_OTHER_PARTITION = new KeyBlob(AWS_KMS_PROVIDER_ID, - KEY_ID_5.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - - @RunWith(Parameterized.class) - public static class ParameterizedDecryptTest { - MKPTestConfiguration mkpConfig; - List inputEDKs; - List decryptableEDKs; - - private static class MKPTestConfiguration { - // instance vars are public for easier access during testing - public boolean isDiscovery; - public DiscoveryFilter discoveryFilter; - public List keyIds; - - public MKPTestConfiguration(boolean isDiscovery, - DiscoveryFilter discoveryFilter, - List keyIds - ) { - this.isDiscovery = isDiscovery; - this.discoveryFilter = discoveryFilter; - this.keyIds = keyIds; - } - } - - public ParameterizedDecryptTest(MKPTestConfiguration mkpConfig, - List inputEDKs, - List decryptableEDKs - ) { - this.mkpConfig = mkpConfig; - this.inputEDKs = inputEDKs; - this.decryptableEDKs = decryptableEDKs; - } - - @Parameterized.Parameters(name = "{index}: mkpConfig={0}, inputEDKs={1}, decryptableEDKs={2}") - public static Collection testCases() { - // Create MKP configuration options to test against - MKPTestConfiguration strict_oneCMK = new MKPTestConfiguration(false, null, Arrays.asList(KEY_ID_1)); - MKPTestConfiguration strict_twoCMKs = new MKPTestConfiguration(false, null, Arrays.asList(KEY_ID_1, KEY_ID_2)); - MKPTestConfiguration explicitDiscovery = new MKPTestConfiguration(true, null, null); - MKPTestConfiguration explicitDiscovery_filter = new MKPTestConfiguration(true, - new DiscoveryFilter(AWS_PARTITION, Arrays.asList(ACCOUNT_ID)), null); - - // Define all test cases - Collection testCases = Arrays.asList(new Object[][]{ + private static final String AWS_PARTITION = "aws"; + private static final String AWS_KMS_PROVIDER_ID = "aws-kms"; + private static final String OTHER_PARTITION = "not-aws"; + private static final String OTHER_PROVIDER_ID = "not-aws-kms"; + private static final String ACCOUNT_ID = "999999999999"; + private static final String OTHER_ACCOUNT_ID = "000000000000"; + + private static final String KEY_ID_1 = + "arn:" + + AWS_PARTITION + + ":kms:us-east-1:" + + ACCOUNT_ID + + ":key/01234567-89ab-cdef-fedc-ba9876543210"; + private static final String KEY_ID_2 = + "arn:" + + AWS_PARTITION + + ":kms:us-east-1:" + + ACCOUNT_ID + + ":key/01234567-89ab-cdef-fedc-ba9876543211"; + private static final String KEY_ID_3 = + "arn:" + + AWS_PARTITION + + ":kms:us-east-1:" + + ACCOUNT_ID + + ":key/01234567-89ab-cdef-fedc-ba9876543212"; + private static final String KEY_ID_4 = + "arn:" + + AWS_PARTITION + + ":kms:us-east-1:" + + OTHER_ACCOUNT_ID + + ":key/01234567-89ab-cdef-fedc-ba9876543210"; + private static final String KEY_ID_5 = + "arn:" + + OTHER_PARTITION + + ":kms:us-east-1:" + + ACCOUNT_ID + + ":key/01234567-89ab-cdef-fedc-ba9876543210"; + + private static final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + private static final Map ENCRYPTION_CONTEXT = + Collections.singletonMap("myKey", "myValue"); + + private static final EncryptedDataKey EDK_ID_1 = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + KEY_ID_1.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey EDK_ID_1_OTHER_CIPHERTEXT = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + KEY_ID_1.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey EDK_ID_2 = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + KEY_ID_2.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey EDK_ID_3 = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + KEY_ID_3.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey EDK_NON_ARN = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + "someAlias".getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey EDK_EMPTY_PROVIDER = + new KeyBlob( + "", + "someId".getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey EDK_OTHER_PROVIDER = + new KeyBlob( + OTHER_PROVIDER_ID, + "someId".getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey EDK_OTHER_ACCOUNT = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + KEY_ID_4.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + private static final EncryptedDataKey EDK_OTHER_PARTITION = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + KEY_ID_5.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + + @RunWith(Parameterized.class) + public static class ParameterizedDecryptTest { + MKPTestConfiguration mkpConfig; + List inputEDKs; + List decryptableEDKs; + + private static class MKPTestConfiguration { + // instance vars are public for easier access during testing + public boolean isDiscovery; + public DiscoveryFilter discoveryFilter; + public List keyIds; + + public MKPTestConfiguration( + boolean isDiscovery, DiscoveryFilter discoveryFilter, List keyIds) { + this.isDiscovery = isDiscovery; + this.discoveryFilter = discoveryFilter; + this.keyIds = keyIds; + } + } + + public ParameterizedDecryptTest( + MKPTestConfiguration mkpConfig, + List inputEDKs, + List decryptableEDKs) { + this.mkpConfig = mkpConfig; + this.inputEDKs = inputEDKs; + this.decryptableEDKs = decryptableEDKs; + } + + @Parameterized.Parameters(name = "{index}: mkpConfig={0}, inputEDKs={1}, decryptableEDKs={2}") + public static Collection testCases() { + // Create MKP configuration options to test against + MKPTestConfiguration strict_oneCMK = + new MKPTestConfiguration(false, null, Arrays.asList(KEY_ID_1)); + MKPTestConfiguration strict_twoCMKs = + new MKPTestConfiguration(false, null, Arrays.asList(KEY_ID_1, KEY_ID_2)); + MKPTestConfiguration explicitDiscovery = new MKPTestConfiguration(true, null, null); + MKPTestConfiguration explicitDiscovery_filter = + new MKPTestConfiguration( + true, new DiscoveryFilter(AWS_PARTITION, Arrays.asList(ACCOUNT_ID)), null); + + // Define all test cases + Collection testCases = + Arrays.asList( + new Object[][] { // Test cases where no EDKs are expected to be decrypted {strict_oneCMK, Collections.emptyList(), Collections.emptyList()}, {strict_oneCMK, Arrays.asList(EDK_ID_2), Collections.emptyList()}, {strict_oneCMK, Arrays.asList(EDK_ID_2, EDK_ID_3), Collections.emptyList()}, - {strict_twoCMKs, Collections.emptyList(), Collections.emptyList()}, {strict_twoCMKs, Arrays.asList(EDK_ID_3), Collections.emptyList()}, - {strict_twoCMKs, Arrays.asList(EDK_ID_3, EDK_OTHER_PROVIDER), Collections.emptyList()}, - + { + strict_twoCMKs, + Arrays.asList(EDK_ID_3, EDK_OTHER_PROVIDER), + Collections.emptyList() + }, {explicitDiscovery, Collections.emptyList(), Collections.emptyList()}, {explicitDiscovery, Arrays.asList(EDK_OTHER_PROVIDER), Collections.emptyList()}, {explicitDiscovery, Arrays.asList(EDK_EMPTY_PROVIDER), Collections.emptyList()}, - {explicitDiscovery, Arrays.asList(EDK_OTHER_PROVIDER, EDK_EMPTY_PROVIDER), Collections.emptyList()}, - + { + explicitDiscovery, + Arrays.asList(EDK_OTHER_PROVIDER, EDK_EMPTY_PROVIDER), + Collections.emptyList() + }, {explicitDiscovery_filter, Collections.emptyList(), Collections.emptyList()}, - {explicitDiscovery_filter, Arrays.asList(EDK_OTHER_PROVIDER), Collections.emptyList()}, - {explicitDiscovery_filter, Arrays.asList(EDK_EMPTY_PROVIDER), Collections.emptyList()}, + { + explicitDiscovery_filter, + Arrays.asList(EDK_OTHER_PROVIDER), + Collections.emptyList() + }, + { + explicitDiscovery_filter, + Arrays.asList(EDK_EMPTY_PROVIDER), + Collections.emptyList() + }, {explicitDiscovery_filter, Arrays.asList(EDK_NON_ARN), Collections.emptyList()}, - {explicitDiscovery_filter, Arrays.asList(EDK_OTHER_PARTITION), Collections.emptyList()}, - {explicitDiscovery_filter, Arrays.asList(EDK_OTHER_ACCOUNT), Collections.emptyList()}, - {explicitDiscovery_filter, Arrays.asList(EDK_OTHER_PROVIDER, EDK_EMPTY_PROVIDER), Collections.emptyList()}, + { + explicitDiscovery_filter, + Arrays.asList(EDK_OTHER_PARTITION), + Collections.emptyList() + }, + { + explicitDiscovery_filter, + Arrays.asList(EDK_OTHER_ACCOUNT), + Collections.emptyList() + }, + { + explicitDiscovery_filter, + Arrays.asList(EDK_OTHER_PROVIDER, EDK_EMPTY_PROVIDER), + Collections.emptyList() + }, // Test cases where one EDK is expected to be decryptable {strict_oneCMK, Arrays.asList(EDK_ID_1), Arrays.asList(EDK_ID_1)}, {strict_oneCMK, Arrays.asList(EDK_ID_2, EDK_ID_1), Arrays.asList(EDK_ID_1)}, {strict_oneCMK, Arrays.asList(EDK_ID_1, EDK_ID_2), Arrays.asList(EDK_ID_1)}, - {strict_twoCMKs, Arrays.asList(EDK_ID_1), Arrays.asList(EDK_ID_1)}, {strict_twoCMKs, Arrays.asList(EDK_ID_2), Arrays.asList(EDK_ID_2)}, {strict_twoCMKs, Arrays.asList(EDK_ID_3, EDK_ID_1), Arrays.asList(EDK_ID_1)}, {strict_twoCMKs, Arrays.asList(EDK_ID_1, EDK_ID_3), Arrays.asList(EDK_ID_1)}, - {explicitDiscovery, Arrays.asList(EDK_ID_1), Arrays.asList(EDK_ID_1)}, - {explicitDiscovery, Arrays.asList(EDK_OTHER_PROVIDER, EDK_ID_1), Arrays.asList(EDK_ID_1)}, - {explicitDiscovery, Arrays.asList(EDK_ID_1, EDK_OTHER_PROVIDER), Arrays.asList(EDK_ID_1)}, - + { + explicitDiscovery, + Arrays.asList(EDK_OTHER_PROVIDER, EDK_ID_1), + Arrays.asList(EDK_ID_1) + }, + { + explicitDiscovery, + Arrays.asList(EDK_ID_1, EDK_OTHER_PROVIDER), + Arrays.asList(EDK_ID_1) + }, {explicitDiscovery_filter, Arrays.asList(EDK_ID_1), Arrays.asList(EDK_ID_1)}, - {explicitDiscovery_filter, Arrays.asList(EDK_OTHER_ACCOUNT, EDK_ID_1), Arrays.asList(EDK_ID_1)}, - {explicitDiscovery_filter, Arrays.asList(EDK_ID_1, EDK_OTHER_ACCOUNT), Arrays.asList(EDK_ID_1)}, + { + explicitDiscovery_filter, + Arrays.asList(EDK_OTHER_ACCOUNT, EDK_ID_1), + Arrays.asList(EDK_ID_1) + }, + { + explicitDiscovery_filter, + Arrays.asList(EDK_ID_1, EDK_OTHER_ACCOUNT), + Arrays.asList(EDK_ID_1) + }, // Test cases where multiple EDKs are expected to be decryptable - {strict_oneCMK, Arrays.asList(EDK_ID_1, EDK_ID_1_OTHER_CIPHERTEXT), Arrays.asList(EDK_ID_1, EDK_ID_1_OTHER_CIPHERTEXT)}, - {strict_twoCMKs, Arrays.asList(EDK_ID_1, EDK_ID_2), Arrays.asList(EDK_ID_1, EDK_ID_2)}, - {explicitDiscovery, Arrays.asList(EDK_ID_1, EDK_ID_2), Arrays.asList(EDK_ID_1, EDK_ID_2)}, - {explicitDiscovery_filter, Arrays.asList(EDK_ID_1, EDK_ID_2), Arrays.asList(EDK_ID_1, EDK_ID_2)}, - }); - return testCases; - } - - @SuppressWarnings("deprecation") - private KmsMasterKeyProvider constructMKPForTest(MKPTestConfiguration mkpConfig, RegionalClientSupplier supplier) { - Builder builder = KmsMasterKeyProvider.builder().withCustomClientFactory(supplier); - - KmsMasterKeyProvider mkp; - if (mkpConfig.isDiscovery && mkpConfig.discoveryFilter == null) { - mkp = builder.buildDiscovery(); - } else if (mkpConfig.isDiscovery) { - mkp = builder.buildDiscovery(mkpConfig.discoveryFilter); - } else { - mkp = builder.buildStrict(mkpConfig.keyIds); - } - - return mkp; - } - - @Test - public void testDecrypt() throws Exception { - MockKMSClient client = spy(new MockKMSClient()); - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - // create MKP to test - KmsMasterKeyProvider mkp = constructMKPForTest(mkpConfig, supplier); - - // if we expect none of them to decrypt, just test that we get the correct - // failure and KMS was not called - if (decryptableEDKs.size() <= 0) { - assertThrows(CannotUnwrapDataKeyException.class, () -> mkp.decryptDataKey( - ALGORITHM_SUITE, inputEDKs, ENCRYPTION_CONTEXT)); - - ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); - verifyZeroInteractions(client); - return; - } - - // Test that the mkp calls KMS for the first expected EDK - EncryptedDataKey expectedEDK = decryptableEDKs.get(0); - - // mock KMS to return the KeyId for the expected EDK, - // we verify that we call KMS with this KeyId, so this is ok - DecryptResult decryptResult = new DecryptResult(); - decryptResult.setKeyId(new String(expectedEDK.getProviderInformation(), StandardCharsets.UTF_8)); - decryptResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); - doReturn(decryptResult).when(client).decrypt(isA(DecryptRequest.class)); - - DataKey dataKeyResult = mkp.decryptDataKey(ALGORITHM_SUITE, inputEDKs, ENCRYPTION_CONTEXT); - - ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(1)).decrypt(decrypt.capture()); - verifyNoMoreInteractions(client); - - DecryptRequest actualRequest = decrypt.getValue(); - assertArrayEquals(expectedEDK.getProviderInformation(), actualRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); - assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); - assertArrayEquals(expectedEDK.getEncryptedDataKey(), actualRequest.getCiphertextBlob().array()); - assertUserAgent(actualRequest); - - assertArrayEquals(expectedEDK.getProviderInformation(), dataKeyResult.getProviderInformation()); - assertArrayEquals(expectedEDK.getEncryptedDataKey(), dataKeyResult.getEncryptedDataKey()); - } - - @Test - public void testDecryptKMSFailsOnce() throws Exception { - MockKMSClient client = spy(new MockKMSClient()); - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - // create MKP to test - KmsMasterKeyProvider mkp = constructMKPForTest(mkpConfig, supplier); - - // if we expect one or less KMS call, just test that we get the correct - // failure and KMS was called the expected number of times - if (decryptableEDKs.size() <= 1) { - // Mock KMS to fail - doThrow(new AmazonServiceException("fail")).when(client).decrypt(isA(DecryptRequest.class)); - - assertThrows(CannotUnwrapDataKeyException.class, () -> mkp.decryptDataKey( - ALGORITHM_SUITE, inputEDKs, ENCRYPTION_CONTEXT)); - - ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(decryptableEDKs.size())).decrypt(decrypt.capture()); - return; - } - - EncryptedDataKey expectedFailedEDK = decryptableEDKs.get(0); - EncryptedDataKey expectedSuccessfulEDK = decryptableEDKs.get(1); - - // Mock KMS to fail the first call then succeed for the second call - DecryptResult decryptResult = new DecryptResult(); - decryptResult.setKeyId(new String(expectedSuccessfulEDK.getProviderInformation(), StandardCharsets.UTF_8)); - decryptResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); - doThrow(new AmazonServiceException("fail")).doReturn(decryptResult).when(client).decrypt(isA(DecryptRequest.class)); - - DataKey dataKeyResult = mkp.decryptDataKey(ALGORITHM_SUITE, inputEDKs, ENCRYPTION_CONTEXT); - - ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(2)).decrypt(decrypt.capture()); - verifyNoMoreInteractions(client); - - List actualRequests = decrypt.getAllValues(); - DecryptRequest failedRequest = actualRequests.get(0); - assertArrayEquals(expectedFailedEDK.getProviderInformation(), failedRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); - assertEquals(ENCRYPTION_CONTEXT, failedRequest.getEncryptionContext()); - assertArrayEquals(expectedFailedEDK.getEncryptedDataKey(), failedRequest.getCiphertextBlob().array()); - assertUserAgent(failedRequest); - - DecryptRequest successfulRequest = actualRequests.get(1); - assertArrayEquals(expectedSuccessfulEDK.getProviderInformation(), successfulRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); - assertEquals(ENCRYPTION_CONTEXT, successfulRequest.getEncryptionContext()); - assertArrayEquals(expectedSuccessfulEDK.getEncryptedDataKey(), successfulRequest.getCiphertextBlob().array()); - assertUserAgent(successfulRequest); - - assertArrayEquals(expectedSuccessfulEDK.getProviderInformation(), dataKeyResult.getProviderInformation()); - assertArrayEquals(expectedSuccessfulEDK.getEncryptedDataKey(), dataKeyResult.getEncryptedDataKey()); - } - - private void assertUserAgent(AmazonWebServiceRequest request) { - assertTrue(request.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT) - .contains(VersionInfo.loadUserAgent())); - } + { + strict_oneCMK, + Arrays.asList(EDK_ID_1, EDK_ID_1_OTHER_CIPHERTEXT), + Arrays.asList(EDK_ID_1, EDK_ID_1_OTHER_CIPHERTEXT) + }, + { + strict_twoCMKs, + Arrays.asList(EDK_ID_1, EDK_ID_2), + Arrays.asList(EDK_ID_1, EDK_ID_2) + }, + { + explicitDiscovery, + Arrays.asList(EDK_ID_1, EDK_ID_2), + Arrays.asList(EDK_ID_1, EDK_ID_2) + }, + { + explicitDiscovery_filter, + Arrays.asList(EDK_ID_1, EDK_ID_2), + Arrays.asList(EDK_ID_1, EDK_ID_2) + }, + }); + return testCases; } - public static class NonParameterized { - @Test - public void testBuildStrictWithNoCMKs() throws Exception { - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - - assertThrows(IllegalArgumentException.class, () -> KmsMasterKeyProvider.builder() - .withCustomClientFactory(supplier) - .buildStrict()); - - assertThrows(IllegalArgumentException.class, () -> KmsMasterKeyProvider.builder() - .withCustomClientFactory(supplier) - .buildStrict(Collections.emptyList())); - - assertThrows(IllegalArgumentException.class, () -> KmsMasterKeyProvider.builder() - .withCustomClientFactory(supplier) - .buildStrict((List) null)); - } - - @Test - public void testBuildStrictWithNullCMK() throws Exception { - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - - assertThrows(IllegalArgumentException.class, () -> KmsMasterKeyProvider.builder() - .withCustomClientFactory(supplier) - .buildStrict((String) null)); - - assertThrows(IllegalArgumentException.class, () -> KmsMasterKeyProvider.builder() - .withCustomClientFactory(supplier) - .buildStrict(Arrays.asList((String) null))); - } - - @Test - public void testBuildDiscoveryWithFilter() throws Exception { - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - - KmsMasterKeyProvider mkp1 = KmsMasterKeyProvider.builder() - .withCustomClientFactory(supplier) - .buildDiscovery(new DiscoveryFilter("aws", Arrays.asList("accountId"))); - assertNotNull(mkp1); - } - - @Test - public void testBuildDiscoveryWithNullFilter() throws Exception { - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - - assertThrows(IllegalArgumentException.class, () -> KmsMasterKeyProvider.builder() - .withCustomClientFactory(supplier) - .buildDiscovery(null)); - } - - @Test - public void testDecryptMismatchedKMSKeyIdResponse() throws Exception { - MockKMSClient client = spy(new MockKMSClient()); - RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); - when(supplier.getClient(any())).thenReturn(client); - - DecryptResult badResult = new DecryptResult(); - badResult.setKeyId(KEY_ID_2); - badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); - - doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); - - KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCustomClientFactory(supplier) - .buildDiscovery(); - - assertThrows(CannotUnwrapDataKeyException.class, () -> mkp.decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(EDK_ID_1), - ENCRYPTION_CONTEXT)); - } + @SuppressWarnings("deprecation") + private KmsMasterKeyProvider constructMKPForTest( + MKPTestConfiguration mkpConfig, RegionalClientSupplier supplier) { + Builder builder = KmsMasterKeyProvider.builder().withCustomClientFactory(supplier); + + KmsMasterKeyProvider mkp; + if (mkpConfig.isDiscovery && mkpConfig.discoveryFilter == null) { + mkp = builder.buildDiscovery(); + } else if (mkpConfig.isDiscovery) { + mkp = builder.buildDiscovery(mkpConfig.discoveryFilter); + } else { + mkp = builder.buildStrict(mkpConfig.keyIds); + } + + return mkp; } @Test - public void testExplicitCredentials() throws Exception { - AWSCredentials creds = new AWSCredentials() { - @Override public String getAWSAccessKeyId() { - throw new UsedExplicitCredentials(); - } - - @Override public String getAWSSecretKey() { - throw new UsedExplicitCredentials(); - } - }; + public void testDecrypt() throws Exception { + MockKMSClient client = spy(new MockKMSClient()); + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + // create MKP to test + KmsMasterKeyProvider mkp = constructMKPForTest(mkpConfig, supplier); + + // if we expect none of them to decrypt, just test that we get the correct + // failure and KMS was not called + if (decryptableEDKs.size() <= 0) { + assertThrows( + CannotUnwrapDataKeyException.class, + () -> mkp.decryptDataKey(ALGORITHM_SUITE, inputEDKs, ENCRYPTION_CONTEXT)); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verifyZeroInteractions(client); + return; + } + + // Test that the mkp calls KMS for the first expected EDK + EncryptedDataKey expectedEDK = decryptableEDKs.get(0); + + // mock KMS to return the KeyId for the expected EDK, + // we verify that we call KMS with this KeyId, so this is ok + DecryptResult decryptResult = new DecryptResult(); + decryptResult.setKeyId( + new String(expectedEDK.getProviderInformation(), StandardCharsets.UTF_8)); + decryptResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); + doReturn(decryptResult).when(client).decrypt(isA(DecryptRequest.class)); + + DataKey dataKeyResult = + mkp.decryptDataKey(ALGORITHM_SUITE, inputEDKs, ENCRYPTION_CONTEXT); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + verifyNoMoreInteractions(client); + + DecryptRequest actualRequest = decrypt.getValue(); + assertArrayEquals( + expectedEDK.getProviderInformation(), + actualRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertArrayEquals( + expectedEDK.getEncryptedDataKey(), actualRequest.getCiphertextBlob().array()); + assertUserAgent(actualRequest); + + assertArrayEquals( + expectedEDK.getProviderInformation(), dataKeyResult.getProviderInformation()); + assertArrayEquals(expectedEDK.getEncryptedDataKey(), dataKeyResult.getEncryptedDataKey()); + } - MasterKeyProvider mkp = KmsMasterKeyProvider.builder() - .withCredentials(creds) - .buildStrict("arn:aws:kms:us-east-1:012345678901:key/foo-bar"); - assertExplicitCredentialsUsed(mkp); + @Test + public void testDecryptKMSFailsOnce() throws Exception { + MockKMSClient client = spy(new MockKMSClient()); + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + // create MKP to test + KmsMasterKeyProvider mkp = constructMKPForTest(mkpConfig, supplier); + + // if we expect one or less KMS call, just test that we get the correct + // failure and KMS was called the expected number of times + if (decryptableEDKs.size() <= 1) { + // Mock KMS to fail + doThrow(new AmazonServiceException("fail")).when(client).decrypt(isA(DecryptRequest.class)); + + assertThrows( + CannotUnwrapDataKeyException.class, + () -> mkp.decryptDataKey(ALGORITHM_SUITE, inputEDKs, ENCRYPTION_CONTEXT)); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(decryptableEDKs.size())).decrypt(decrypt.capture()); + return; + } + + EncryptedDataKey expectedFailedEDK = decryptableEDKs.get(0); + EncryptedDataKey expectedSuccessfulEDK = decryptableEDKs.get(1); + + // Mock KMS to fail the first call then succeed for the second call + DecryptResult decryptResult = new DecryptResult(); + decryptResult.setKeyId( + new String(expectedSuccessfulEDK.getProviderInformation(), StandardCharsets.UTF_8)); + decryptResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); + doThrow(new AmazonServiceException("fail")) + .doReturn(decryptResult) + .when(client) + .decrypt(isA(DecryptRequest.class)); + + DataKey dataKeyResult = + mkp.decryptDataKey(ALGORITHM_SUITE, inputEDKs, ENCRYPTION_CONTEXT); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(2)).decrypt(decrypt.capture()); + verifyNoMoreInteractions(client); + + List actualRequests = decrypt.getAllValues(); + DecryptRequest failedRequest = actualRequests.get(0); + assertArrayEquals( + expectedFailedEDK.getProviderInformation(), + failedRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); + assertEquals(ENCRYPTION_CONTEXT, failedRequest.getEncryptionContext()); + assertArrayEquals( + expectedFailedEDK.getEncryptedDataKey(), failedRequest.getCiphertextBlob().array()); + assertUserAgent(failedRequest); + + DecryptRequest successfulRequest = actualRequests.get(1); + assertArrayEquals( + expectedSuccessfulEDK.getProviderInformation(), + successfulRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); + assertEquals(ENCRYPTION_CONTEXT, successfulRequest.getEncryptionContext()); + assertArrayEquals( + expectedSuccessfulEDK.getEncryptedDataKey(), + successfulRequest.getCiphertextBlob().array()); + assertUserAgent(successfulRequest); + + assertArrayEquals( + expectedSuccessfulEDK.getProviderInformation(), dataKeyResult.getProviderInformation()); + assertArrayEquals( + expectedSuccessfulEDK.getEncryptedDataKey(), dataKeyResult.getEncryptedDataKey()); + } + + private void assertUserAgent(AmazonWebServiceRequest request) { + assertTrue( + request + .getRequestClientOptions() + .getClientMarker(RequestClientOptions.Marker.USER_AGENT) + .contains(VersionInfo.loadUserAgent())); + } + } - mkp = KmsMasterKeyProvider.builder() - .withCredentials(new AWSStaticCredentialsProvider(creds)) - .buildStrict("arn:aws:kms:us-east-1:012345678901:key/foo-bar"); - assertExplicitCredentialsUsed(mkp); + public static class NonParameterized { + @Test + public void testBuildStrictWithNoCMKs() throws Exception { + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + + assertThrows( + IllegalArgumentException.class, + () -> KmsMasterKeyProvider.builder().withCustomClientFactory(supplier).buildStrict()); + + assertThrows( + IllegalArgumentException.class, + () -> + KmsMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict(Collections.emptyList())); + + assertThrows( + IllegalArgumentException.class, + () -> + KmsMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict((List) null)); } - private void assertExplicitCredentialsUsed(final MasterKeyProvider mkp) { - try { - MasterKeyRequest mkr = MasterKeyRequest.newBuilder() - .setEncryptionContext(Collections.emptyMap()) - .setStreaming(true) - .build(); - mkp.getMasterKeysForEncryption(mkr) - .forEach(mk -> mk.generateDataKey(ALGORITHM_SUITE, Collections.emptyMap())); - - fail("Expected exception"); - } catch (UsedExplicitCredentials e) { - // ok - } + @Test + public void testBuildStrictWithNullCMK() throws Exception { + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + + assertThrows( + IllegalArgumentException.class, + () -> + KmsMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict((String) null)); + + assertThrows( + IllegalArgumentException.class, + () -> + KmsMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildStrict(Arrays.asList((String) null))); + } + + @Test + public void testBuildDiscoveryWithFilter() throws Exception { + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + + KmsMasterKeyProvider mkp1 = + KmsMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildDiscovery(new DiscoveryFilter("aws", Arrays.asList("accountId"))); + assertNotNull(mkp1); + } + + @Test + public void testBuildDiscoveryWithNullFilter() throws Exception { + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + + assertThrows( + IllegalArgumentException.class, + () -> + KmsMasterKeyProvider.builder() + .withCustomClientFactory(supplier) + .buildDiscovery(null)); + } + + @Test + public void testDecryptMismatchedKMSKeyIdResponse() throws Exception { + MockKMSClient client = spy(new MockKMSClient()); + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + DecryptResult badResult = new DecryptResult(); + badResult.setKeyId(KEY_ID_2); + badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); + + doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); + + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder().withCustomClientFactory(supplier).buildDiscovery(); + + assertThrows( + CannotUnwrapDataKeyException.class, + () -> mkp.decryptDataKey(ALGORITHM_SUITE, Arrays.asList(EDK_ID_1), ENCRYPTION_CONTEXT)); + } + } + + @Test + public void testExplicitCredentials() throws Exception { + AWSCredentials creds = + new AWSCredentials() { + @Override + public String getAWSAccessKeyId() { + throw new UsedExplicitCredentials(); + } + + @Override + public String getAWSSecretKey() { + throw new UsedExplicitCredentials(); + } + }; + + MasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withCredentials(creds) + .buildStrict("arn:aws:kms:us-east-1:012345678901:key/foo-bar"); + assertExplicitCredentialsUsed(mkp); + + mkp = + KmsMasterKeyProvider.builder() + .withCredentials(new AWSStaticCredentialsProvider(creds)) + .buildStrict("arn:aws:kms:us-east-1:012345678901:key/foo-bar"); + assertExplicitCredentialsUsed(mkp); + } + + private void assertExplicitCredentialsUsed(final MasterKeyProvider mkp) { + try { + MasterKeyRequest mkr = + MasterKeyRequest.newBuilder() + .setEncryptionContext(Collections.emptyMap()) + .setStreaming(true) + .build(); + mkp.getMasterKeysForEncryption(mkr) + .forEach(mk -> mk.generateDataKey(ALGORITHM_SUITE, Collections.emptyMap())); + + fail("Expected exception"); + } catch (UsedExplicitCredentials e) { + // ok } + } - private static class UsedExplicitCredentials extends RuntimeException {} + private static class UsedExplicitCredentials extends RuntimeException {} } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java index eaab97464..f09691154 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyTest.java @@ -3,6 +3,20 @@ package com.amazonaws.encryptionsdk.kms; +import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; +import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import com.amazonaws.AmazonWebServiceRequest; import com.amazonaws.RequestClientOptions; import com.amazonaws.encryptionsdk.CryptoAlgorithm; @@ -19,9 +33,6 @@ import com.amazonaws.services.kms.model.EncryptRequest; import com.amazonaws.services.kms.model.GenerateDataKeyRequest; import com.amazonaws.services.kms.model.GenerateDataKeyResult; - -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -29,337 +40,373 @@ import java.util.List; import java.util.Map; import java.util.function.Supplier; - +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; import org.junit.Test; import org.mockito.ArgumentCaptor; -import static com.amazonaws.encryptionsdk.TestUtils.assertThrows; -import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - public class KmsMasterKeyTest { - private static final String AWS_KMS_PROVIDER_ID = "aws-kms"; - private static final String OTHER_PROVIDER_ID = "not-aws-kms"; - - private static final CryptoAlgorithm ALGORITHM_SUITE = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - private static final SecretKey DATA_KEY = new SecretKeySpec(generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); - private static final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); - private static final Map ENCRYPTION_CONTEXT = Collections.singletonMap("myKey", "myValue"); - - @Test - public void testEncryptAndDecrypt() { - AWSKMS client = spy(new MockKMSClient()); - Supplier supplier = mock(Supplier.class); - when(supplier.get()).thenReturn(client); - - MasterKey otherMasterKey = mock(MasterKey.class); - when(otherMasterKey.getProviderId()).thenReturn(OTHER_PROVIDER_ID); - when(otherMasterKey.getKeyId()).thenReturn("someOtherId"); - DataKey dataKey = new DataKey(DATA_KEY, new byte[0], - OTHER_PROVIDER_ID.getBytes(StandardCharsets.UTF_8), otherMasterKey); - - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); - String keyId = client.createKey().getKeyMetadata().getArn(); - KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); - kmsMasterKey.setGrantTokens(GRANT_TOKENS); - - DataKey encryptDataKeyResult = kmsMasterKey.encryptDataKey( - ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey); - - ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); - verify(client, times(1)).encrypt(er.capture()); - - EncryptRequest actualRequest = er.getValue(); - assertEquals(keyId, actualRequest.getKeyId()); - assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); - assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); - assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array()); - assertUserAgent(actualRequest); - - assertEquals(encryptDataKeyResult.getMasterKey(), kmsMasterKey); - assertEquals(AWS_KMS_PROVIDER_ID, encryptDataKeyResult.getProviderId()); - assertArrayEquals(keyId.getBytes(StandardCharsets.UTF_8), encryptDataKeyResult.getProviderInformation()); - assertNotNull(encryptDataKeyResult.getEncryptedDataKey()); - - DataKey decryptDataKeyResult = kmsMasterKey.decryptDataKey( - ALGORITHM_SUITE, - Collections.singletonList(encryptDataKeyResult), - ENCRYPTION_CONTEXT); - - ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(1)).decrypt(decrypt.capture()); - - DecryptRequest actualDecryptRequest = decrypt.getValue(); - assertArrayEquals(encryptDataKeyResult.getProviderInformation(), actualDecryptRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); - assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens()); - assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext()); - assertArrayEquals(encryptDataKeyResult.getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array()); - assertUserAgent(actualDecryptRequest); - - assertEquals(DATA_KEY, decryptDataKeyResult.getKey()); - assertArrayEquals(keyId.getBytes(StandardCharsets.UTF_8), decryptDataKeyResult.getProviderInformation()); - } - - @Test - public void testGenerateAndDecrypt() { - AWSKMS client = spy(new MockKMSClient()); - Supplier supplier = mock(Supplier.class); - when(supplier.get()).thenReturn(client); - - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); - String keyId = client.createKey().getKeyMetadata().getArn(); - KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); - kmsMasterKey.setGrantTokens(GRANT_TOKENS); - - DataKey generateDataKeyResult = kmsMasterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT); - - ArgumentCaptor gr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); - verify(client, times(1)).generateDataKey(gr.capture()); - - GenerateDataKeyRequest actualRequest = gr.getValue(); - - assertEquals(keyId, actualRequest.getKeyId()); - assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); - assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); - assertEquals(ALGORITHM_SUITE.getDataKeyLength(), actualRequest.getNumberOfBytes().longValue()); - assertUserAgent(actualRequest); - - assertNotNull(generateDataKeyResult.getKey()); - assertEquals(ALGORITHM_SUITE.getDataKeyLength(), generateDataKeyResult.getKey().getEncoded().length); - assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), generateDataKeyResult.getKey().getAlgorithm()); - assertNotNull(generateDataKeyResult.getEncryptedDataKey()); - - DataKey decryptDataKeyResult = kmsMasterKey.decryptDataKey( - ALGORITHM_SUITE, - Collections.singletonList(generateDataKeyResult), - ENCRYPTION_CONTEXT); - - ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(1)).decrypt(decrypt.capture()); - - DecryptRequest actualDecryptRequest = decrypt.getValue(); - assertArrayEquals(generateDataKeyResult.getProviderInformation(), actualDecryptRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); - assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens()); - assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext()); - assertArrayEquals(generateDataKeyResult.getEncryptedDataKey(), actualDecryptRequest.getCiphertextBlob().array()); - assertUserAgent(actualDecryptRequest); - - assertEquals(generateDataKeyResult.getKey(), decryptDataKeyResult.getKey()); - assertArrayEquals(keyId.getBytes(StandardCharsets.UTF_8), decryptDataKeyResult.getProviderInformation()); - } - - @Test - public void testEncryptWithRawKeyId() { - AWSKMS client = spy(new MockKMSClient()); - Supplier supplier = mock(Supplier.class); - when(supplier.get()).thenReturn(client); - - MasterKey otherMasterKey = mock(MasterKey.class); - when(otherMasterKey.getProviderId()).thenReturn(OTHER_PROVIDER_ID); - when(otherMasterKey.getKeyId()).thenReturn("someOtherId"); - DataKey dataKey = new DataKey(DATA_KEY, new byte[0], - OTHER_PROVIDER_ID.getBytes(StandardCharsets.UTF_8), otherMasterKey); - - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); - String keyId = client.createKey().getKeyMetadata().getArn(); - String rawKeyId = keyId.split("/")[1]; - KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, rawKeyId, mkp); - kmsMasterKey.setGrantTokens(GRANT_TOKENS); - - DataKey encryptDataKeyResult = kmsMasterKey.encryptDataKey( - ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey); - - ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); - verify(client, times(1)).encrypt(er.capture()); - - EncryptRequest actualRequest = er.getValue(); - - assertEquals(rawKeyId, actualRequest.getKeyId()); - assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); - assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); - assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array()); - assertUserAgent(actualRequest); - - assertEquals(AWS_KMS_PROVIDER_ID, encryptDataKeyResult.getProviderId()); - assertArrayEquals(keyId.getBytes(StandardCharsets.UTF_8), encryptDataKeyResult.getProviderInformation()); - assertNotNull(encryptDataKeyResult.getEncryptedDataKey()); - } - - @Test - public void testEncryptWrongKeyFormat() { - SecretKey key = mock(SecretKey.class); - when(key.getFormat()).thenReturn("BadFormat"); - - AWSKMS client = spy(new MockKMSClient()); - Supplier supplier = mock(Supplier.class); - when(supplier.get()).thenReturn(client); - - MasterKey otherMasterKey = mock(MasterKey.class); - when(otherMasterKey.getProviderId()).thenReturn(OTHER_PROVIDER_ID); - when(otherMasterKey.getKeyId()).thenReturn("someOtherId"); - DataKey dataKey = new DataKey(key, new byte[0], - OTHER_PROVIDER_ID.getBytes(StandardCharsets.UTF_8), otherMasterKey); - - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); - String keyId = client.createKey().getKeyMetadata().getArn(); - KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); - - assertThrows(IllegalArgumentException.class, () -> kmsMasterKey.encryptDataKey( - ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey)); - } - - @Test - public void testGenerateBadKmsKeyLength() { - AWSKMS client = spy(new MockKMSClient()); - Supplier supplier = mock(Supplier.class); - when(supplier.get()).thenReturn(client); - - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); - String keyId = client.createKey().getKeyMetadata().getArn(); - KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); - - GenerateDataKeyResult badResult = new GenerateDataKeyResult(); - badResult.setKeyId(keyId); - badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength() + 1)); - - doReturn(badResult).when(client).generateDataKey(isA(GenerateDataKeyRequest.class)); - - assertThrows(IllegalStateException.class, () -> kmsMasterKey.generateDataKey( - ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); - } - - @Test - public void testDecryptBadKmsKeyLength() { - AWSKMS client = spy(new MockKMSClient()); - Supplier supplier = mock(Supplier.class); - when(supplier.get()).thenReturn(client); - - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); - String keyId = client.createKey().getKeyMetadata().getArn(); - KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); - - DecryptResult badResult = new DecryptResult(); - badResult.setKeyId(keyId); - badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength() + 1)); - - doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); - - EncryptedDataKey edk = new KeyBlob(AWS_KMS_PROVIDER_ID, - keyId.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - - assertThrows(IllegalStateException.class, () -> kmsMasterKey.decryptDataKey( - ALGORITHM_SUITE, - Collections.singletonList(edk), - ENCRYPTION_CONTEXT)); - } - - @Test - public void testDecryptMissingKmsKeyId() { - AWSKMS client = spy(new MockKMSClient()); - Supplier supplier = mock(Supplier.class); - when(supplier.get()).thenReturn(client); - - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); - String keyId = client.createKey().getKeyMetadata().getArn(); - KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); - - DecryptResult badResult = new DecryptResult(); - badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); - - doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); - - EncryptedDataKey edk = new KeyBlob(AWS_KMS_PROVIDER_ID, - keyId.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - - assertThrows(IllegalStateException.class, - "Received an empty keyId from KMS", - () -> kmsMasterKey.decryptDataKey( - ALGORITHM_SUITE, - Collections.singletonList(edk), - ENCRYPTION_CONTEXT)); - } - - @Test - public void testDecryptMismatchedKmsKeyId() { - AWSKMS client = spy(new MockKMSClient()); - Supplier supplier = mock(Supplier.class); - when(supplier.get()).thenReturn(client); - - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); - String keyId = client.createKey().getKeyMetadata().getArn(); - KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); - - DecryptResult badResult = new DecryptResult(); - badResult.setKeyId("mismatchedID"); - badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); - - doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); - - EncryptedDataKey edk = new KeyBlob(AWS_KMS_PROVIDER_ID, - keyId.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - - assertThrows(CannotUnwrapDataKeyException.class, () -> kmsMasterKey.decryptDataKey( - ALGORITHM_SUITE, - Collections.singletonList(edk), - ENCRYPTION_CONTEXT)); - } - - @Test - public void testDecryptSkipsMismatchedIdEDK() { - AWSKMS client = spy(new MockKMSClient()); - Supplier supplier = mock(Supplier.class); - when(supplier.get()).thenReturn(client); - - MasterKeyProvider mkp = mock(MasterKeyProvider.class); - when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); - String keyId = client.createKey().getKeyMetadata().getArn(); - KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); - - // Mock expected KMS response to verify success if second EDK is ok, - // and the mismatched EDK is skipped vs failing outright - DecryptResult kmsResponse = new DecryptResult(); - kmsResponse.setKeyId(keyId); - kmsResponse.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); - doReturn(kmsResponse).when(client).decrypt(isA(DecryptRequest.class)); - - EncryptedDataKey edk = new KeyBlob(AWS_KMS_PROVIDER_ID, - keyId.getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - EncryptedDataKey mismatchedEDK = new KeyBlob(AWS_KMS_PROVIDER_ID, - "mismatchedID".getBytes(StandardCharsets.UTF_8), generate(ALGORITHM_SUITE.getDataKeyLength())); - - DataKey decryptDataKeyResult = kmsMasterKey.decryptDataKey( - ALGORITHM_SUITE, - Arrays.asList(mismatchedEDK, edk), - ENCRYPTION_CONTEXT); - - ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); - verify(client, times(1)).decrypt(decrypt.capture()); - - DecryptRequest actualDecryptRequest = decrypt.getValue(); - assertArrayEquals(edk.getProviderInformation(), actualDecryptRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); - } - - private void assertUserAgent(AmazonWebServiceRequest request) { - assertTrue(request.getRequestClientOptions().getClientMarker(RequestClientOptions.Marker.USER_AGENT) - .contains(VersionInfo.loadUserAgent())); - } + private static final String AWS_KMS_PROVIDER_ID = "aws-kms"; + private static final String OTHER_PROVIDER_ID = "not-aws-kms"; + + private static final CryptoAlgorithm ALGORITHM_SUITE = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + private static final SecretKey DATA_KEY = + new SecretKeySpec( + generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo()); + private static final List GRANT_TOKENS = Collections.singletonList("testGrantToken"); + private static final Map ENCRYPTION_CONTEXT = + Collections.singletonMap("myKey", "myValue"); + + @Test + public void testEncryptAndDecrypt() { + AWSKMS client = spy(new MockKMSClient()); + Supplier supplier = mock(Supplier.class); + when(supplier.get()).thenReturn(client); + + MasterKey otherMasterKey = mock(MasterKey.class); + when(otherMasterKey.getProviderId()).thenReturn(OTHER_PROVIDER_ID); + when(otherMasterKey.getKeyId()).thenReturn("someOtherId"); + DataKey dataKey = + new DataKey( + DATA_KEY, + new byte[0], + OTHER_PROVIDER_ID.getBytes(StandardCharsets.UTF_8), + otherMasterKey); + + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); + String keyId = client.createKey().getKeyMetadata().getArn(); + KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); + kmsMasterKey.setGrantTokens(GRANT_TOKENS); + + DataKey encryptDataKeyResult = + kmsMasterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey); + + ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(er.capture()); + + EncryptRequest actualRequest = er.getValue(); + assertEquals(keyId, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array()); + assertUserAgent(actualRequest); + + assertEquals(encryptDataKeyResult.getMasterKey(), kmsMasterKey); + assertEquals(AWS_KMS_PROVIDER_ID, encryptDataKeyResult.getProviderId()); + assertArrayEquals( + keyId.getBytes(StandardCharsets.UTF_8), encryptDataKeyResult.getProviderInformation()); + assertNotNull(encryptDataKeyResult.getEncryptedDataKey()); + + DataKey decryptDataKeyResult = + kmsMasterKey.decryptDataKey( + ALGORITHM_SUITE, Collections.singletonList(encryptDataKeyResult), ENCRYPTION_CONTEXT); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + + DecryptRequest actualDecryptRequest = decrypt.getValue(); + assertArrayEquals( + encryptDataKeyResult.getProviderInformation(), + actualDecryptRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); + assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext()); + assertArrayEquals( + encryptDataKeyResult.getEncryptedDataKey(), + actualDecryptRequest.getCiphertextBlob().array()); + assertUserAgent(actualDecryptRequest); + + assertEquals(DATA_KEY, decryptDataKeyResult.getKey()); + assertArrayEquals( + keyId.getBytes(StandardCharsets.UTF_8), decryptDataKeyResult.getProviderInformation()); + } + + @Test + public void testGenerateAndDecrypt() { + AWSKMS client = spy(new MockKMSClient()); + Supplier supplier = mock(Supplier.class); + when(supplier.get()).thenReturn(client); + + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); + String keyId = client.createKey().getKeyMetadata().getArn(); + KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); + kmsMasterKey.setGrantTokens(GRANT_TOKENS); + + DataKey generateDataKeyResult = + kmsMasterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT); + + ArgumentCaptor gr = + ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + verify(client, times(1)).generateDataKey(gr.capture()); + + GenerateDataKeyRequest actualRequest = gr.getValue(); + + assertEquals(keyId, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertEquals(ALGORITHM_SUITE.getDataKeyLength(), actualRequest.getNumberOfBytes().longValue()); + assertUserAgent(actualRequest); + + assertNotNull(generateDataKeyResult.getKey()); + assertEquals( + ALGORITHM_SUITE.getDataKeyLength(), generateDataKeyResult.getKey().getEncoded().length); + assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), generateDataKeyResult.getKey().getAlgorithm()); + assertNotNull(generateDataKeyResult.getEncryptedDataKey()); + + DataKey decryptDataKeyResult = + kmsMasterKey.decryptDataKey( + ALGORITHM_SUITE, Collections.singletonList(generateDataKeyResult), ENCRYPTION_CONTEXT); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + + DecryptRequest actualDecryptRequest = decrypt.getValue(); + assertArrayEquals( + generateDataKeyResult.getProviderInformation(), + actualDecryptRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); + assertEquals(GRANT_TOKENS, actualDecryptRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.getEncryptionContext()); + assertArrayEquals( + generateDataKeyResult.getEncryptedDataKey(), + actualDecryptRequest.getCiphertextBlob().array()); + assertUserAgent(actualDecryptRequest); + + assertEquals(generateDataKeyResult.getKey(), decryptDataKeyResult.getKey()); + assertArrayEquals( + keyId.getBytes(StandardCharsets.UTF_8), decryptDataKeyResult.getProviderInformation()); + } + + @Test + public void testEncryptWithRawKeyId() { + AWSKMS client = spy(new MockKMSClient()); + Supplier supplier = mock(Supplier.class); + when(supplier.get()).thenReturn(client); + + MasterKey otherMasterKey = mock(MasterKey.class); + when(otherMasterKey.getProviderId()).thenReturn(OTHER_PROVIDER_ID); + when(otherMasterKey.getKeyId()).thenReturn("someOtherId"); + DataKey dataKey = + new DataKey( + DATA_KEY, + new byte[0], + OTHER_PROVIDER_ID.getBytes(StandardCharsets.UTF_8), + otherMasterKey); + + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); + String keyId = client.createKey().getKeyMetadata().getArn(); + String rawKeyId = keyId.split("/")[1]; + KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, rawKeyId, mkp); + kmsMasterKey.setGrantTokens(GRANT_TOKENS); + + DataKey encryptDataKeyResult = + kmsMasterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey); + + ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(er.capture()); + + EncryptRequest actualRequest = er.getValue(); + + assertEquals(rawKeyId, actualRequest.getKeyId()); + assertEquals(GRANT_TOKENS, actualRequest.getGrantTokens()); + assertEquals(ENCRYPTION_CONTEXT, actualRequest.getEncryptionContext()); + assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.getPlaintext().array()); + assertUserAgent(actualRequest); + + assertEquals(AWS_KMS_PROVIDER_ID, encryptDataKeyResult.getProviderId()); + assertArrayEquals( + keyId.getBytes(StandardCharsets.UTF_8), encryptDataKeyResult.getProviderInformation()); + assertNotNull(encryptDataKeyResult.getEncryptedDataKey()); + } + + @Test + public void testEncryptWrongKeyFormat() { + SecretKey key = mock(SecretKey.class); + when(key.getFormat()).thenReturn("BadFormat"); + + AWSKMS client = spy(new MockKMSClient()); + Supplier supplier = mock(Supplier.class); + when(supplier.get()).thenReturn(client); + + MasterKey otherMasterKey = mock(MasterKey.class); + when(otherMasterKey.getProviderId()).thenReturn(OTHER_PROVIDER_ID); + when(otherMasterKey.getKeyId()).thenReturn("someOtherId"); + DataKey dataKey = + new DataKey( + key, new byte[0], OTHER_PROVIDER_ID.getBytes(StandardCharsets.UTF_8), otherMasterKey); + + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); + String keyId = client.createKey().getKeyMetadata().getArn(); + KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); + + assertThrows( + IllegalArgumentException.class, + () -> kmsMasterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey)); + } + + @Test + public void testGenerateBadKmsKeyLength() { + AWSKMS client = spy(new MockKMSClient()); + Supplier supplier = mock(Supplier.class); + when(supplier.get()).thenReturn(client); + + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); + String keyId = client.createKey().getKeyMetadata().getArn(); + KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); + + GenerateDataKeyResult badResult = new GenerateDataKeyResult(); + badResult.setKeyId(keyId); + badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength() + 1)); + + doReturn(badResult).when(client).generateDataKey(isA(GenerateDataKeyRequest.class)); + + assertThrows( + IllegalStateException.class, + () -> kmsMasterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT)); + } + + @Test + public void testDecryptBadKmsKeyLength() { + AWSKMS client = spy(new MockKMSClient()); + Supplier supplier = mock(Supplier.class); + when(supplier.get()).thenReturn(client); + + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); + String keyId = client.createKey().getKeyMetadata().getArn(); + KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); + + DecryptResult badResult = new DecryptResult(); + badResult.setKeyId(keyId); + badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength() + 1)); + + doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); + + EncryptedDataKey edk = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + keyId.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + + assertThrows( + IllegalStateException.class, + () -> + kmsMasterKey.decryptDataKey( + ALGORITHM_SUITE, Collections.singletonList(edk), ENCRYPTION_CONTEXT)); + } + + @Test + public void testDecryptMissingKmsKeyId() { + AWSKMS client = spy(new MockKMSClient()); + Supplier supplier = mock(Supplier.class); + when(supplier.get()).thenReturn(client); + + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); + String keyId = client.createKey().getKeyMetadata().getArn(); + KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); + + DecryptResult badResult = new DecryptResult(); + badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); + + doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); + + EncryptedDataKey edk = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + keyId.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + + assertThrows( + IllegalStateException.class, + "Received an empty keyId from KMS", + () -> + kmsMasterKey.decryptDataKey( + ALGORITHM_SUITE, Collections.singletonList(edk), ENCRYPTION_CONTEXT)); + } + + @Test + public void testDecryptMismatchedKmsKeyId() { + AWSKMS client = spy(new MockKMSClient()); + Supplier supplier = mock(Supplier.class); + when(supplier.get()).thenReturn(client); + + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); + String keyId = client.createKey().getKeyMetadata().getArn(); + KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); + + DecryptResult badResult = new DecryptResult(); + badResult.setKeyId("mismatchedID"); + badResult.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); + + doReturn(badResult).when(client).decrypt(isA(DecryptRequest.class)); + + EncryptedDataKey edk = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + keyId.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + + assertThrows( + CannotUnwrapDataKeyException.class, + () -> + kmsMasterKey.decryptDataKey( + ALGORITHM_SUITE, Collections.singletonList(edk), ENCRYPTION_CONTEXT)); + } + + @Test + public void testDecryptSkipsMismatchedIdEDK() { + AWSKMS client = spy(new MockKMSClient()); + Supplier supplier = mock(Supplier.class); + when(supplier.get()).thenReturn(client); + + MasterKeyProvider mkp = mock(MasterKeyProvider.class); + when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID); + String keyId = client.createKey().getKeyMetadata().getArn(); + KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp); + + // Mock expected KMS response to verify success if second EDK is ok, + // and the mismatched EDK is skipped vs failing outright + DecryptResult kmsResponse = new DecryptResult(); + kmsResponse.setKeyId(keyId); + kmsResponse.setPlaintext(ByteBuffer.allocate(ALGORITHM_SUITE.getDataKeyLength())); + doReturn(kmsResponse).when(client).decrypt(isA(DecryptRequest.class)); + + EncryptedDataKey edk = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + keyId.getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + EncryptedDataKey mismatchedEDK = + new KeyBlob( + AWS_KMS_PROVIDER_ID, + "mismatchedID".getBytes(StandardCharsets.UTF_8), + generate(ALGORITHM_SUITE.getDataKeyLength())); + + DataKey decryptDataKeyResult = + kmsMasterKey.decryptDataKey( + ALGORITHM_SUITE, Arrays.asList(mismatchedEDK, edk), ENCRYPTION_CONTEXT); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + + DecryptRequest actualDecryptRequest = decrypt.getValue(); + assertArrayEquals( + edk.getProviderInformation(), + actualDecryptRequest.getKeyId().getBytes(StandardCharsets.UTF_8)); + } + + private void assertUserAgent(AmazonWebServiceRequest request) { + assertTrue( + request + .getRequestClientOptions() + .getClientMarker(RequestClientOptions.Marker.USER_AGENT) + .contains(VersionInfo.loadUserAgent())); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/MaxEncryptedDataKeysIntegrationTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/MaxEncryptedDataKeysIntegrationTest.java index 67d71f6ac..05cacd958 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/MaxEncryptedDataKeysIntegrationTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/MaxEncryptedDataKeysIntegrationTest.java @@ -3,80 +3,84 @@ package com.amazonaws.encryptionsdk.kms; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.TestUtils; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClientBuilder; -import org.junit.Before; -import org.junit.Test; - import java.util.ArrayList; import java.util.List; - -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; +import org.junit.Before; +import org.junit.Test; public class MaxEncryptedDataKeysIntegrationTest { - private static final byte[] PLAINTEXT = {1, 2, 3, 4}; - private static final int MAX_EDKS = 3; + private static final byte[] PLAINTEXT = {1, 2, 3, 4}; + private static final int MAX_EDKS = 3; - private AWSKMS testClient_; - private KmsMasterKeyProvider.RegionalClientSupplier testClientSupplier_; - private AwsCrypto testCryptoClient_; + private AWSKMS testClient_; + private KmsMasterKeyProvider.RegionalClientSupplier testClientSupplier_; + private AwsCrypto testCryptoClient_; - @Before - public void setup() { - testClient_ = spy(AWSKMSClientBuilder.standard().withRegion("us-west-2").build()); - testClientSupplier_ = regionName -> { - if (regionName.equals("us-west-2")) { - return testClient_; - } - throw new AwsCryptoException("test supplier only configured for us-west-2 and eu-central-1"); + @Before + public void setup() { + testClient_ = spy(AWSKMSClientBuilder.standard().withRegion("us-west-2").build()); + testClientSupplier_ = + regionName -> { + if (regionName.equals("us-west-2")) { + return testClient_; + } + throw new AwsCryptoException( + "test supplier only configured for us-west-2 and eu-central-1"); }; - testCryptoClient_ = AwsCrypto.standard().toBuilder().withMaxEncryptedDataKeys(MAX_EDKS).build(); - } - - private KmsMasterKeyProvider providerWithEdks(int numKeys) { - List keyIds = new ArrayList<>(numKeys); - for (int i = 0; i < numKeys; i++) { - keyIds.add(KMSTestFixtures.US_WEST_2_KEY_ID); - } - return KmsMasterKeyProvider.builder() - .withCustomClientFactory(testClientSupplier_) - .buildStrict(keyIds); - } + testCryptoClient_ = AwsCrypto.standard().toBuilder().withMaxEncryptedDataKeys(MAX_EDKS).build(); + } - @Test - public void encryptDecryptWithLessThanMaxEdks() { - KmsMasterKeyProvider provider = providerWithEdks(MAX_EDKS - 1); - byte[] ciphertext = testCryptoClient_.encryptData(provider, PLAINTEXT).getResult(); - byte[] decrypted = testCryptoClient_.decryptData(provider, ciphertext).getResult(); - assertArrayEquals(decrypted, PLAINTEXT); + private KmsMasterKeyProvider providerWithEdks(int numKeys) { + List keyIds = new ArrayList<>(numKeys); + for (int i = 0; i < numKeys; i++) { + keyIds.add(KMSTestFixtures.US_WEST_2_KEY_ID); } + return KmsMasterKeyProvider.builder() + .withCustomClientFactory(testClientSupplier_) + .buildStrict(keyIds); + } - @Test - public void encryptDecryptWithMaxEdks() { - KmsMasterKeyProvider provider = providerWithEdks(MAX_EDKS); - byte[] ciphertext = testCryptoClient_.encryptData(provider, PLAINTEXT).getResult(); - byte[] decrypted = testCryptoClient_.decryptData(provider, ciphertext).getResult(); - assertArrayEquals(decrypted, PLAINTEXT); - } + @Test + public void encryptDecryptWithLessThanMaxEdks() { + KmsMasterKeyProvider provider = providerWithEdks(MAX_EDKS - 1); + byte[] ciphertext = testCryptoClient_.encryptData(provider, PLAINTEXT).getResult(); + byte[] decrypted = testCryptoClient_.decryptData(provider, ciphertext).getResult(); + assertArrayEquals(decrypted, PLAINTEXT); + } - @Test - public void noEncryptWithMoreThanMaxEdks() { - KmsMasterKeyProvider provider = providerWithEdks(MAX_EDKS + 1); - TestUtils.assertThrows(AwsCryptoException.class, "Encrypted data keys exceed maxEncryptedDataKeys", () -> - testCryptoClient_.encryptData(provider, PLAINTEXT)); - } + @Test + public void encryptDecryptWithMaxEdks() { + KmsMasterKeyProvider provider = providerWithEdks(MAX_EDKS); + byte[] ciphertext = testCryptoClient_.encryptData(provider, PLAINTEXT).getResult(); + byte[] decrypted = testCryptoClient_.decryptData(provider, ciphertext).getResult(); + assertArrayEquals(decrypted, PLAINTEXT); + } - @Test - public void noDecryptWithMoreThanMaxEdks() { - KmsMasterKeyProvider provider = providerWithEdks(MAX_EDKS + 1); - byte[] ciphertext = AwsCrypto.standard().encryptData(provider, PLAINTEXT).getResult(); - TestUtils.assertThrows(AwsCryptoException.class, "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", () -> - testCryptoClient_.decryptData(provider, ciphertext)); - verify(testClient_, never()).decrypt(any()); - } + @Test + public void noEncryptWithMoreThanMaxEdks() { + KmsMasterKeyProvider provider = providerWithEdks(MAX_EDKS + 1); + TestUtils.assertThrows( + AwsCryptoException.class, + "Encrypted data keys exceed maxEncryptedDataKeys", + () -> testCryptoClient_.encryptData(provider, PLAINTEXT)); + } + @Test + public void noDecryptWithMoreThanMaxEdks() { + KmsMasterKeyProvider provider = providerWithEdks(MAX_EDKS + 1); + byte[] ciphertext = AwsCrypto.standard().encryptData(provider, PLAINTEXT).getResult(); + TestUtils.assertThrows( + AwsCryptoException.class, + "Ciphertext encrypted data keys exceed maxEncryptedDataKeys", + () -> testCryptoClient_.decryptData(provider, ciphertext)); + verify(testClient_, never()).decrypt(any()); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java index 37fe9cbff..c889823ee 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -13,16 +13,6 @@ package com.amazonaws.encryptionsdk.kms; -import java.nio.ByteBuffer; -import java.security.SecureRandom; -import java.util.Collections; -import java.util.Date; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import java.util.UUID; - import com.amazonaws.AmazonClientException; import com.amazonaws.AmazonServiceException; import com.amazonaws.AmazonWebServiceRequest; @@ -84,337 +74,374 @@ import com.amazonaws.services.kms.model.RevokeGrantResult; import com.amazonaws.services.kms.model.UpdateKeyDescriptionRequest; import com.amazonaws.services.kms.model.UpdateKeyDescriptionResult; +import java.nio.ByteBuffer; +import java.security.SecureRandom; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.UUID; public class MockKMSClient extends AWSKMSClient { - private static final SecureRandom rnd = new SecureRandom(); - private static final String ACCOUNT_ID = "01234567890"; - private final Map results_ = new HashMap<>(); - private final Set activeKeys = new HashSet<>(); - private final Map keyAliases = new HashMap<>(); - private Region region_ = Region.getRegion(Regions.DEFAULT_REGION); - - @Override - public CreateAliasResult createAlias(CreateAliasRequest arg0) throws AmazonServiceException, AmazonClientException { - assertExists(arg0.getTargetKeyId()); - - keyAliases.put( - "alias/" + arg0.getAliasName(), - keyAliases.get(arg0.getTargetKeyId()) - ); - - return new CreateAliasResult(); - } - - @Override - public CreateGrantResult createGrant(CreateGrantRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public CreateKeyResult createKey() throws AmazonServiceException, AmazonClientException { - return createKey(new CreateKeyRequest()); - } - - @Override - public CreateKeyResult createKey(CreateKeyRequest req) throws AmazonServiceException, AmazonClientException { - String keyId = UUID.randomUUID().toString(); - String arn = "arn:aws:kms:" + region_.getName() + ":" + ACCOUNT_ID + ":key/" + keyId; - activeKeys.add(arn); - keyAliases.put(keyId, arn); - keyAliases.put(arn, arn); - CreateKeyResult result = new CreateKeyResult(); - result.setKeyMetadata(new KeyMetadata().withAWSAccountId(ACCOUNT_ID).withCreationDate(new Date()) - .withDescription(req.getDescription()).withEnabled(true).withKeyId(keyId) - .withKeyUsage(KeyUsageType.ENCRYPT_DECRYPT).withArn(arn)); - return result; - } - - @Override - public DecryptResult decrypt(DecryptRequest req) throws AmazonServiceException, AmazonClientException { - DecryptResult result = results_.get(new DecryptMapKey(req)); - if (result != null) { - // Copy it to avoid external modification - DecryptResult copy = new DecryptResult(); - copy.setKeyId(retrieveArn(result.getKeyId())); - byte[] pt = new byte[result.getPlaintext().limit()]; - result.getPlaintext().get(pt); - result.getPlaintext().rewind(); - copy.setPlaintext(ByteBuffer.wrap(pt)); - return copy; - } else { - throw new InvalidCiphertextException("Invalid Ciphertext"); - } - } - - @Override - public DeleteAliasResult deleteAlias(DeleteAliasRequest arg0) throws AmazonServiceException, AmazonClientException { + private static final SecureRandom rnd = new SecureRandom(); + private static final String ACCOUNT_ID = "01234567890"; + private final Map results_ = new HashMap<>(); + private final Set activeKeys = new HashSet<>(); + private final Map keyAliases = new HashMap<>(); + private Region region_ = Region.getRegion(Regions.DEFAULT_REGION); + + @Override + public CreateAliasResult createAlias(CreateAliasRequest arg0) + throws AmazonServiceException, AmazonClientException { + assertExists(arg0.getTargetKeyId()); + + keyAliases.put("alias/" + arg0.getAliasName(), keyAliases.get(arg0.getTargetKeyId())); + + return new CreateAliasResult(); + } + + @Override + public CreateGrantResult createGrant(CreateGrantRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public CreateKeyResult createKey() throws AmazonServiceException, AmazonClientException { + return createKey(new CreateKeyRequest()); + } + + @Override + public CreateKeyResult createKey(CreateKeyRequest req) + throws AmazonServiceException, AmazonClientException { + String keyId = UUID.randomUUID().toString(); + String arn = "arn:aws:kms:" + region_.getName() + ":" + ACCOUNT_ID + ":key/" + keyId; + activeKeys.add(arn); + keyAliases.put(keyId, arn); + keyAliases.put(arn, arn); + CreateKeyResult result = new CreateKeyResult(); + result.setKeyMetadata( + new KeyMetadata() + .withAWSAccountId(ACCOUNT_ID) + .withCreationDate(new Date()) + .withDescription(req.getDescription()) + .withEnabled(true) + .withKeyId(keyId) + .withKeyUsage(KeyUsageType.ENCRYPT_DECRYPT) + .withArn(arn)); + return result; + } + + @Override + public DecryptResult decrypt(DecryptRequest req) + throws AmazonServiceException, AmazonClientException { + DecryptResult result = results_.get(new DecryptMapKey(req)); + if (result != null) { + // Copy it to avoid external modification + DecryptResult copy = new DecryptResult(); + copy.setKeyId(retrieveArn(result.getKeyId())); + byte[] pt = new byte[result.getPlaintext().limit()]; + result.getPlaintext().get(pt); + result.getPlaintext().rewind(); + copy.setPlaintext(ByteBuffer.wrap(pt)); + return copy; + } else { + throw new InvalidCiphertextException("Invalid Ciphertext"); + } + } + + @Override + public DeleteAliasResult deleteAlias(DeleteAliasRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public DescribeKeyResult describeKey(DescribeKeyRequest arg0) + throws AmazonServiceException, AmazonClientException { + final String arn = retrieveArn(arg0.getKeyId()); + + final KeyMetadata keyMetadata = new KeyMetadata().withArn(arn).withKeyId(arn); + final DescribeKeyResult describeKeyResult = + new DescribeKeyResult().withKeyMetadata(keyMetadata); + + return describeKeyResult; + } + + @Override + public DisableKeyResult disableKey(DisableKeyRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public DisableKeyRotationResult disableKeyRotation(DisableKeyRotationRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public EnableKeyResult enableKey(EnableKeyRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public EnableKeyRotationResult enableKeyRotation(EnableKeyRotationRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public EncryptResult encrypt(EncryptRequest req) + throws AmazonServiceException, AmazonClientException { + // We internally delegate to encrypt, so as to avoid mockito detecting extra calls to encrypt + // when spying on the + // MockKMSClient, we put the real logic into a separate function. + return encrypt0(req); + } + + private EncryptResult encrypt0(EncryptRequest req) + throws AmazonServiceException, AmazonClientException { + final byte[] cipherText = new byte[512]; + rnd.nextBytes(cipherText); + DecryptResult dec = new DecryptResult(); + dec.withKeyId(retrieveArn(req.getKeyId())).withPlaintext(req.getPlaintext().asReadOnlyBuffer()); + ByteBuffer ctBuff = ByteBuffer.wrap(cipherText); + + results_.put(new DecryptMapKey(ctBuff, req.getEncryptionContext()), dec); + + String arn = retrieveArn(req.getKeyId()); + return new EncryptResult().withCiphertextBlob(ctBuff).withKeyId(arn); + } + + @Override + public GenerateDataKeyResult generateDataKey(GenerateDataKeyRequest req) + throws AmazonServiceException, AmazonClientException { + byte[] pt; + if (req.getKeySpec() != null) { + if (req.getKeySpec().contains("256")) { + pt = new byte[32]; + } else if (req.getKeySpec().contains("128")) { + pt = new byte[16]; + } else { throw new java.lang.UnsupportedOperationException(); - } - - @Override - public DescribeKeyResult describeKey(DescribeKeyRequest arg0) throws AmazonServiceException, AmazonClientException { - final String arn = retrieveArn(arg0.getKeyId()); - - final KeyMetadata keyMetadata = new KeyMetadata().withArn(arn).withKeyId(arn); - final DescribeKeyResult describeKeyResult = new DescribeKeyResult().withKeyMetadata(keyMetadata); - - return describeKeyResult; - } - - @Override - public DisableKeyResult disableKey(DisableKeyRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public DisableKeyRotationResult disableKeyRotation(DisableKeyRotationRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public EnableKeyResult enableKey(EnableKeyRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public EnableKeyRotationResult enableKeyRotation(EnableKeyRotationRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public EncryptResult encrypt(EncryptRequest req) throws AmazonServiceException, AmazonClientException { - // We internally delegate to encrypt, so as to avoid mockito detecting extra calls to encrypt when spying on the - // MockKMSClient, we put the real logic into a separate function. - return encrypt0(req); - } - - private EncryptResult encrypt0(EncryptRequest req) throws AmazonServiceException, AmazonClientException { - final byte[] cipherText = new byte[512]; - rnd.nextBytes(cipherText); - DecryptResult dec = new DecryptResult(); - dec.withKeyId(retrieveArn(req.getKeyId())).withPlaintext(req.getPlaintext().asReadOnlyBuffer()); - ByteBuffer ctBuff = ByteBuffer.wrap(cipherText); - - results_.put(new DecryptMapKey(ctBuff, req.getEncryptionContext()), dec); - - String arn = retrieveArn(req.getKeyId()); - return new EncryptResult().withCiphertextBlob(ctBuff).withKeyId(arn); - } - - @Override - public GenerateDataKeyResult generateDataKey(GenerateDataKeyRequest req) throws AmazonServiceException, - AmazonClientException { - byte[] pt; - if (req.getKeySpec() != null) { - if (req.getKeySpec().contains("256")) { - pt = new byte[32]; - } else if (req.getKeySpec().contains("128")) { - pt = new byte[16]; - } else { - throw new java.lang.UnsupportedOperationException(); - } - } else { - pt = new byte[req.getNumberOfBytes()]; - } - rnd.nextBytes(pt); - ByteBuffer ptBuff = ByteBuffer.wrap(pt); - EncryptResult encryptResult = encrypt0(new EncryptRequest().withKeyId(req.getKeyId()).withPlaintext(ptBuff) + } + } else { + pt = new byte[req.getNumberOfBytes()]; + } + rnd.nextBytes(pt); + ByteBuffer ptBuff = ByteBuffer.wrap(pt); + EncryptResult encryptResult = + encrypt0( + new EncryptRequest() + .withKeyId(req.getKeyId()) + .withPlaintext(ptBuff) .withEncryptionContext(req.getEncryptionContext())); - String arn = retrieveArn(req.getKeyId()); - return new GenerateDataKeyResult().withKeyId(arn).withCiphertextBlob(encryptResult.getCiphertextBlob()) - .withPlaintext(ptBuff); - } - - @Override - public GenerateDataKeyWithoutPlaintextResult generateDataKeyWithoutPlaintext( - GenerateDataKeyWithoutPlaintextRequest req) throws AmazonServiceException, AmazonClientException { - GenerateDataKeyRequest generateDataKeyRequest = new GenerateDataKeyRequest().withEncryptionContext(req.getEncryptionContext()) - .withGrantTokens(req.getGrantTokens()) - .withKeyId(req.getKeyId()) - .withKeySpec(req.getKeySpec()) - .withNumberOfBytes(req.getNumberOfBytes()); - GenerateDataKeyResult generateDataKey = generateDataKey(generateDataKeyRequest); - String arn = retrieveArn(req.getKeyId()); - return new GenerateDataKeyWithoutPlaintextResult().withCiphertextBlob(generateDataKey.getCiphertextBlob()) - .withKeyId(arn); - } - - @Override - public GenerateRandomResult generateRandom() throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public GenerateRandomResult generateRandom(GenerateRandomRequest arg0) throws AmazonServiceException, - AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public ResponseMetadata getCachedResponseMetadata(AmazonWebServiceRequest arg0) { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public GetKeyPolicyResult getKeyPolicy(GetKeyPolicyRequest arg0) throws AmazonServiceException, - AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public GetKeyRotationStatusResult getKeyRotationStatus(GetKeyRotationStatusRequest arg0) - throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public ListAliasesResult listAliases() throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public ListAliasesResult listAliases(ListAliasesRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public ListGrantsResult listGrants(ListGrantsRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public ListKeyPoliciesResult listKeyPolicies(ListKeyPoliciesRequest arg0) throws AmazonServiceException, - AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public ListKeysResult listKeys() throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public ListKeysResult listKeys(ListKeysRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public PutKeyPolicyResult putKeyPolicy(PutKeyPolicyRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public ReEncryptResult reEncrypt(ReEncryptRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public RetireGrantResult retireGrant(RetireGrantRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public RevokeGrantResult revokeGrant(RevokeGrantRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - @Override - public void setEndpoint(String arg0) { - // Do nothing - } - - @Override - public void setRegion(Region arg0) { - region_ = arg0; - } - - @Override - public void shutdown() { - // Do nothing - } - - @Override - public UpdateKeyDescriptionResult updateKeyDescription(UpdateKeyDescriptionRequest arg0) throws AmazonServiceException, - AmazonClientException { - throw new java.lang.UnsupportedOperationException(); - } - - public void deleteKey(final String keyId) { - final String arn = retrieveArn(keyId); - activeKeys.remove(arn); - } - - private String retrieveArn(final String keyId) { - String arn = keyAliases.get(keyId); - assertExists(arn); - return arn; - } - - private void assertExists(String keyId) { - if (keyAliases.containsKey(keyId)) { - keyId = keyAliases.get(keyId); - } - if (keyId == null || !activeKeys.contains(keyId)) { - throw new NotFoundException("Key doesn't exist: " + keyId); - } - } - - private static class DecryptMapKey { - private final ByteBuffer cipherText; - private final Map ec; - - public DecryptMapKey(DecryptRequest req) { - cipherText = req.getCiphertextBlob().asReadOnlyBuffer(); - if (req.getEncryptionContext() != null) { - ec = Collections.unmodifiableMap(new HashMap(req.getEncryptionContext())); - } else { - ec = Collections.emptyMap(); - } - } - - public DecryptMapKey(ByteBuffer ctBuff, Map ec) { - cipherText = ctBuff.asReadOnlyBuffer(); - if (ec != null) { - this.ec = Collections.unmodifiableMap(new HashMap(ec)); - } else { - this.ec = Collections.emptyMap(); - } - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((cipherText == null) ? 0 : cipherText.hashCode()); - result = prime * result + ((ec == null) ? 0 : ec.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - DecryptMapKey other = (DecryptMapKey) obj; - if (cipherText == null) { - if (other.cipherText != null) - return false; - } else if (!cipherText.equals(other.cipherText)) - return false; - if (ec == null) { - if (other.ec != null) - return false; - } else if (!ec.equals(other.ec)) - return false; - return true; - } - - @Override - public String toString() { - return "DecryptMapKey [cipherText=" + cipherText + ", ec=" + ec + "]"; - } - } + String arn = retrieveArn(req.getKeyId()); + return new GenerateDataKeyResult() + .withKeyId(arn) + .withCiphertextBlob(encryptResult.getCiphertextBlob()) + .withPlaintext(ptBuff); + } + + @Override + public GenerateDataKeyWithoutPlaintextResult generateDataKeyWithoutPlaintext( + GenerateDataKeyWithoutPlaintextRequest req) + throws AmazonServiceException, AmazonClientException { + GenerateDataKeyRequest generateDataKeyRequest = + new GenerateDataKeyRequest() + .withEncryptionContext(req.getEncryptionContext()) + .withGrantTokens(req.getGrantTokens()) + .withKeyId(req.getKeyId()) + .withKeySpec(req.getKeySpec()) + .withNumberOfBytes(req.getNumberOfBytes()); + GenerateDataKeyResult generateDataKey = generateDataKey(generateDataKeyRequest); + String arn = retrieveArn(req.getKeyId()); + return new GenerateDataKeyWithoutPlaintextResult() + .withCiphertextBlob(generateDataKey.getCiphertextBlob()) + .withKeyId(arn); + } + + @Override + public GenerateRandomResult generateRandom() + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public GenerateRandomResult generateRandom(GenerateRandomRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public ResponseMetadata getCachedResponseMetadata(AmazonWebServiceRequest arg0) { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public GetKeyPolicyResult getKeyPolicy(GetKeyPolicyRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public GetKeyRotationStatusResult getKeyRotationStatus(GetKeyRotationStatusRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public ListAliasesResult listAliases() throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public ListAliasesResult listAliases(ListAliasesRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public ListGrantsResult listGrants(ListGrantsRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public ListKeyPoliciesResult listKeyPolicies(ListKeyPoliciesRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public ListKeysResult listKeys() throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public ListKeysResult listKeys(ListKeysRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public PutKeyPolicyResult putKeyPolicy(PutKeyPolicyRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public ReEncryptResult reEncrypt(ReEncryptRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public RetireGrantResult retireGrant(RetireGrantRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public RevokeGrantResult revokeGrant(RevokeGrantRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public void setEndpoint(String arg0) { + // Do nothing + } + + @Override + public void setRegion(Region arg0) { + region_ = arg0; + } + + @Override + public void shutdown() { + // Do nothing + } + + @Override + public UpdateKeyDescriptionResult updateKeyDescription(UpdateKeyDescriptionRequest arg0) + throws AmazonServiceException, AmazonClientException { + throw new java.lang.UnsupportedOperationException(); + } + + public void deleteKey(final String keyId) { + final String arn = retrieveArn(keyId); + activeKeys.remove(arn); + } + + private String retrieveArn(final String keyId) { + String arn = keyAliases.get(keyId); + assertExists(arn); + return arn; + } + + private void assertExists(String keyId) { + if (keyAliases.containsKey(keyId)) { + keyId = keyAliases.get(keyId); + } + if (keyId == null || !activeKeys.contains(keyId)) { + throw new NotFoundException("Key doesn't exist: " + keyId); + } + } + + private static class DecryptMapKey { + private final ByteBuffer cipherText; + private final Map ec; + + public DecryptMapKey(DecryptRequest req) { + cipherText = req.getCiphertextBlob().asReadOnlyBuffer(); + if (req.getEncryptionContext() != null) { + ec = Collections.unmodifiableMap(new HashMap(req.getEncryptionContext())); + } else { + ec = Collections.emptyMap(); + } + } + + public DecryptMapKey(ByteBuffer ctBuff, Map ec) { + cipherText = ctBuff.asReadOnlyBuffer(); + if (ec != null) { + this.ec = Collections.unmodifiableMap(new HashMap(ec)); + } else { + this.ec = Collections.emptyMap(); + } + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((cipherText == null) ? 0 : cipherText.hashCode()); + result = prime * result + ((ec == null) ? 0 : ec.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + DecryptMapKey other = (DecryptMapKey) obj; + if (cipherText == null) { + if (other.cipherText != null) return false; + } else if (!cipherText.equals(other.cipherText)) return false; + if (ec == null) { + if (other.ec != null) return false; + } else if (!ec.equals(other.ec)) return false; + return true; + } + + @Override + public String toString() { + return "DecryptMapKey [cipherText=" + cipherText + ", ec=" + ec + "]"; + } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/kms/XCompatKmsDecryptTest.java b/src/test/java/com/amazonaws/encryptionsdk/kms/XCompatKmsDecryptTest.java index 4d79dbbad..875c9c5f5 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/kms/XCompatKmsDecryptTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/kms/XCompatKmsDecryptTest.java @@ -5,6 +5,10 @@ import static org.junit.Assert.assertArrayEquals; +import com.amazonaws.encryptionsdk.AwsCrypto; +import com.amazonaws.encryptionsdk.CryptoResult; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import java.io.File; import java.nio.file.Files; import java.nio.file.Paths; @@ -13,94 +17,83 @@ import java.util.Collections; import java.util.List; import java.util.Map; - import org.apache.commons.lang3.StringUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; -import com.amazonaws.encryptionsdk.AwsCrypto; -import com.amazonaws.encryptionsdk.CryptoResult; -import com.amazonaws.encryptionsdk.CommitmentPolicy; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - @RunWith(Parameterized.class) public class XCompatKmsDecryptTest { - private String plaintextFileName; - private String ciphertextFileName; - private String kmsKeyId; + private String plaintextFileName; + private String ciphertextFileName; + private String kmsKeyId; - public XCompatKmsDecryptTest(String plaintextFileName, String ciphertextFileName, String kmsKeyId) { - this.plaintextFileName = plaintextFileName; - this.ciphertextFileName = ciphertextFileName; - this.kmsKeyId = kmsKeyId; - } + public XCompatKmsDecryptTest( + String plaintextFileName, String ciphertextFileName, String kmsKeyId) { + this.plaintextFileName = plaintextFileName; + this.ciphertextFileName = ciphertextFileName; + this.kmsKeyId = kmsKeyId; + } - @Parameters(name="{index}: testDecryptFromFile({0}, {1}, {2})") - public static Collection data() throws Exception { - String baseDirName; - baseDirName = System.getProperty("staticCompatibilityResourcesDir"); - if (baseDirName == null) { - baseDirName = - XCompatKmsDecryptTest.class.getProtectionDomain().getCodeSource().getLocation().getPath() + - "aws_encryption_sdk_resources"; - } + @Parameters(name = "{index}: testDecryptFromFile({0}, {1}, {2})") + public static Collection data() throws Exception { + String baseDirName; + baseDirName = System.getProperty("staticCompatibilityResourcesDir"); + if (baseDirName == null) { + baseDirName = + XCompatKmsDecryptTest.class.getProtectionDomain().getCodeSource().getLocation().getPath() + + "aws_encryption_sdk_resources"; + } - List testCases_ = new ArrayList(); + List testCases_ = new ArrayList(); - String ciphertextManifestName = StringUtils.join( - new String[]{ - baseDirName, - "manifests", - "ciphertext.manifest" - }, - File.separator - ); - File ciphertextManifestFile = new File(ciphertextManifestName); + String ciphertextManifestName = + StringUtils.join( + new String[] {baseDirName, "manifests", "ciphertext.manifest"}, File.separator); + File ciphertextManifestFile = new File(ciphertextManifestName); - if (!ciphertextManifestFile.exists()) { - return Collections.emptyList(); - } + if (!ciphertextManifestFile.exists()) { + return Collections.emptyList(); + } - ObjectMapper ciphertextManifestMapper = new ObjectMapper(); - Map ciphertextManifest = ciphertextManifestMapper.readValue( - ciphertextManifestFile, - new TypeReference>(){} - ); + ObjectMapper ciphertextManifestMapper = new ObjectMapper(); + Map ciphertextManifest = + ciphertextManifestMapper.readValue( + ciphertextManifestFile, new TypeReference>() {}); - List> testCases = (List>)ciphertextManifest.get("test_cases"); - for (Map testCase : testCases) { - Map plaintext = (Map)testCase.get("plaintext"); - Map ciphertext = (Map)testCase.get("ciphertext"); + List> testCases = + (List>) ciphertextManifest.get("test_cases"); + for (Map testCase : testCases) { + Map plaintext = (Map) testCase.get("plaintext"); + Map ciphertext = (Map) testCase.get("ciphertext"); - List> masterKeys = (List>)testCase.get("master_keys"); - for (Map masterKey : masterKeys) { - String providerId = (String) masterKey.get("provider_id"); - if (providerId.equals("aws-kms") && (boolean)masterKey.get("decryptable")) { - testCases_.add(new Object[] { - baseDirName + File.separator + plaintext.get("filename"), - baseDirName + File.separator + ciphertext.get("filename"), - (String)masterKey.get("key_id") - }); - break; - } - } + List> masterKeys = + (List>) testCase.get("master_keys"); + for (Map masterKey : masterKeys) { + String providerId = (String) masterKey.get("provider_id"); + if (providerId.equals("aws-kms") && (boolean) masterKey.get("decryptable")) { + testCases_.add( + new Object[] { + baseDirName + File.separator + plaintext.get("filename"), + baseDirName + File.separator + ciphertext.get("filename"), + (String) masterKey.get("key_id") + }); + break; } - return testCases_; + } } + return testCases_; + } - @Test - public void testDecryptFromFile() throws Exception { - AwsCrypto crypto = AwsCrypto.standard(); - final KmsMasterKeyProvider masterKeyProvider = KmsMasterKeyProvider.builder().buildStrict(kmsKeyId); - byte ciphertextBytes[] = Files.readAllBytes(Paths.get(ciphertextFileName)); - byte plaintextBytes[] = Files.readAllBytes(Paths.get(plaintextFileName)); - final CryptoResult decryptResult = crypto.decryptData( - masterKeyProvider, - ciphertextBytes - ); - assertArrayEquals(plaintextBytes, (byte[])decryptResult.getResult()); - } + @Test + public void testDecryptFromFile() throws Exception { + AwsCrypto crypto = AwsCrypto.standard(); + final KmsMasterKeyProvider masterKeyProvider = + KmsMasterKeyProvider.builder().buildStrict(kmsKeyId); + byte ciphertextBytes[] = Files.readAllBytes(Paths.get(ciphertextFileName)); + byte plaintextBytes[] = Files.readAllBytes(Paths.get(plaintextFileName)); + final CryptoResult decryptResult = crypto.decryptData(masterKeyProvider, ciphertextBytes); + assertArrayEquals(plaintextBytes, (byte[]) decryptResult.getResult()); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/model/ByteFormatCheckValues.java b/src/test/java/com/amazonaws/encryptionsdk/model/ByteFormatCheckValues.java index 4299eeec6..6de3d73b0 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/model/ByteFormatCheckValues.java +++ b/src/test/java/com/amazonaws/encryptionsdk/model/ByteFormatCheckValues.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -16,51 +16,55 @@ import com.amazonaws.encryptionsdk.internal.Utils; public class ByteFormatCheckValues { - private static final String base64MessageId_ = "NQ/NXvg4mMN5zm5JFZHUWw=="; - private static final String base64PlaintextKey_ = "N9vW5Ox5xh4BrUgeaL2gXg=="; - private static final String base64EncryptedKey_ = "Zg5VUzPfgD0/H92Fx7h0ew=="; + private static final String base64MessageId_ = "NQ/NXvg4mMN5zm5JFZHUWw=="; + private static final String base64PlaintextKey_ = "N9vW5Ox5xh4BrUgeaL2gXg=="; + private static final String base64EncryptedKey_ = "Zg5VUzPfgD0/H92Fx7h0ew=="; - private static final String base64Nonce_ = "3rktZhNbwrZBSaqt"; - private static final String base64Tag_ = "cBPLjSEz0fsWDToxTqMvfQ=="; + private static final String base64Nonce_ = "3rktZhNbwrZBSaqt"; + private static final String base64Tag_ = "cBPLjSEz0fsWDToxTqMvfQ=="; - private static final String base64CiphertextHeaderHash_ = "bCScP4wa25l9TLQZ4KLv7xqVCg9AN58lB1FHrl2yVes="; - private static final String base64BlockHeaderHash_ = "7q8fULz95XaJqrksEuzDoVpSYih54QbPC1+v833s/5Y="; - private static final String base64FrameHeaderHash_ = "tB/UmW+/hLJU5i2D9Or8guXrn8lP0uCiUaP1KkdyKGs="; - private static final String base64FinalFrameHeaderHash_ = "/b2fVFOxvnaM5vXDMGyyFPNTWMjuU/c/48qeH3uTHj0="; + private static final String base64CiphertextHeaderHash_ = + "bCScP4wa25l9TLQZ4KLv7xqVCg9AN58lB1FHrl2yVes="; + private static final String base64BlockHeaderHash_ = + "7q8fULz95XaJqrksEuzDoVpSYih54QbPC1+v833s/5Y="; + private static final String base64FrameHeaderHash_ = + "tB/UmW+/hLJU5i2D9Or8guXrn8lP0uCiUaP1KkdyKGs="; + private static final String base64FinalFrameHeaderHash_ = + "/b2fVFOxvnaM5vXDMGyyFPNTWMjuU/c/48qeH3uTHj0="; - public static byte[] getMessageId() { - return Utils.decodeBase64String(base64MessageId_); - } + public static byte[] getMessageId() { + return Utils.decodeBase64String(base64MessageId_); + } - public static byte[] getEncryptedKey() { - return Utils.decodeBase64String(base64EncryptedKey_); - } + public static byte[] getEncryptedKey() { + return Utils.decodeBase64String(base64EncryptedKey_); + } - public static byte[] getPlaintextKey() { - return Utils.decodeBase64String(base64PlaintextKey_); - } + public static byte[] getPlaintextKey() { + return Utils.decodeBase64String(base64PlaintextKey_); + } - public static byte[] getCiphertextHeaderHash() { - return Utils.decodeBase64String(base64CiphertextHeaderHash_); - } + public static byte[] getCiphertextHeaderHash() { + return Utils.decodeBase64String(base64CiphertextHeaderHash_); + } - public static byte[] getCipherBlockHeaderHash() { - return Utils.decodeBase64String(base64BlockHeaderHash_); - } + public static byte[] getCipherBlockHeaderHash() { + return Utils.decodeBase64String(base64BlockHeaderHash_); + } - public static byte[] getCipherFrameHeaderHash() { - return Utils.decodeBase64String(base64FrameHeaderHash_); - } + public static byte[] getCipherFrameHeaderHash() { + return Utils.decodeBase64String(base64FrameHeaderHash_); + } - public static byte[] getCipherFinalFrameHeaderHash() { - return Utils.decodeBase64String(base64FinalFrameHeaderHash_); - } + public static byte[] getCipherFinalFrameHeaderHash() { + return Utils.decodeBase64String(base64FinalFrameHeaderHash_); + } - public static byte[] getNonce() { - return Utils.decodeBase64String(base64Nonce_); - } + public static byte[] getNonce() { + return Utils.decodeBase64String(base64Nonce_); + } - public static byte[] getTag() { - return Utils.decodeBase64String(base64Tag_); - } + public static byte[] getTag() { + return Utils.decodeBase64String(base64Tag_); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/model/CipherBlockHeadersTest.java b/src/test/java/com/amazonaws/encryptionsdk/model/CipherBlockHeadersTest.java index 555ea4963..77edfdf61 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/model/CipherBlockHeadersTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/model/CipherBlockHeadersTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -17,145 +17,144 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import java.nio.ByteBuffer; -import java.util.Arrays; - -import org.junit.Test; - import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.internal.Constants; import com.amazonaws.encryptionsdk.internal.RandomBytesGenerator; import com.amazonaws.encryptionsdk.internal.TestIOUtils; +import java.nio.ByteBuffer; +import java.util.Arrays; +import org.junit.Test; public class CipherBlockHeadersTest { - final int nonceLen_ = 12; - byte[] nonce_ = RandomBytesGenerator.generate(nonceLen_); - final int sampleContentLen_ = 1024; - - @Test - public void serializeDeserialize() { - final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); - - final byte[] headerBytes = cipherBlockHeaders.toByteArray(); - - final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); - reconstructedHeaders.setNonceLength((short) nonceLen_); - reconstructedHeaders.deserialize(headerBytes, 0); - final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); - - assertArrayEquals(headerBytes, reconstructedHeaderBytes); + final int nonceLen_ = 12; + byte[] nonce_ = RandomBytesGenerator.generate(nonceLen_); + final int sampleContentLen_ = 1024; + + @Test + public void serializeDeserialize() { + final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); + + final byte[] headerBytes = cipherBlockHeaders.toByteArray(); + + final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); + reconstructedHeaders.setNonceLength((short) nonceLen_); + reconstructedHeaders.deserialize(headerBytes, 0); + final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); + + assertArrayEquals(headerBytes, reconstructedHeaderBytes); + } + + private boolean serializeDeserializeStreaming(final boolean isFinalFrame) { + final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); + + final byte[] headerBytes = cipherBlockHeaders.toByteArray(); + final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); + reconstructedHeaders.setNonceLength((short) nonceLen_); + + int totalParsedBytes = 0; + int bytesToParseLen = 1; + int bytesParsed; + + while (reconstructedHeaders.isComplete() == false) { + final byte[] bytesToParse = new byte[bytesToParseLen]; + System.arraycopy(headerBytes, totalParsedBytes, bytesToParse, 0, bytesToParse.length); + + bytesParsed = reconstructedHeaders.deserialize(bytesToParse, 0); + if (bytesParsed == 0) { + bytesToParseLen++; + } else { + totalParsedBytes += bytesParsed; + bytesToParseLen = 1; + } } - private boolean serializeDeserializeStreaming(final boolean isFinalFrame) { - final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); + final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); - final byte[] headerBytes = cipherBlockHeaders.toByteArray(); - final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); - reconstructedHeaders.setNonceLength((short) nonceLen_); + return Arrays.equals(headerBytes, reconstructedHeaderBytes) ? true : false; + } - int totalParsedBytes = 0; - int bytesToParseLen = 1; - int bytesParsed; + @Test + public void serializeDeserializeStreamingFinalFrame() { + assertTrue(serializeDeserializeStreaming(true)); + } - while (reconstructedHeaders.isComplete() == false) { - final byte[] bytesToParse = new byte[bytesToParseLen]; - System.arraycopy(headerBytes, totalParsedBytes, bytesToParse, 0, bytesToParse.length); + @Test + public void deserializeNull() { + final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(); + final int deserializedBytes = cipherBlockHeaders.deserialize(null, 0); - bytesParsed = reconstructedHeaders.deserialize(bytesToParse, 0); - if (bytesParsed == 0) { - bytesToParseLen++; - } else { - totalParsedBytes += bytesParsed; - bytesToParseLen = 1; - } - } + assertEquals(0, deserializedBytes); + } - final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); + @Test + public void checkContentLen() { + final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); + final byte[] headerBytes = cipherBlockHeaders.toByteArray(); - return Arrays.equals(headerBytes, reconstructedHeaderBytes) ? true : false; - } + final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); + reconstructedHeaders.setNonceLength((short) nonceLen_); + reconstructedHeaders.deserialize(headerBytes, 0); - @Test - public void serializeDeserializeStreamingFinalFrame() { - assertTrue(serializeDeserializeStreaming(true)); - } + assertEquals(sampleContentLen_, reconstructedHeaders.getContentLength()); + } - @Test - public void deserializeNull() { - final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(); - final int deserializedBytes = cipherBlockHeaders.deserialize(null, 0); + @Test + public void checkNonce() { + final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); + final byte[] headerBytes = cipherBlockHeaders.toByteArray(); - assertEquals(0, deserializedBytes); - } + final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); + reconstructedHeaders.setNonceLength((short) nonceLen_); + reconstructedHeaders.deserialize(headerBytes, 0); - @Test - public void checkContentLen() { - final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); - final byte[] headerBytes = cipherBlockHeaders.toByteArray(); + assertArrayEquals(nonce_, reconstructedHeaders.getNonce()); + } - final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); - reconstructedHeaders.setNonceLength((short) nonceLen_); - reconstructedHeaders.deserialize(headerBytes, 0); + @Test + public void checkNullNonce() { + final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(); + cipherBlockHeaders.setNonceLength((short) nonceLen_); - assertEquals(sampleContentLen_, reconstructedHeaders.getContentLength()); - } + assertArrayEquals(null, cipherBlockHeaders.getNonce()); + } - @Test - public void checkNonce() { - final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); - final byte[] headerBytes = cipherBlockHeaders.toByteArray(); + @Test(expected = AwsCryptoException.class) + public void nullNonce() { + new CipherBlockHeaders(null, sampleContentLen_); + } - final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); - reconstructedHeaders.setNonceLength((short) nonceLen_); - reconstructedHeaders.deserialize(headerBytes, 0); + @Test(expected = AwsCryptoException.class) + public void invalidNonce() { + new CipherBlockHeaders(new byte[Constants.MAX_NONCE_LENGTH + 1], sampleContentLen_); + } - assertArrayEquals(nonce_, reconstructedHeaders.getNonce()); - } + @Test(expected = BadCiphertextException.class) + public void invalidContentLen() { + final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); + final byte[] headerBytes = cipherBlockHeaders.toByteArray(); + final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); - @Test - public void checkNullNonce() { - final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(); - cipherBlockHeaders.setNonceLength((short) nonceLen_); + // Pull out nonce to move to content len + final byte[] nonce = new byte[nonceLen_]; + headerBuff.get(nonce); - assertArrayEquals(null, cipherBlockHeaders.getNonce()); - } + // Set content length (of type long) to -1; + headerBuff.putLong(-1); - @Test(expected = AwsCryptoException.class) - public void nullNonce() { - new CipherBlockHeaders(null, sampleContentLen_); - } - - @Test(expected = AwsCryptoException.class) - public void invalidNonce() { - new CipherBlockHeaders(new byte[Constants.MAX_NONCE_LENGTH + 1], sampleContentLen_); - } + final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); + reconstructedHeaders.setNonceLength((short) nonceLen_); + reconstructedHeaders.deserialize(headerBuff.array(), 0); + } - @Test(expected = BadCiphertextException.class) - public void invalidContentLen() { - final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); - final byte[] headerBytes = cipherBlockHeaders.toByteArray(); - final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); + @Test + public void byteFormatCheck() { + nonce_ = ByteFormatCheckValues.getNonce(); + final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); - // Pull out nonce to move to content len - final byte[] nonce = new byte[nonceLen_]; - headerBuff.get(nonce); + final byte[] cipherBlockHeaderHash = + TestIOUtils.getSha256Hash(cipherBlockHeaders.toByteArray()); - // Set content length (of type long) to -1; - headerBuff.putLong(-1); - - final CipherBlockHeaders reconstructedHeaders = new CipherBlockHeaders(); - reconstructedHeaders.setNonceLength((short) nonceLen_); - reconstructedHeaders.deserialize(headerBuff.array(), 0); - } - - @Test - public void byteFormatCheck() { - nonce_ = ByteFormatCheckValues.getNonce(); - final CipherBlockHeaders cipherBlockHeaders = new CipherBlockHeaders(nonce_, sampleContentLen_); - - final byte[] cipherBlockHeaderHash = TestIOUtils.getSha256Hash(cipherBlockHeaders.toByteArray()); - - assertArrayEquals(ByteFormatCheckValues.getCipherBlockHeaderHash(), cipherBlockHeaderHash); - } + assertArrayEquals(ByteFormatCheckValues.getCipherBlockHeaderHash(), cipherBlockHeaderHash); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/model/CipherFrameHeadersTest.java b/src/test/java/com/amazonaws/encryptionsdk/model/CipherFrameHeadersTest.java index 00acd6834..0e963eb8c 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/model/CipherFrameHeadersTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/model/CipherFrameHeadersTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -17,218 +17,204 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import java.nio.ByteBuffer; -import java.util.Arrays; - import com.amazonaws.encryptionsdk.TestUtils; -import org.junit.Test; - -import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.BadCiphertextException; import com.amazonaws.encryptionsdk.internal.Constants; import com.amazonaws.encryptionsdk.internal.RandomBytesGenerator; import com.amazonaws.encryptionsdk.internal.TestIOUtils; +import java.nio.ByteBuffer; +import java.util.Arrays; +import org.junit.Test; public class CipherFrameHeadersTest { - final int nonceLen_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG.getNonceLen(); - final int testSeqNum_ = 1; - final int testFrameContentLen_ = 4096; - byte[] nonce_ = RandomBytesGenerator.generate(nonceLen_); - - @Test - public void serializeDeserializeTest() { - for (boolean includeContentLen : new boolean[] { true, false }) { - for (boolean isFinalFrame : new boolean[] { true, false }) { - assertTrue(serializeDeserialize(includeContentLen, isFinalFrame)); - } - } - } + final int nonceLen_ = TestUtils.DEFAULT_TEST_CRYPTO_ALG.getNonceLen(); + final int testSeqNum_ = 1; + final int testFrameContentLen_ = 4096; + byte[] nonce_ = RandomBytesGenerator.generate(nonceLen_); - @Test - public void serializeDeserializeStreamingTest() { - for (boolean includeContentLen : new boolean[] { true, false }) { - for (boolean isFinalFrame : new boolean[] { true, false }) { - assertTrue(serializeDeserializeStreaming(includeContentLen, isFinalFrame)); - } - } + @Test + public void serializeDeserializeTest() { + for (boolean includeContentLen : new boolean[] {true, false}) { + for (boolean isFinalFrame : new boolean[] {true, false}) { + assertTrue(serializeDeserialize(includeContentLen, isFinalFrame)); + } } + } - private byte[] createHeaderBytes(final boolean includeContentLen, final boolean isFinalFrame) { - final CipherFrameHeaders cipherFrameHeaders = new CipherFrameHeaders( - testSeqNum_, - nonce_, - testFrameContentLen_, - isFinalFrame); - cipherFrameHeaders.includeFrameSize(includeContentLen); - - return cipherFrameHeaders.toByteArray(); + @Test + public void serializeDeserializeStreamingTest() { + for (boolean includeContentLen : new boolean[] {true, false}) { + for (boolean isFinalFrame : new boolean[] {true, false}) { + assertTrue(serializeDeserializeStreaming(includeContentLen, isFinalFrame)); + } } + } - private CipherFrameHeaders deserialize(final boolean parseContentLen, final byte[] headerBytes) { - final CipherFrameHeaders reconstructedHeaders = new CipherFrameHeaders(); - reconstructedHeaders.setNonceLength((short) nonceLen_); - reconstructedHeaders.includeFrameSize(parseContentLen); - reconstructedHeaders.deserialize(headerBytes, 0); + private byte[] createHeaderBytes(final boolean includeContentLen, final boolean isFinalFrame) { + final CipherFrameHeaders cipherFrameHeaders = + new CipherFrameHeaders(testSeqNum_, nonce_, testFrameContentLen_, isFinalFrame); + cipherFrameHeaders.includeFrameSize(includeContentLen); - return reconstructedHeaders; - } + return cipherFrameHeaders.toByteArray(); + } - private boolean serializeDeserialize(final boolean includeContentLen, final boolean isFinalFrame) { - final byte[] headerBytes = createHeaderBytes(includeContentLen, isFinalFrame); - final CipherFrameHeaders reconstructedHeaders = deserialize(includeContentLen, headerBytes); + private CipherFrameHeaders deserialize(final boolean parseContentLen, final byte[] headerBytes) { + final CipherFrameHeaders reconstructedHeaders = new CipherFrameHeaders(); + reconstructedHeaders.setNonceLength((short) nonceLen_); + reconstructedHeaders.includeFrameSize(parseContentLen); + reconstructedHeaders.deserialize(headerBytes, 0); - final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); + return reconstructedHeaders; + } - return Arrays.equals(headerBytes, reconstructedHeaderBytes) ? true : false; - } + private boolean serializeDeserialize( + final boolean includeContentLen, final boolean isFinalFrame) { + final byte[] headerBytes = createHeaderBytes(includeContentLen, isFinalFrame); + final CipherFrameHeaders reconstructedHeaders = deserialize(includeContentLen, headerBytes); - private boolean serializeDeserializeStreaming(final boolean includeContentLen, final boolean isFinalFrame) { - final byte[] headerBytes = createHeaderBytes(includeContentLen, isFinalFrame); + final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); - final CipherFrameHeaders reconstructedHeaders = new CipherFrameHeaders(); - reconstructedHeaders.setNonceLength((short) nonceLen_); - reconstructedHeaders.includeFrameSize(includeContentLen); + return Arrays.equals(headerBytes, reconstructedHeaderBytes) ? true : false; + } - int totalParsedBytes = 0; - int bytesToParseLen = 1; - int bytesParsed; + private boolean serializeDeserializeStreaming( + final boolean includeContentLen, final boolean isFinalFrame) { + final byte[] headerBytes = createHeaderBytes(includeContentLen, isFinalFrame); - while (reconstructedHeaders.isComplete() == false) { - final byte[] bytesToParse = new byte[bytesToParseLen]; - System.arraycopy(headerBytes, totalParsedBytes, bytesToParse, 0, bytesToParse.length); + final CipherFrameHeaders reconstructedHeaders = new CipherFrameHeaders(); + reconstructedHeaders.setNonceLength((short) nonceLen_); + reconstructedHeaders.includeFrameSize(includeContentLen); - bytesParsed = reconstructedHeaders.deserialize(bytesToParse, 0); - if (bytesParsed == 0) { - bytesToParseLen++; - } else { - totalParsedBytes += bytesParsed; - bytesToParseLen = 1; - } - } + int totalParsedBytes = 0; + int bytesToParseLen = 1; + int bytesParsed; - final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); + while (reconstructedHeaders.isComplete() == false) { + final byte[] bytesToParse = new byte[bytesToParseLen]; + System.arraycopy(headerBytes, totalParsedBytes, bytesToParse, 0, bytesToParse.length); - return Arrays.equals(headerBytes, reconstructedHeaderBytes) ? true : false; + bytesParsed = reconstructedHeaders.deserialize(bytesToParse, 0); + if (bytesParsed == 0) { + bytesToParseLen++; + } else { + totalParsedBytes += bytesParsed; + bytesToParseLen = 1; + } } - @Test - public void deserializeNull() { - final CipherFrameHeaders cipherFrameHeaders = new CipherFrameHeaders(); - final int deserializedBytes = cipherFrameHeaders.deserialize(null, 0); + final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); - assertEquals(0, deserializedBytes); - } + return Arrays.equals(headerBytes, reconstructedHeaderBytes) ? true : false; + } - @Test - public void checkNullNonce() { - final CipherFrameHeaders cipherFrameHeaders = new CipherFrameHeaders(); - cipherFrameHeaders.setNonceLength((short) nonceLen_); + @Test + public void deserializeNull() { + final CipherFrameHeaders cipherFrameHeaders = new CipherFrameHeaders(); + final int deserializedBytes = cipherFrameHeaders.deserialize(null, 0); - assertEquals(null, cipherFrameHeaders.getNonce()); - } + assertEquals(0, deserializedBytes); + } - @Test - public void checkNonce() { - final byte[] headerBytes = createHeaderBytes(false, false); + @Test + public void checkNullNonce() { + final CipherFrameHeaders cipherFrameHeaders = new CipherFrameHeaders(); + cipherFrameHeaders.setNonceLength((short) nonceLen_); - final CipherFrameHeaders reconstructedHeaders = deserialize(false, headerBytes); + assertEquals(null, cipherFrameHeaders.getNonce()); + } - assertArrayEquals(nonce_, reconstructedHeaders.getNonce()); - } + @Test + public void checkNonce() { + final byte[] headerBytes = createHeaderBytes(false, false); - @Test - public void checkSeqNum() { - final byte[] headerBytes = createHeaderBytes(false, false); + final CipherFrameHeaders reconstructedHeaders = deserialize(false, headerBytes); - final CipherFrameHeaders reconstructedHeaders = deserialize(false, headerBytes); + assertArrayEquals(nonce_, reconstructedHeaders.getNonce()); + } - assertEquals(testSeqNum_, reconstructedHeaders.getSequenceNumber()); - } + @Test + public void checkSeqNum() { + final byte[] headerBytes = createHeaderBytes(false, false); - @Test - public void checkFrameLen() { - final boolean isFinalFrame = false; - final boolean includeContentLen = true; + final CipherFrameHeaders reconstructedHeaders = deserialize(false, headerBytes); - final byte[] headerBytes = createHeaderBytes(includeContentLen, isFinalFrame); + assertEquals(testSeqNum_, reconstructedHeaders.getSequenceNumber()); + } - final CipherFrameHeaders reconstructedHeaders = deserialize(includeContentLen, headerBytes); + @Test + public void checkFrameLen() { + final boolean isFinalFrame = false; + final boolean includeContentLen = true; - assertEquals(testFrameContentLen_, reconstructedHeaders.getFrameContentLength()); - } + final byte[] headerBytes = createHeaderBytes(includeContentLen, isFinalFrame); - @Test(expected = BadCiphertextException.class) - public void invalidFrameLen() { - final boolean isFinalFrame = false; - final boolean includeContentLen = true; + final CipherFrameHeaders reconstructedHeaders = deserialize(includeContentLen, headerBytes); - final byte[] headerBytes = createHeaderBytes(includeContentLen, isFinalFrame); - final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); + assertEquals(testFrameContentLen_, reconstructedHeaders.getFrameContentLength()); + } - // Pull out seq num to move to nonce - headerBuff.getInt(); + @Test(expected = BadCiphertextException.class) + public void invalidFrameLen() { + final boolean isFinalFrame = false; + final boolean includeContentLen = true; - // Pull out nonce to move to content len - final byte[] nonce = new byte[nonceLen_]; - headerBuff.get(nonce); + final byte[] headerBytes = createHeaderBytes(includeContentLen, isFinalFrame); + final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); - // Set content length (of type long) to -1; - headerBuff.putInt(-1); + // Pull out seq num to move to nonce + headerBuff.getInt(); - final CipherFrameHeaders reconstructedHeaders = new CipherFrameHeaders(); - reconstructedHeaders.setNonceLength((short) nonceLen_); - reconstructedHeaders.includeFrameSize(includeContentLen); - reconstructedHeaders.deserialize(headerBuff.array(), 0); - } + // Pull out nonce to move to content len + final byte[] nonce = new byte[nonceLen_]; + headerBuff.get(nonce); - @Test(expected = AwsCryptoException.class) - public void nullNonce() { - boolean isFinalFrame = false; - new CipherFrameHeaders( - testSeqNum_, - null, - testFrameContentLen_, - isFinalFrame); - } + // Set content length (of type long) to -1; + headerBuff.putInt(-1); - @Test(expected = AwsCryptoException.class) - public void invalidNonce() { - boolean isFinalFrame = false; - new CipherFrameHeaders( - testSeqNum_, - new byte[Constants.MAX_NONCE_LENGTH + 1], - testFrameContentLen_, - isFinalFrame); - } + final CipherFrameHeaders reconstructedHeaders = new CipherFrameHeaders(); + reconstructedHeaders.setNonceLength((short) nonceLen_); + reconstructedHeaders.includeFrameSize(includeContentLen); + reconstructedHeaders.deserialize(headerBuff.array(), 0); + } - @Test - public void byteFormatCheck() { - boolean isFinalFrame = false; - nonce_ = ByteFormatCheckValues.getNonce(); - final CipherFrameHeaders cipherFrameHeaders = new CipherFrameHeaders( - testSeqNum_, - nonce_, - testFrameContentLen_, - isFinalFrame); + @Test(expected = AwsCryptoException.class) + public void nullNonce() { + boolean isFinalFrame = false; + new CipherFrameHeaders(testSeqNum_, null, testFrameContentLen_, isFinalFrame); + } - final byte[] cipherFrameHeaderHash = TestIOUtils.getSha256Hash(cipherFrameHeaders.toByteArray()); + @Test(expected = AwsCryptoException.class) + public void invalidNonce() { + boolean isFinalFrame = false; + new CipherFrameHeaders( + testSeqNum_, new byte[Constants.MAX_NONCE_LENGTH + 1], testFrameContentLen_, isFinalFrame); + } - assertArrayEquals(ByteFormatCheckValues.getCipherFrameHeaderHash(), cipherFrameHeaderHash); - } + @Test + public void byteFormatCheck() { + boolean isFinalFrame = false; + nonce_ = ByteFormatCheckValues.getNonce(); + final CipherFrameHeaders cipherFrameHeaders = + new CipherFrameHeaders(testSeqNum_, nonce_, testFrameContentLen_, isFinalFrame); - @Test - public void byteFormatCheckforFinalFrame() { - boolean isFinalFrame = true; - nonce_ = ByteFormatCheckValues.getNonce(); - final CipherFrameHeaders cipherFinalFrameHeaders = new CipherFrameHeaders( - testSeqNum_, - nonce_, - testFrameContentLen_, - isFinalFrame); + final byte[] cipherFrameHeaderHash = + TestIOUtils.getSha256Hash(cipherFrameHeaders.toByteArray()); - final byte[] cipherFinalFrameHeaderHash = TestIOUtils.getSha256Hash(cipherFinalFrameHeaders.toByteArray()); + assertArrayEquals(ByteFormatCheckValues.getCipherFrameHeaderHash(), cipherFrameHeaderHash); + } - assertArrayEquals(ByteFormatCheckValues.getCipherFinalFrameHeaderHash(), cipherFinalFrameHeaderHash); - } + @Test + public void byteFormatCheckforFinalFrame() { + boolean isFinalFrame = true; + nonce_ = ByteFormatCheckValues.getNonce(); + final CipherFrameHeaders cipherFinalFrameHeaders = + new CipherFrameHeaders(testSeqNum_, nonce_, testFrameContentLen_, isFinalFrame); + + final byte[] cipherFinalFrameHeaderHash = + TestIOUtils.getSha256Hash(cipherFinalFrameHeaders.toByteArray()); + + assertArrayEquals( + ByteFormatCheckValues.getCipherFinalFrameHeaderHash(), cipherFinalFrameHeaderHash); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/model/CiphertextHeadersTest.java b/src/test/java/com/amazonaws/encryptionsdk/model/CiphertextHeadersTest.java index 2b0f23a3c..8dc1cac5d 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/model/CiphertextHeadersTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/model/CiphertextHeadersTest.java @@ -8,16 +8,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import com.amazonaws.encryptionsdk.TestUtils; -import org.junit.Test; - import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; @@ -25,531 +15,577 @@ import com.amazonaws.encryptionsdk.internal.Constants; import com.amazonaws.encryptionsdk.internal.EncryptionContextSerializer; import com.amazonaws.encryptionsdk.internal.RandomBytesGenerator; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.Test; public class CiphertextHeadersTest { - final CryptoAlgorithm cryptoAlgo_ = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; - final String keyProviderId_ = "None"; - final byte[] keyProviderInfo_ = "TestKeyID".getBytes(); - final byte version_ = cryptoAlgo_.getMessageFormatVersion(); - final byte invalidVersion_ = 0x00; - final CiphertextType type_ = CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA; - final int nonceLen_ = cryptoAlgo_.getNonceLen(); - final int tagLenBytes_ = cryptoAlgo_.getTagLen(); - final ContentType contentType_ = ContentType.FRAME; - final int frameSize_ = AwsCrypto.getDefaultFrameSize(); - - // A set of crypto algs that are representative of the different ciphertext header formats - final List testAlgs = Arrays.asList(CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, - CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256); - - byte[] encryptedKey_ = RandomBytesGenerator.generate(cryptoAlgo_.getKeyLength()); - - final KeyBlob keyBlob_ = new KeyBlob(keyProviderId_, keyProviderInfo_, encryptedKey_); - - byte[] headerNonce_ = RandomBytesGenerator.generate(nonceLen_); - byte[] headerTag_ = RandomBytesGenerator.generate(tagLenBytes_); - - @Test - public void serializeDeserialize() { - Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Test"); - - for (CryptoAlgorithm alg : testAlgs) { - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); - - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - reconstructedHeaders.deserialize(headerBytes, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); - - assertArrayEquals(headerBytes, reconstructedHeaderBytes); - } - } - - @Test - public void serializeDeserializeStreaming() { - Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - - for (CryptoAlgorithm alg : testAlgs) { - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); - - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - - int totalParsedBytes = 0; - int bytesToParseLen = 1; - int bytesParsed; - - while (reconstructedHeaders.isComplete() == false) { - final byte[] bytesToParse = new byte[bytesToParseLen]; - System.arraycopy(headerBytes, totalParsedBytes, bytesToParse, 0, bytesToParse.length); - - bytesParsed = reconstructedHeaders.deserialize(bytesToParse, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - if (bytesParsed == 0) { - bytesToParseLen++; - } else { - totalParsedBytes += bytesParsed; - bytesToParseLen = 1; - } - } - - final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); - - assertArrayEquals(headerBytes, reconstructedHeaderBytes); - } - } - - @Test - public void deserializeNull() { - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders(); - final int deserializedBytes = ciphertextHeaders.deserialize(null, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - - assertEquals(0, deserializedBytes); - } - - @Test - public void overlyLargeEncryptionContext() { - final int size = Constants.UNSIGNED_SHORT_MAX_VAL + 1; - final byte[] encContextBytes = RandomBytesGenerator.generate(size); - for (CryptoAlgorithm alg : testAlgs) { - assertThrows(AwsCryptoException.class, () -> - createCiphertextHeaders(encContextBytes, alg)); - } - } - - @Test - public void serializeWithNullHeaderNonce() { - for (CryptoAlgorithm alg : testAlgs) { - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders( - type_, - alg, - new byte[0], - Collections.singletonList(keyBlob_), - contentType_, - frameSize_); - ciphertextHeaders.setHeaderTag(headerTag_); - - assertThrows(AwsCryptoException.class, () -> - ciphertextHeaders.toByteArray()); - } - } - - @Test - public void serializeWithNullHeaderTag() { - for (CryptoAlgorithm alg : testAlgs) { - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders( - type_, - alg, - new byte[0], - Collections.singletonList(keyBlob_), - contentType_, - frameSize_); - ciphertextHeaders.setHeaderNonce(headerNonce_); - - assertThrows(AwsCryptoException.class, () -> - ciphertextHeaders.toByteArray()); - } - } - - @Test - public void serializeWithNullSuiteData() { - // Only applicable for V2 algorithms - CryptoAlgorithm alg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders( - type_, - alg, - new byte[0], - Collections.singletonList(keyBlob_), - contentType_, - frameSize_); - ciphertextHeaders.setHeaderTag(headerTag_); - ciphertextHeaders.setHeaderNonce(headerNonce_); - - assertThrows(AwsCryptoException.class, () -> - ciphertextHeaders.toByteArray()); - } - - /* - * @Test - * public void byteFormatCheck() { - * testPlaintextKey_ = ByteFormatCheckValues.getPlaintextKey(); - * testKey_ = new SecretKeySpec(testPlaintextKey_, - * cryptoAlgo_.getKeySpec()); - * encryptedKey_ = ByteFormatCheckValues.getEncryptedKey(); - * dataKey_ = new AWSKMSDataKey(testKey_, encryptedKey_); - * headerNonce_ = ByteFormatCheckValues.getNonce(); - * headerTag_ = ByteFormatCheckValues.getTag(); - * - * Map encryptionContext = new HashMap(1); - * encryptionContext.put("ENC", "CiphertextHeader format check test"); - * - * final CiphertextHeaders ciphertextHeaders = - * createCiphertextHeaders(encryptionContext); - * //NOTE: this test will fail because of the line below. - * //That is, the message id is randomly generated in the ciphertext - * headers. - * messageId_ = ciphertextHeaders.getMessageId(); - * final byte[] ciphertextHeaderHash = - * TestIOUtils.getSha256Hash(ciphertextHeaders.toByteArray()); - * - * assertArrayEquals(ByteFormatCheckValues.getCiphertextHeaderHash(), - * ciphertextHeaderHash); - * } - */ - - private CiphertextHeaders createCiphertextHeaders(final byte[] encryptionContextBytes, CryptoAlgorithm cryptoAlg) { - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders( - type_, - cryptoAlg, - encryptionContextBytes, - Collections.singletonList(keyBlob_), - contentType_, - frameSize_); - - ciphertextHeaders.setHeaderNonce(headerNonce_); - ciphertextHeaders.setHeaderTag(headerTag_); - - if (cryptoAlg.getMessageFormatVersion() == 2) { - ciphertextHeaders.setSuiteData(new byte[cryptoAlg.getSuiteDataLength()]); - } - - return ciphertextHeaders; - } - - private CiphertextHeaders createCiphertextHeaders(final Map encryptionContext, CryptoAlgorithm cryptoAlg) { - byte[] encryptionContextBytes = null; - if (encryptionContext != null) { - encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext); - } - - return createCiphertextHeaders(encryptionContextBytes, cryptoAlg); - } - - @SuppressWarnings("deprecation") - @Test - public void legacyConstructCiphertextHeaders() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - final byte[] encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext); - - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders( - version_, - type_, - cryptoAlgo_, - encryptionContextBytes, - Collections.singletonList(keyBlob_), - contentType_, - frameSize_); - - ciphertextHeaders.setHeaderNonce(headerNonce_); - ciphertextHeaders.setHeaderTag(headerTag_); - assertNotNull(ciphertextHeaders); - } - - @SuppressWarnings("deprecation") - @Test(expected = IllegalArgumentException.class) - public void legacyConstructCiphertextHeadersMismatchedVersion() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - final byte[] encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext); - - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders( - invalidVersion_, - type_, - cryptoAlgo_, - encryptionContextBytes, - Collections.singletonList(keyBlob_), - contentType_, - frameSize_); - } - - @Test - public void checkEncContextLen() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - final byte[] encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext); - final int encryptionContextLen = encryptionContextBytes.length; - - for (CryptoAlgorithm alg: testAlgs) { - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - reconstructedHeaders.deserialize(headerBytes, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - - assertEquals(encryptionContextLen, reconstructedHeaders.getEncryptionContextLen()); - } - } - - @Test - public void checkKeyBlobCount() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - - for (CryptoAlgorithm alg: testAlgs) { - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - reconstructedHeaders.deserialize(headerBytes, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); - - assertEquals(1, reconstructedHeaders.getEncryptedKeyBlobCount()); - } - } - - @Test - public void checkNullMessageId() { - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders(); - - assertEquals(null, ciphertextHeaders.getMessageId()); - } - - @Test - public void checkNullHeaderNonce() { - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders(); - - assertEquals(null, ciphertextHeaders.getHeaderNonce()); - } - - @Test - public void checkNullHeaderTag() { - final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders(); - - assertEquals(null, ciphertextHeaders.getHeaderTag()); - } - - private void readVersion(final ByteBuffer headerBuff) { - headerBuff.get(); - } - - private void readType(final ByteBuffer headerBuff) { - readVersion(headerBuff); - headerBuff.get(); - } - - private void readToAlgoId(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - // Use the message format version to determine where the algorithm Id is in the - // header and how to get to it. - if (cryptoAlg.getMessageFormatVersion() == 1) { - readType(headerBuff); - } else { - readVersion(headerBuff); - } - } - - private void readAlgoId(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - readToAlgoId(headerBuff, cryptoAlg); - headerBuff.getShort(); - } - - private void readEncContext(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - readAlgoId(headerBuff, cryptoAlg); - - // pull out messageId to get to enc context len. - final byte[] msgId = new byte[cryptoAlg.getMessageIdLength()]; - headerBuff.get(msgId); - - // pull out enc context to get to key count. - final int encContextLen = headerBuff.getShort(); - final byte[] encContext = new byte[encContextLen]; - headerBuff.get(encContext); - } - - private void readKeyBlob(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - readEncContext(headerBuff, cryptoAlg); - - headerBuff.getShort(); // get key count - final short keyProviderIdLen = headerBuff.getShort(); - final byte[] keyProviderId = new byte[keyProviderIdLen]; - headerBuff.get(keyProviderId); - final short keyProviderInfoLen = headerBuff.getShort(); - final byte[] keyProviderInfo = new byte[keyProviderInfoLen]; - headerBuff.get(keyProviderInfo); - final short keyLen = headerBuff.getShort(); - final byte[] key = new byte[keyLen]; - headerBuff.get(key); - } - - private void readToContentType(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - readKeyBlob(headerBuff, cryptoAlg); - } - - private void readContentType(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - readToContentType(headerBuff, cryptoAlg); - headerBuff.get(); - } - - private void readToReservedField(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - readContentType(headerBuff, cryptoAlg); - } - - private void readReservedField(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - readToReservedField(headerBuff, cryptoAlg); - headerBuff.getInt(); - } - - private void readToNonceLen(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - readReservedField(headerBuff, cryptoAlg); - } - - private void readNonceLen(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - readToNonceLen(headerBuff, cryptoAlg); - headerBuff.get(); - } - - private void readToFrameLen(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { - // Use the message format version to determine where the frame length is in the - // header and how to get to it. - if (cryptoAlg.getMessageFormatVersion() == 1) { - readNonceLen(headerBuff, cryptoAlg); + final CryptoAlgorithm cryptoAlgo_ = + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384; + final String keyProviderId_ = "None"; + final byte[] keyProviderInfo_ = "TestKeyID".getBytes(); + final byte version_ = cryptoAlgo_.getMessageFormatVersion(); + final byte invalidVersion_ = 0x00; + final CiphertextType type_ = CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA; + final int nonceLen_ = cryptoAlgo_.getNonceLen(); + final int tagLenBytes_ = cryptoAlgo_.getTagLen(); + final ContentType contentType_ = ContentType.FRAME; + final int frameSize_ = AwsCrypto.getDefaultFrameSize(); + + // A set of crypto algs that are representative of the different ciphertext header formats + final List testAlgs = + Arrays.asList( + CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256); + + byte[] encryptedKey_ = RandomBytesGenerator.generate(cryptoAlgo_.getKeyLength()); + + final KeyBlob keyBlob_ = new KeyBlob(keyProviderId_, keyProviderInfo_, encryptedKey_); + + byte[] headerNonce_ = RandomBytesGenerator.generate(nonceLen_); + byte[] headerTag_ = RandomBytesGenerator.generate(tagLenBytes_); + + @Test + public void serializeDeserialize() { + Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Test"); + + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); + + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + reconstructedHeaders.deserialize( + headerBytes, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); + + assertArrayEquals(headerBytes, reconstructedHeaderBytes); + } + } + + @Test + public void serializeDeserializeStreaming() { + Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); + + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + + int totalParsedBytes = 0; + int bytesToParseLen = 1; + int bytesParsed; + + while (reconstructedHeaders.isComplete() == false) { + final byte[] bytesToParse = new byte[bytesToParseLen]; + System.arraycopy(headerBytes, totalParsedBytes, bytesToParse, 0, bytesToParse.length); + + bytesParsed = + reconstructedHeaders.deserialize( + bytesToParse, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + if (bytesParsed == 0) { + bytesToParseLen++; } else { - readContentType(headerBuff, cryptoAlg); + totalParsedBytes += bytesParsed; + bytesToParseLen = 1; } - } - - @Test - public void invalidVersion(){ - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - - for (CryptoAlgorithm alg: testAlgs) { - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); - - //set version to invalid type of 0. - headerBuff.put((byte) 0); - - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - assertThrows(BadCiphertextException.class, "Invalid version", - () -> reconstructedHeaders.deserialize(headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - } - - @Test - public void invalidType() { - // Only applicable for V1 algorithms - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - final CryptoAlgorithm cryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, cryptoAlgorithm); - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); - - readVersion(headerBuff); - - // set type to invalid value of 0. - headerBuff.put((byte) 0); - - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - assertThrows(BadCiphertextException.class, "Invalid ciphertext type", - () -> reconstructedHeaders.deserialize(headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - - @Test - public void invalidAlgoId() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - - for (CryptoAlgorithm alg: testAlgs) { - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); - - readToAlgoId(headerBuff, alg); - - // set algoId to invalid value of 0. - headerBuff.putShort((short) 0); - - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - assertThrows(BadCiphertextException.class, "Invalid algorithm identifier in ciphertext", - () -> reconstructedHeaders.deserialize(headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - } - - @Test - public void invalidContentType() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - - for (CryptoAlgorithm alg: testAlgs) { - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); - - readToContentType(headerBuff, alg); - - // set content type to an invalid value - headerBuff.put((byte) 10); - - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - assertThrows(BadCiphertextException.class, "Invalid content type in ciphertext.", - () -> reconstructedHeaders.deserialize(headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - } - - @Test - public void invalidReservedFieldLen() { - // Only applicable for V1 algorithms - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - final CryptoAlgorithm cryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, cryptoAlgorithm); - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); - - readToReservedField(headerBuff, cryptoAlgorithm); - - // set reserved field to an invalid value - headerBuff.putInt(-1); - - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - assertThrows(BadCiphertextException.class, "Invalid value for reserved field in ciphertext", - () -> reconstructedHeaders.deserialize(headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - - @Test - public void invalidNonceLen() { - // Only applicable for V1 algorithms - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - final CryptoAlgorithm cryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, cryptoAlgorithm); - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); - - readToNonceLen(headerBuff, cryptoAlgorithm); - - // set nonce len to an invalid value - headerBuff.put((byte) -1); - - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - assertThrows(BadCiphertextException.class, "Invalid nonce length in ciphertext", - () -> reconstructedHeaders.deserialize(headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - - @Test - public void invalidFrameLength() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); - - for (CryptoAlgorithm alg : testAlgs) { - final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); - final byte[] headerBytes = ciphertextHeaders.toByteArray(); - final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); - - readToFrameLen(headerBuff, alg); - - // set frame len to an invalid value - headerBuff.putInt(-1); - - final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); - assertThrows(BadCiphertextException.class, "Invalid frame length in ciphertext", - () -> reconstructedHeaders.deserialize(headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); - } - } + } + + final byte[] reconstructedHeaderBytes = reconstructedHeaders.toByteArray(); + + assertArrayEquals(headerBytes, reconstructedHeaderBytes); + } + } + + @Test + public void deserializeNull() { + final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders(); + final int deserializedBytes = + ciphertextHeaders.deserialize(null, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + + assertEquals(0, deserializedBytes); + } + + @Test + public void overlyLargeEncryptionContext() { + final int size = Constants.UNSIGNED_SHORT_MAX_VAL + 1; + final byte[] encContextBytes = RandomBytesGenerator.generate(size); + for (CryptoAlgorithm alg : testAlgs) { + assertThrows(AwsCryptoException.class, () -> createCiphertextHeaders(encContextBytes, alg)); + } + } + + @Test + public void serializeWithNullHeaderNonce() { + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = + new CiphertextHeaders( + type_, + alg, + new byte[0], + Collections.singletonList(keyBlob_), + contentType_, + frameSize_); + ciphertextHeaders.setHeaderTag(headerTag_); + + assertThrows(AwsCryptoException.class, () -> ciphertextHeaders.toByteArray()); + } + } + + @Test + public void serializeWithNullHeaderTag() { + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = + new CiphertextHeaders( + type_, + alg, + new byte[0], + Collections.singletonList(keyBlob_), + contentType_, + frameSize_); + ciphertextHeaders.setHeaderNonce(headerNonce_); + + assertThrows(AwsCryptoException.class, () -> ciphertextHeaders.toByteArray()); + } + } + + @Test + public void serializeWithNullSuiteData() { + // Only applicable for V2 algorithms + CryptoAlgorithm alg = CryptoAlgorithm.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY; + final CiphertextHeaders ciphertextHeaders = + new CiphertextHeaders( + type_, alg, new byte[0], Collections.singletonList(keyBlob_), contentType_, frameSize_); + ciphertextHeaders.setHeaderTag(headerTag_); + ciphertextHeaders.setHeaderNonce(headerNonce_); + + assertThrows(AwsCryptoException.class, () -> ciphertextHeaders.toByteArray()); + } + + /* + * @Test + * public void byteFormatCheck() { + * testPlaintextKey_ = ByteFormatCheckValues.getPlaintextKey(); + * testKey_ = new SecretKeySpec(testPlaintextKey_, + * cryptoAlgo_.getKeySpec()); + * encryptedKey_ = ByteFormatCheckValues.getEncryptedKey(); + * dataKey_ = new AWSKMSDataKey(testKey_, encryptedKey_); + * headerNonce_ = ByteFormatCheckValues.getNonce(); + * headerTag_ = ByteFormatCheckValues.getTag(); + * + * Map encryptionContext = new HashMap(1); + * encryptionContext.put("ENC", "CiphertextHeader format check test"); + * + * final CiphertextHeaders ciphertextHeaders = + * createCiphertextHeaders(encryptionContext); + * //NOTE: this test will fail because of the line below. + * //That is, the message id is randomly generated in the ciphertext + * headers. + * messageId_ = ciphertextHeaders.getMessageId(); + * final byte[] ciphertextHeaderHash = + * TestIOUtils.getSha256Hash(ciphertextHeaders.toByteArray()); + * + * assertArrayEquals(ByteFormatCheckValues.getCiphertextHeaderHash(), + * ciphertextHeaderHash); + * } + */ + + private CiphertextHeaders createCiphertextHeaders( + final byte[] encryptionContextBytes, CryptoAlgorithm cryptoAlg) { + final CiphertextHeaders ciphertextHeaders = + new CiphertextHeaders( + type_, + cryptoAlg, + encryptionContextBytes, + Collections.singletonList(keyBlob_), + contentType_, + frameSize_); + + ciphertextHeaders.setHeaderNonce(headerNonce_); + ciphertextHeaders.setHeaderTag(headerTag_); + + if (cryptoAlg.getMessageFormatVersion() == 2) { + ciphertextHeaders.setSuiteData(new byte[cryptoAlg.getSuiteDataLength()]); + } + + return ciphertextHeaders; + } + + private CiphertextHeaders createCiphertextHeaders( + final Map encryptionContext, CryptoAlgorithm cryptoAlg) { + byte[] encryptionContextBytes = null; + if (encryptionContext != null) { + encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext); + } + + return createCiphertextHeaders(encryptionContextBytes, cryptoAlg); + } + + @SuppressWarnings("deprecation") + @Test + public void legacyConstructCiphertextHeaders() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + final byte[] encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext); + + final CiphertextHeaders ciphertextHeaders = + new CiphertextHeaders( + version_, + type_, + cryptoAlgo_, + encryptionContextBytes, + Collections.singletonList(keyBlob_), + contentType_, + frameSize_); + + ciphertextHeaders.setHeaderNonce(headerNonce_); + ciphertextHeaders.setHeaderTag(headerTag_); + assertNotNull(ciphertextHeaders); + } + + @SuppressWarnings("deprecation") + @Test(expected = IllegalArgumentException.class) + public void legacyConstructCiphertextHeadersMismatchedVersion() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + final byte[] encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext); + + final CiphertextHeaders ciphertextHeaders = + new CiphertextHeaders( + invalidVersion_, + type_, + cryptoAlgo_, + encryptionContextBytes, + Collections.singletonList(keyBlob_), + contentType_, + frameSize_); + } + + @Test + public void checkEncContextLen() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + final byte[] encryptionContextBytes = EncryptionContextSerializer.serialize(encryptionContext); + final int encryptionContextLen = encryptionContextBytes.length; + + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + reconstructedHeaders.deserialize( + headerBytes, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + + assertEquals(encryptionContextLen, reconstructedHeaders.getEncryptionContextLen()); + } + } + + @Test + public void checkKeyBlobCount() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + reconstructedHeaders.deserialize( + headerBytes, 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS); + + assertEquals(1, reconstructedHeaders.getEncryptedKeyBlobCount()); + } + } + + @Test + public void checkNullMessageId() { + final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders(); + + assertEquals(null, ciphertextHeaders.getMessageId()); + } + + @Test + public void checkNullHeaderNonce() { + final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders(); + + assertEquals(null, ciphertextHeaders.getHeaderNonce()); + } + + @Test + public void checkNullHeaderTag() { + final CiphertextHeaders ciphertextHeaders = new CiphertextHeaders(); + + assertEquals(null, ciphertextHeaders.getHeaderTag()); + } + + private void readVersion(final ByteBuffer headerBuff) { + headerBuff.get(); + } + + private void readType(final ByteBuffer headerBuff) { + readVersion(headerBuff); + headerBuff.get(); + } + + private void readToAlgoId(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + // Use the message format version to determine where the algorithm Id is in the + // header and how to get to it. + if (cryptoAlg.getMessageFormatVersion() == 1) { + readType(headerBuff); + } else { + readVersion(headerBuff); + } + } + + private void readAlgoId(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + readToAlgoId(headerBuff, cryptoAlg); + headerBuff.getShort(); + } + + private void readEncContext(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + readAlgoId(headerBuff, cryptoAlg); + + // pull out messageId to get to enc context len. + final byte[] msgId = new byte[cryptoAlg.getMessageIdLength()]; + headerBuff.get(msgId); + + // pull out enc context to get to key count. + final int encContextLen = headerBuff.getShort(); + final byte[] encContext = new byte[encContextLen]; + headerBuff.get(encContext); + } + + private void readKeyBlob(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + readEncContext(headerBuff, cryptoAlg); + + headerBuff.getShort(); // get key count + final short keyProviderIdLen = headerBuff.getShort(); + final byte[] keyProviderId = new byte[keyProviderIdLen]; + headerBuff.get(keyProviderId); + final short keyProviderInfoLen = headerBuff.getShort(); + final byte[] keyProviderInfo = new byte[keyProviderInfoLen]; + headerBuff.get(keyProviderInfo); + final short keyLen = headerBuff.getShort(); + final byte[] key = new byte[keyLen]; + headerBuff.get(key); + } + + private void readToContentType(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + readKeyBlob(headerBuff, cryptoAlg); + } + + private void readContentType(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + readToContentType(headerBuff, cryptoAlg); + headerBuff.get(); + } + + private void readToReservedField(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + readContentType(headerBuff, cryptoAlg); + } + + private void readReservedField(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + readToReservedField(headerBuff, cryptoAlg); + headerBuff.getInt(); + } + + private void readToNonceLen(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + readReservedField(headerBuff, cryptoAlg); + } + + private void readNonceLen(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + readToNonceLen(headerBuff, cryptoAlg); + headerBuff.get(); + } + + private void readToFrameLen(final ByteBuffer headerBuff, final CryptoAlgorithm cryptoAlg) { + // Use the message format version to determine where the frame length is in the + // header and how to get to it. + if (cryptoAlg.getMessageFormatVersion() == 1) { + readNonceLen(headerBuff, cryptoAlg); + } else { + readContentType(headerBuff, cryptoAlg); + } + } + + @Test + public void invalidVersion() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); + + // set version to invalid type of 0. + headerBuff.put((byte) 0); + + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + assertThrows( + BadCiphertextException.class, + "Invalid version", + () -> + reconstructedHeaders.deserialize( + headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + } + + @Test + public void invalidType() { + // Only applicable for V1 algorithms + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + final CryptoAlgorithm cryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + + final CiphertextHeaders ciphertextHeaders = + createCiphertextHeaders(encryptionContext, cryptoAlgorithm); + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); + + readVersion(headerBuff); + + // set type to invalid value of 0. + headerBuff.put((byte) 0); + + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + assertThrows( + BadCiphertextException.class, + "Invalid ciphertext type", + () -> + reconstructedHeaders.deserialize( + headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + + @Test + public void invalidAlgoId() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); + + readToAlgoId(headerBuff, alg); + + // set algoId to invalid value of 0. + headerBuff.putShort((short) 0); + + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + assertThrows( + BadCiphertextException.class, + "Invalid algorithm identifier in ciphertext", + () -> + reconstructedHeaders.deserialize( + headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + } + + @Test + public void invalidContentType() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); + + readToContentType(headerBuff, alg); + + // set content type to an invalid value + headerBuff.put((byte) 10); + + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + assertThrows( + BadCiphertextException.class, + "Invalid content type in ciphertext.", + () -> + reconstructedHeaders.deserialize( + headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + } + + @Test + public void invalidReservedFieldLen() { + // Only applicable for V1 algorithms + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + final CryptoAlgorithm cryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + + final CiphertextHeaders ciphertextHeaders = + createCiphertextHeaders(encryptionContext, cryptoAlgorithm); + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); + + readToReservedField(headerBuff, cryptoAlgorithm); + + // set reserved field to an invalid value + headerBuff.putInt(-1); + + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + assertThrows( + BadCiphertextException.class, + "Invalid value for reserved field in ciphertext", + () -> + reconstructedHeaders.deserialize( + headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + + @Test + public void invalidNonceLen() { + // Only applicable for V1 algorithms + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + final CryptoAlgorithm cryptoAlgorithm = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + + final CiphertextHeaders ciphertextHeaders = + createCiphertextHeaders(encryptionContext, cryptoAlgorithm); + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); + + readToNonceLen(headerBuff, cryptoAlgorithm); + + // set nonce len to an invalid value + headerBuff.put((byte) -1); + + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + assertThrows( + BadCiphertextException.class, + "Invalid nonce length in ciphertext", + () -> + reconstructedHeaders.deserialize( + headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + + @Test + public void invalidFrameLength() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "CiphertextHeader Streaming Test"); + + for (CryptoAlgorithm alg : testAlgs) { + final CiphertextHeaders ciphertextHeaders = createCiphertextHeaders(encryptionContext, alg); + final byte[] headerBytes = ciphertextHeaders.toByteArray(); + final ByteBuffer headerBuff = ByteBuffer.wrap(headerBytes); + + readToFrameLen(headerBuff, alg); + + // set frame len to an invalid value + headerBuff.putInt(-1); + + final CiphertextHeaders reconstructedHeaders = new CiphertextHeaders(); + assertThrows( + BadCiphertextException.class, + "Invalid frame length in ciphertext", + () -> + reconstructedHeaders.deserialize( + headerBuff.array(), 0, CiphertextHeaders.NO_MAX_ENCRYPTED_DATA_KEYS)); + } + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/model/DecryptionMaterialsRequestTest.java b/src/test/java/com/amazonaws/encryptionsdk/model/DecryptionMaterialsRequestTest.java index 9f04304aa..5aa4c9d55 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/model/DecryptionMaterialsRequestTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/model/DecryptionMaterialsRequestTest.java @@ -1,11 +1,11 @@ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -15,35 +15,32 @@ import static org.junit.Assert.assertEquals; +import com.amazonaws.encryptionsdk.CryptoAlgorithm; +import java.util.ArrayList; import java.util.HashMap; -import java.util.Map; import java.util.List; -import java.util.ArrayList; - +import java.util.Map; import org.junit.Test; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; -import com.amazonaws.encryptionsdk.model.KeyBlob; - public class DecryptionMaterialsRequestTest { - @Test - public void build() { - CryptoAlgorithm alg = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; - Map encryptionContext = new HashMap(1); - encryptionContext.put("DMR", "DecryptionMaterialsRequest Test"); - List kbs = new ArrayList(); - - DecryptionMaterialsRequest request0 = DecryptionMaterialsRequest.newBuilder() + @Test + public void build() { + CryptoAlgorithm alg = CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256; + Map encryptionContext = new HashMap(1); + encryptionContext.put("DMR", "DecryptionMaterialsRequest Test"); + List kbs = new ArrayList(); + + DecryptionMaterialsRequest request0 = + DecryptionMaterialsRequest.newBuilder() .setAlgorithm(alg) .setEncryptionContext(encryptionContext) .setEncryptedDataKeys(kbs) .build(); - - DecryptionMaterialsRequest request1 = request0.toBuilder().build(); - assertEquals(request0.getAlgorithm(), request1.getAlgorithm()); - assertEquals(request0.getEncryptionContext().size(), request1.getEncryptionContext().size()); - assertEquals(request0.getEncryptedDataKeys().size(), request1.getEncryptedDataKeys().size()); - } + DecryptionMaterialsRequest request1 = request0.toBuilder().build(); + + assertEquals(request0.getAlgorithm(), request1.getAlgorithm()); + assertEquals(request0.getEncryptionContext().size(), request1.getEncryptionContext().size()); + assertEquals(request0.getEncryptedDataKeys().size(), request1.getEncryptedDataKeys().size()); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/model/EncryptionMaterialsRequestTest.java b/src/test/java/com/amazonaws/encryptionsdk/model/EncryptionMaterialsRequestTest.java index 1fe777d85..7678e4758 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/model/EncryptionMaterialsRequestTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/model/EncryptionMaterialsRequestTest.java @@ -11,22 +11,23 @@ public class EncryptionMaterialsRequestTest { - @Test(expected = IllegalArgumentException.class) - public void testConstructWithNullCommitmentPolicy() { - EncryptionMaterialsRequest.newBuilder().setCommitmentPolicy(null).build(); - } + @Test(expected = IllegalArgumentException.class) + public void testConstructWithNullCommitmentPolicy() { + EncryptionMaterialsRequest.newBuilder().setCommitmentPolicy(null).build(); + } - @Test(expected = IllegalArgumentException.class) - public void testConstructWithoutCommitmentPolicy() { - EncryptionMaterialsRequest.newBuilder().build(); - } + @Test(expected = IllegalArgumentException.class) + public void testConstructWithoutCommitmentPolicy() { + EncryptionMaterialsRequest.newBuilder().build(); + } - @Test - public void testConstructWithCommitmentPolicy() { - EncryptionMaterialsRequest req = EncryptionMaterialsRequest.newBuilder() - .setCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) - .build(); - assertNotNull(req); - assertEquals(CommitmentPolicy.ForbidEncryptAllowDecrypt, req.getCommitmentPolicy()); - } + @Test + public void testConstructWithCommitmentPolicy() { + EncryptionMaterialsRequest req = + EncryptionMaterialsRequest.newBuilder() + .setCommitmentPolicy(CommitmentPolicy.ForbidEncryptAllowDecrypt) + .build(); + assertNotNull(req); + assertEquals(CommitmentPolicy.ForbidEncryptAllowDecrypt, req.getCommitmentPolicy()); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/model/KeyBlobTest.java b/src/test/java/com/amazonaws/encryptionsdk/model/KeyBlobTest.java index d66f49657..3e91babff 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/model/KeyBlobTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/model/KeyBlobTest.java @@ -1,11 +1,11 @@ /* * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except * in compliance with the License. A copy of the License is located at - * + * * http://aws.amazon.com/apache2.0 - * + * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. @@ -16,204 +16,217 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - -import org.junit.Before; -import org.junit.Test; - import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.DataKey; import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.internal.Constants; import com.amazonaws.encryptionsdk.internal.RandomBytesGenerator; import com.amazonaws.encryptionsdk.internal.StaticMasterKey; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; public class KeyBlobTest { - private static CryptoAlgorithm ALGORITHM = CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF; - final String providerId_ = "Test Key"; - final String providerInfo_ = "Test Info"; - private StaticMasterKey masterKeyProvider_; - - @Before - public void init() { - masterKeyProvider_ = new StaticMasterKey("testmaterial"); - } - - private byte[] createKeyBlobBytes() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "Test Encryption Context"); - final DataKey mockDataKey_ = masterKeyProvider_.generateDataKey(ALGORITHM, encryptionContext); - - final KeyBlob keyBlob = new KeyBlob( - providerId_, - providerInfo_.getBytes(StandardCharsets.UTF_8), - mockDataKey_.getEncryptedDataKey()); - - return keyBlob.toByteArray(); - } - - private KeyBlob deserialize(final byte[] keyBlobBytes) { - final KeyBlob reconstructedKeyBlob = new KeyBlob(); - reconstructedKeyBlob.deserialize(keyBlobBytes, 0); - return reconstructedKeyBlob; - } - - @Test - public void serializeDeserialize() { - final byte[] keyBlobBytes = createKeyBlobBytes(); + private static CryptoAlgorithm ALGORITHM = CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF; + final String providerId_ = "Test Key"; + final String providerInfo_ = "Test Info"; + private StaticMasterKey masterKeyProvider_; + + @Before + public void init() { + masterKeyProvider_ = new StaticMasterKey("testmaterial"); + } + + private byte[] createKeyBlobBytes() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "Test Encryption Context"); + final DataKey mockDataKey_ = + masterKeyProvider_.generateDataKey(ALGORITHM, encryptionContext); + + final KeyBlob keyBlob = + new KeyBlob( + providerId_, + providerInfo_.getBytes(StandardCharsets.UTF_8), + mockDataKey_.getEncryptedDataKey()); + + return keyBlob.toByteArray(); + } + + private KeyBlob deserialize(final byte[] keyBlobBytes) { + final KeyBlob reconstructedKeyBlob = new KeyBlob(); + reconstructedKeyBlob.deserialize(keyBlobBytes, 0); + return reconstructedKeyBlob; + } + + @Test + public void serializeDeserialize() { + final byte[] keyBlobBytes = createKeyBlobBytes(); + + final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); + final byte[] reconstructedKeyBlobBytes = reconstructedKeyBlob.toByteArray(); + + assertArrayEquals(reconstructedKeyBlobBytes, keyBlobBytes); + } + + @Test(expected = AwsCryptoException.class) + public void overlyLargeKeyProviderIdLen() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "Test Encryption Context"); + + final DataKey mockDataKey = + masterKeyProvider_.generateDataKey(ALGORITHM, encryptionContext); + + final int providerId_Len = Constants.UNSIGNED_SHORT_MAX_VAL + 1; + final byte[] providerId_Bytes = RandomBytesGenerator.generate(providerId_Len); + final String providerId_ = new String(providerId_Bytes, StandardCharsets.UTF_8); - final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); - final byte[] reconstructedKeyBlobBytes = reconstructedKeyBlob.toByteArray(); - - assertArrayEquals(reconstructedKeyBlobBytes, keyBlobBytes); - } - - @Test(expected = AwsCryptoException.class) - public void overlyLargeKeyProviderIdLen() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "Test Encryption Context"); - - final DataKey mockDataKey = masterKeyProvider_.generateDataKey(ALGORITHM, encryptionContext); - - final int providerId_Len = Constants.UNSIGNED_SHORT_MAX_VAL + 1; - final byte[] providerId_Bytes = RandomBytesGenerator.generate(providerId_Len); - final String providerId_ = new String(providerId_Bytes, StandardCharsets.UTF_8); - - final String providerInfo_ = "Test Info"; - - new KeyBlob(providerId_, providerInfo_.getBytes(StandardCharsets.UTF_8), mockDataKey.getEncryptedDataKey()); + final String providerInfo_ = "Test Info"; - } + new KeyBlob( + providerId_, + providerInfo_.getBytes(StandardCharsets.UTF_8), + mockDataKey.getEncryptedDataKey()); + } - @Test(expected = AwsCryptoException.class) - public void overlyLargeKeyProviderInfoLen() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "Test Encryption Context"); + @Test(expected = AwsCryptoException.class) + public void overlyLargeKeyProviderInfoLen() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "Test Encryption Context"); - final DataKey mockDataKey = masterKeyProvider_.generateDataKey(ALGORITHM, encryptionContext); + final DataKey mockDataKey = + masterKeyProvider_.generateDataKey(ALGORITHM, encryptionContext); - final int providerInfo_Len = Constants.UNSIGNED_SHORT_MAX_VAL + 1; - final byte[] providerInfo_ = RandomBytesGenerator.generate(providerInfo_Len); + final int providerInfo_Len = Constants.UNSIGNED_SHORT_MAX_VAL + 1; + final byte[] providerInfo_ = RandomBytesGenerator.generate(providerInfo_Len); - new KeyBlob(providerId_, providerInfo_, mockDataKey.getEncryptedDataKey()); - } + new KeyBlob(providerId_, providerInfo_, mockDataKey.getEncryptedDataKey()); + } - @Test(expected = AwsCryptoException.class) - public void overlyLargeKey() { - final int keyLen = Constants.UNSIGNED_SHORT_MAX_VAL + 1; - final byte[] encryptedKeyBytes = RandomBytesGenerator.generate(keyLen); + @Test(expected = AwsCryptoException.class) + public void overlyLargeKey() { + final int keyLen = Constants.UNSIGNED_SHORT_MAX_VAL + 1; + final byte[] encryptedKeyBytes = RandomBytesGenerator.generate(keyLen); - new KeyBlob(providerId_, providerInfo_.getBytes(StandardCharsets.UTF_8), encryptedKeyBytes); - } + new KeyBlob(providerId_, providerInfo_.getBytes(StandardCharsets.UTF_8), encryptedKeyBytes); + } - @Test - public void deserializeNull() { - final KeyBlob keyBlob = new KeyBlob(); - final int deserializedBytes = keyBlob.deserialize(null, 0); + @Test + public void deserializeNull() { + final KeyBlob keyBlob = new KeyBlob(); + final int deserializedBytes = keyBlob.deserialize(null, 0); - assertEquals(0, deserializedBytes); - } + assertEquals(0, deserializedBytes); + } - @Test - public void checkKeyProviderIdLen() { - final byte[] keyBlobBytes = createKeyBlobBytes(); + @Test + public void checkKeyProviderIdLen() { + final byte[] keyBlobBytes = createKeyBlobBytes(); - final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); + final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); - assertEquals(providerId_.length(), reconstructedKeyBlob.getKeyProviderIdLen()); - } + assertEquals(providerId_.length(), reconstructedKeyBlob.getKeyProviderIdLen()); + } - @Test - public void checkKeyProviderId() { - final byte[] keyBlobBytes = createKeyBlobBytes(); + @Test + public void checkKeyProviderId() { + final byte[] keyBlobBytes = createKeyBlobBytes(); - final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); + final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); - assertArrayEquals(providerId_.getBytes(StandardCharsets.UTF_8), reconstructedKeyBlob - .getProviderId() - .getBytes(StandardCharsets.UTF_8)); - } + assertArrayEquals( + providerId_.getBytes(StandardCharsets.UTF_8), + reconstructedKeyBlob.getProviderId().getBytes(StandardCharsets.UTF_8)); + } - @Test - public void checkKeyProviderInfoLen() { - final byte[] keyBlobBytes = createKeyBlobBytes(); + @Test + public void checkKeyProviderInfoLen() { + final byte[] keyBlobBytes = createKeyBlobBytes(); - final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); + final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); - assertEquals(providerInfo_.length(), reconstructedKeyBlob.getKeyProviderInfoLen()); - } + assertEquals(providerInfo_.length(), reconstructedKeyBlob.getKeyProviderInfoLen()); + } - @Test - public void checkKeyProviderInfo() { - final byte[] keyBlobBytes = createKeyBlobBytes(); + @Test + public void checkKeyProviderInfo() { + final byte[] keyBlobBytes = createKeyBlobBytes(); - final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); + final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); - assertArrayEquals(providerInfo_.getBytes(StandardCharsets.UTF_8), reconstructedKeyBlob.getProviderInformation()); - } + assertArrayEquals( + providerInfo_.getBytes(StandardCharsets.UTF_8), + reconstructedKeyBlob.getProviderInformation()); + } - @Test - public void checkKeyLen() { - final Map encryptionContext = new HashMap(1); - encryptionContext.put("ENC", "Test Encryption Context"); - final DataKey mockDataKey_ = masterKeyProvider_.generateDataKey(ALGORITHM, encryptionContext); + @Test + public void checkKeyLen() { + final Map encryptionContext = new HashMap(1); + encryptionContext.put("ENC", "Test Encryption Context"); + final DataKey mockDataKey_ = + masterKeyProvider_.generateDataKey(ALGORITHM, encryptionContext); - final KeyBlob keyBlob = new KeyBlob( - providerId_, - providerInfo_.getBytes(StandardCharsets.UTF_8), - mockDataKey_.getEncryptedDataKey()); + final KeyBlob keyBlob = + new KeyBlob( + providerId_, + providerInfo_.getBytes(StandardCharsets.UTF_8), + mockDataKey_.getEncryptedDataKey()); - final byte[] keyBlobBytes = keyBlob.toByteArray(); + final byte[] keyBlobBytes = keyBlob.toByteArray(); - final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); + final KeyBlob reconstructedKeyBlob = deserialize(keyBlobBytes); - assertEquals(mockDataKey_.getEncryptedDataKey().length, reconstructedKeyBlob.getEncryptedDataKeyLen()); - } + assertEquals( + mockDataKey_.getEncryptedDataKey().length, reconstructedKeyBlob.getEncryptedDataKeyLen()); + } - private KeyBlob generateRandomKeyBlob(int idLen, int infoLen, int keyLen) { - final byte[] idBytes = new byte[idLen]; - Arrays.fill(idBytes, (byte) 'A'); + private KeyBlob generateRandomKeyBlob(int idLen, int infoLen, int keyLen) { + final byte[] idBytes = new byte[idLen]; + Arrays.fill(idBytes, (byte) 'A'); + + final byte[] infoBytes = RandomBytesGenerator.generate(infoLen); + final byte[] keyBytes = RandomBytesGenerator.generate(keyLen); - final byte[] infoBytes = RandomBytesGenerator.generate(infoLen); - final byte[] keyBytes = RandomBytesGenerator.generate(keyLen); + return new KeyBlob(new String(idBytes, StandardCharsets.UTF_8), infoBytes, keyBytes); + } - return new KeyBlob(new String(idBytes, StandardCharsets.UTF_8), infoBytes, keyBytes); - } + private void assertKeyBlobsEqual(KeyBlob b1, KeyBlob b2) { + assertArrayEquals( + b1.getProviderId().getBytes(StandardCharsets.UTF_8), + b2.getProviderId().getBytes(StandardCharsets.UTF_8)); + assertArrayEquals(b1.getProviderInformation(), b2.getProviderInformation()); + assertArrayEquals(b1.getEncryptedDataKey(), b2.getEncryptedDataKey()); + } - private void assertKeyBlobsEqual(KeyBlob b1, KeyBlob b2) { - assertArrayEquals(b1.getProviderId().getBytes(StandardCharsets.UTF_8), - b2.getProviderId().getBytes(StandardCharsets.UTF_8)); - assertArrayEquals(b1.getProviderInformation(), b2.getProviderInformation()); - assertArrayEquals(b1.getEncryptedDataKey(), b2.getEncryptedDataKey()); - } - - @Test - public void checkKeyProviderIdLenUnsigned() { - // provider id length is too large for a signed short but fits in unsigned - final KeyBlob blob = generateRandomKeyBlob(Constants.UNSIGNED_SHORT_MAX_VAL, Short.MAX_VALUE, Short.MAX_VALUE); - final byte[] arr = blob.toByteArray(); + @Test + public void checkKeyProviderIdLenUnsigned() { + // provider id length is too large for a signed short but fits in unsigned + final KeyBlob blob = + generateRandomKeyBlob(Constants.UNSIGNED_SHORT_MAX_VAL, Short.MAX_VALUE, Short.MAX_VALUE); + final byte[] arr = blob.toByteArray(); - assertKeyBlobsEqual(deserialize(arr), blob); - } + assertKeyBlobsEqual(deserialize(arr), blob); + } - @Test - public void checkKeyProviderInfoLenUnsigned() { - // provider info length is too large for a signed short but fits in unsigned - final KeyBlob blob = generateRandomKeyBlob(Short.MAX_VALUE, Constants.UNSIGNED_SHORT_MAX_VAL, Short.MAX_VALUE); - final byte[] arr = blob.toByteArray(); + @Test + public void checkKeyProviderInfoLenUnsigned() { + // provider info length is too large for a signed short but fits in unsigned + final KeyBlob blob = + generateRandomKeyBlob(Short.MAX_VALUE, Constants.UNSIGNED_SHORT_MAX_VAL, Short.MAX_VALUE); + final byte[] arr = blob.toByteArray(); - assertKeyBlobsEqual(deserialize(arr), blob); - } + assertKeyBlobsEqual(deserialize(arr), blob); + } - @Test - public void checkKeyLenUnsigned() { - // key length is too large for a signed short but fits in unsigned - final KeyBlob blob = generateRandomKeyBlob(Short.MAX_VALUE, Short.MAX_VALUE, Constants.UNSIGNED_SHORT_MAX_VAL); - final byte[] arr = blob.toByteArray(); + @Test + public void checkKeyLenUnsigned() { + // key length is too large for a signed short but fits in unsigned + final KeyBlob blob = + generateRandomKeyBlob(Short.MAX_VALUE, Short.MAX_VALUE, Constants.UNSIGNED_SHORT_MAX_VAL); + final byte[] arr = blob.toByteArray(); - assertKeyBlobsEqual(deserialize(arr), blob); - } + assertKeyBlobsEqual(deserialize(arr), blob); + } } diff --git a/src/test/java/com/amazonaws/encryptionsdk/multi/MultipleMasterKeyTest.java b/src/test/java/com/amazonaws/encryptionsdk/multi/MultipleMasterKeyTest.java index 8f793b7ef..918dec4f2 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/multi/MultipleMasterKeyTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/multi/MultipleMasterKeyTest.java @@ -4,117 +4,114 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import javax.crypto.spec.SecretKeySpec; - -import org.junit.Test; - import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CryptoResult; import com.amazonaws.encryptionsdk.MasterKey; import com.amazonaws.encryptionsdk.MasterKeyProvider; import com.amazonaws.encryptionsdk.internal.StaticMasterKey; import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import com.amazonaws.encryptionsdk.CommitmentPolicy; +import javax.crypto.spec.SecretKeySpec; +import org.junit.Test; public class MultipleMasterKeyTest { - private static final String WRAPPING_ALG = "AES/GCM/NoPadding"; - private static final byte[] PLAINTEXT = generate(1024); - - @Test - public void testMultipleJceKeys() { - final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); - final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); - final SecretKeySpec k2 = new SecretKeySpec(generate(32), "AES"); - final JceMasterKey mk2 = JceMasterKey.getInstance(k2, "jce", "2", WRAPPING_ALG); - final MasterKeyProvider mkp = MultipleProviderFactory.buildMultiProvider(JceMasterKey.class, - mk1, mk2); - - AwsCrypto crypto = AwsCrypto.standard(); - CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(2, ct.getMasterKeyIds().size()); - CryptoResult result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk1, result.getMasterKeys().get(0)); - - assertMultiReturnsKeys(mkp, mk1, mk2); - } - - @Test - public void testMultipleJceKeysSingleDecrypt() { - final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); - final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); - final SecretKeySpec k2 = new SecretKeySpec(generate(32), "AES"); - final JceMasterKey mk2 = JceMasterKey.getInstance(k2, "jce", "2", WRAPPING_ALG); - final MasterKeyProvider mkp = MultipleProviderFactory.buildMultiProvider(JceMasterKey.class, - mk1, mk2); - - AwsCrypto crypto = AwsCrypto.standard(); - CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(2, ct.getMasterKeyIds().size()); - - CryptoResult result = crypto.decryptData(mk1, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk1, result.getMasterKeys().get(0)); - - result = crypto.decryptData(mk2, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk2, result.getMasterKeys().get(0)); - } - - @Test - public void testMixedKeys() { - final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); - final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); - StaticMasterKey mk2 = new StaticMasterKey("mock1"); - final MasterKeyProvider mkp = MultipleProviderFactory.buildMultiProvider(mk1, mk2); - - AwsCrypto crypto = AwsCrypto.standard(); - CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(2, ct.getMasterKeyIds().size()); - CryptoResult result = crypto.decryptData(mkp, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk1, result.getMasterKeys().get(0)); - - assertMultiReturnsKeys(mkp, mk1, mk2); - } - - @Test - public void testMixedKeysSingleDecrypt() { - final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); - final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); - StaticMasterKey mk2 = new StaticMasterKey("mock1"); - - final MasterKeyProvider mkp = MultipleProviderFactory.buildMultiProvider(mk1, mk2); - - AwsCrypto crypto = AwsCrypto.standard(); - CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); - assertEquals(2, ct.getMasterKeyIds().size()); - - CryptoResult result = crypto.decryptData(mk1, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk1, result.getMasterKeys().get(0)); - - result = crypto.decryptData(mk2, ct.getResult()); - assertArrayEquals(PLAINTEXT, result.getResult()); - // Only the first found key should be used - assertEquals(1, result.getMasterKeys().size()); - assertEquals(mk2, result.getMasterKeys().get(0)); - } - - private void assertMultiReturnsKeys(MasterKeyProvider mkp, MasterKey... mks) { - for (MasterKey mk : mks) { - assertEquals(mk, mkp.getMasterKey(mk.getKeyId())); - assertEquals(mk, mkp.getMasterKey(mk.getProviderId(), mk.getKeyId())); - } + private static final String WRAPPING_ALG = "AES/GCM/NoPadding"; + private static final byte[] PLAINTEXT = generate(1024); + + @Test + public void testMultipleJceKeys() { + final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); + final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); + final SecretKeySpec k2 = new SecretKeySpec(generate(32), "AES"); + final JceMasterKey mk2 = JceMasterKey.getInstance(k2, "jce", "2", WRAPPING_ALG); + final MasterKeyProvider mkp = + MultipleProviderFactory.buildMultiProvider(JceMasterKey.class, mk1, mk2); + + AwsCrypto crypto = AwsCrypto.standard(); + CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(2, ct.getMasterKeyIds().size()); + CryptoResult result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk1, result.getMasterKeys().get(0)); + + assertMultiReturnsKeys(mkp, mk1, mk2); + } + + @Test + public void testMultipleJceKeysSingleDecrypt() { + final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); + final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); + final SecretKeySpec k2 = new SecretKeySpec(generate(32), "AES"); + final JceMasterKey mk2 = JceMasterKey.getInstance(k2, "jce", "2", WRAPPING_ALG); + final MasterKeyProvider mkp = + MultipleProviderFactory.buildMultiProvider(JceMasterKey.class, mk1, mk2); + + AwsCrypto crypto = AwsCrypto.standard(); + CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(2, ct.getMasterKeyIds().size()); + + CryptoResult result = crypto.decryptData(mk1, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk1, result.getMasterKeys().get(0)); + + result = crypto.decryptData(mk2, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk2, result.getMasterKeys().get(0)); + } + + @Test + public void testMixedKeys() { + final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); + final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); + StaticMasterKey mk2 = new StaticMasterKey("mock1"); + final MasterKeyProvider mkp = MultipleProviderFactory.buildMultiProvider(mk1, mk2); + + AwsCrypto crypto = AwsCrypto.standard(); + CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(2, ct.getMasterKeyIds().size()); + CryptoResult result = crypto.decryptData(mkp, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk1, result.getMasterKeys().get(0)); + + assertMultiReturnsKeys(mkp, mk1, mk2); + } + + @Test + public void testMixedKeysSingleDecrypt() { + final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); + final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); + StaticMasterKey mk2 = new StaticMasterKey("mock1"); + + final MasterKeyProvider mkp = MultipleProviderFactory.buildMultiProvider(mk1, mk2); + + AwsCrypto crypto = AwsCrypto.standard(); + CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); + assertEquals(2, ct.getMasterKeyIds().size()); + + CryptoResult result = crypto.decryptData(mk1, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk1, result.getMasterKeys().get(0)); + + result = crypto.decryptData(mk2, ct.getResult()); + assertArrayEquals(PLAINTEXT, result.getResult()); + // Only the first found key should be used + assertEquals(1, result.getMasterKeys().size()); + assertEquals(mk2, result.getMasterKeys().get(0)); + } + + private void assertMultiReturnsKeys(MasterKeyProvider mkp, MasterKey... mks) { + for (MasterKey mk : mks) { + assertEquals(mk, mkp.getMasterKey(mk.getKeyId())); + assertEquals(mk, mkp.getMasterKey(mk.getProviderId(), mk.getKeyId())); } + } }