diff --git a/README.md b/README.md index f6eacf9..84175ca 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,21 @@ You can download release builds through the [releases section of this](https://g * **Minimum requirements** -- You'll need Java 8 (or later) and [Maven 3](http://maven.apache.org/). * **Download** -- Download the [latest preview release](https://github.com/awslabs/large-payload-offloading-java-common-lib-for-aws/releases) or pick it up from Maven: +### Version 2.x +```xml + + software.amazon.payloadoffloading + payloadoffloading-common + 2.0.0 + +``` + +### Version 1.x ```xml software.amazon.payloadoffloading payloadoffloading-common 1.0.0 - jar ``` diff --git a/pom.xml b/pom.xml index 39c11f2..74976a3 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ software.amazon.payloadoffloading payloadoffloading-common - 1.0.0 + 2.0.0 jar Payload offloading common library for AWS Common library between extended Amazon AWS clients to save payloads up to 2GB on Amazon S3. @@ -36,15 +36,21 @@ - 1.11.817 + 2.13.64 - com.amazonaws - aws-java-sdk-s3 + software.amazon.awssdk + s3 ${aws-java-sdk.version} + + software.amazon.awssdk + utils + ${aws-java-sdk.version} + + junit junit diff --git a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java new file mode 100644 index 0000000..ae291f4 --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java @@ -0,0 +1,11 @@ +package software.amazon.payloadoffloading; + +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.ServerSideEncryption; + +public class AwsManagedCmk implements ServerSideEncryptionStrategy { + @Override + public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) { + putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java new file mode 100644 index 0000000..7f62d49 --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java @@ -0,0 +1,18 @@ +package software.amazon.payloadoffloading; + +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.ServerSideEncryption; + +public class CustomerKey implements ServerSideEncryptionStrategy { + private final String awsKmsKeyId; + + public CustomerKey(String awsKmsKeyId) { + this.awsKmsKeyId = awsKmsKeyId; + } + + @Override + public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) { + putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + putObjectRequestBuilder.ssekmsKeyId(awsKmsKeyId); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java b/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java index 31f564a..a694b48 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java @@ -1,14 +1,14 @@ package software.amazon.payloadoffloading; -import com.amazonaws.AmazonClientException; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.exception.SdkClientException; /** * This class is used for carrying pointer to Amazon S3 objects which contain payloads. */ public class PayloadS3Pointer { - private static final Log LOG = LogFactory.getLog(PayloadS3Pointer.class); + private static final Logger LOG = LoggerFactory.getLogger(PayloadS3Pointer.class); private String s3BucketName; private String s3Key; @@ -38,7 +38,7 @@ public String toJson() { } catch (Exception e) { String errorMessage = "Failed to convert S3 object pointer to text."; LOG.error(errorMessage, e); - throw new AmazonClientException(errorMessage, e); + throw SdkClientException.create(errorMessage, e); } return s3PointerStr; } @@ -52,7 +52,7 @@ public static PayloadS3Pointer fromJson(String s3PointerJson) { } catch (Exception e) { String errorMessage = "Failed to read the S3 object pointer from given string."; LOG.error(errorMessage, e); - throw new AmazonClientException(errorMessage, e); + throw SdkClientException.create(errorMessage, e); } return s3Pointer; } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java index f1cf7c2..7ec1449 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java @@ -1,21 +1,39 @@ package software.amazon.payloadoffloading; -import com.amazonaws.AmazonClientException; -import com.amazonaws.annotation.NotThreadSafe; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.annotations.NotThreadSafe; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3Client; /** - * Amazon payload storage configuration options such as Amazon S3 client, - * bucket name, and payload size threshold for payloads. + *

Amazon payload storage configuration options such as Amazon S3 client, + * bucket name, and payload size threshold for payloads.

+ * + *

Server side encryption is optional and can be enabled using with {@link #withServerSideEncryption(ServerSideEncryptionStrategy)} + * or {@link #setServerSideEncryptionStrategy(ServerSideEncryptionStrategy)}

+ * + *

There are two possible options for server side encrption. This can be using a customer managed key or AWS managed CMK.

+ * + * Example usage: + * + *
+ *     withServerSideEncryption(ServerSideEncrptionFactory.awsManagedCmk())
+ * 
+ * + * or + * + *
+*     withServerSideEncryption(ServerSideEncrptionFactory.customerKey(YOUR_CUSTOMER_ID))
+ * 
+ * + * @see software.amazon.payloadoffloading.ServerSideEncryptionFactory */ @NotThreadSafe public class PayloadStorageConfiguration { - private static final Log LOG = LogFactory.getLog(PayloadStorageConfiguration.class); + private static final Logger LOG = LoggerFactory.getLogger(PayloadStorageConfiguration.class); - private AmazonS3 s3; + private S3Client s3; private String s3BucketName; private int payloadSizeThreshold = 0; private boolean alwaysThroughS3 = false; @@ -23,21 +41,21 @@ public class PayloadStorageConfiguration { /** * This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS. */ - private SSEAwsKeyManagementParams sseAwsKeyManagementParams; + private ServerSideEncryptionStrategy serverSideEncryptionStrategy; public PayloadStorageConfiguration() { s3 = null; s3BucketName = null; - sseAwsKeyManagementParams = null; + serverSideEncryptionStrategy = null; } public PayloadStorageConfiguration(PayloadStorageConfiguration other) { - this.s3 = other.getAmazonS3Client(); + this.s3 = other.getS3Client(); this.s3BucketName = other.getS3BucketName(); - this.sseAwsKeyManagementParams = other.getSSEAwsKeyManagementParams(); this.payloadSupport = other.isPayloadSupportEnabled(); this.alwaysThroughS3 = other.isAlwaysThroughS3(); this.payloadSizeThreshold = other.getPayloadSizeThreshold(); + this.serverSideEncryptionStrategy = other.getServerSideEncryptionStrategy(); } /** @@ -47,11 +65,11 @@ public PayloadStorageConfiguration(PayloadStorageConfiguration other) { * @param s3BucketName Name of the bucket which is going to be used for storing payload. * The bucket must be already created and configured in s3. */ - public void setPayloadSupportEnabled(AmazonS3 s3, String s3BucketName) { + public void setPayloadSupportEnabled(S3Client s3, String s3BucketName) { if (s3 == null || s3BucketName == null) { String errorMessage = "S3 client and/or S3 bucket name cannot be null."; LOG.error(errorMessage); - throw new AmazonClientException(errorMessage); + throw SdkClientException.create(errorMessage); } if (isPayloadSupportEnabled()) { LOG.warn("Payload support is already enabled. Overwriting AmazonS3Client and S3BucketName."); @@ -70,7 +88,7 @@ public void setPayloadSupportEnabled(AmazonS3 s3, String s3BucketName) { * The bucket must be already created and configured in s3. * @return the updated PayloadStorageConfiguration object. */ - public PayloadStorageConfiguration withPayloadSupportEnabled(AmazonS3 s3, String s3BucketName) { + public PayloadStorageConfiguration withPayloadSupportEnabled(S3Client s3, String s3BucketName) { setPayloadSupportEnabled(s3, s3BucketName); return this; } @@ -109,7 +127,7 @@ public boolean isPayloadSupportEnabled() { * * @return Reference to the Amazon S3 client which is being used. */ - public AmazonS3 getAmazonS3Client() { + public S3Client getS3Client() { return s3; } @@ -122,35 +140,6 @@ public String getS3BucketName() { return s3BucketName; } - /** - * Gets the S3 SSE-KMS encryption params of S3 objects under configured S3 bucket name. - * - * @return The S3 SSE-KMS params used for encryption. - */ - public SSEAwsKeyManagementParams getSSEAwsKeyManagementParams() { - return sseAwsKeyManagementParams; - } - - /** - * Sets the the S3 SSE-KMS encryption params of S3 objects under configured S3 bucket name. - * - * @param sseAwsKeyManagementParams The S3 SSE-KMS params used for encryption. - */ - public void setSSEAwsKeyManagementParams(SSEAwsKeyManagementParams sseAwsKeyManagementParams) { - this.sseAwsKeyManagementParams = sseAwsKeyManagementParams; - } - - /** - * Sets the the S3 SSE-KMS encryption params of S3 objects under configured S3 bucket name. - * - * @param sseAwsKeyManagementParams The S3 SSE-KMS params used for encryption. - * @return the updated PayloadStorageConfiguration object - */ - public PayloadStorageConfiguration withSSEAwsKeyManagementParams(SSEAwsKeyManagementParams sseAwsKeyManagementParams) { - setSSEAwsKeyManagementParams(sseAwsKeyManagementParams); - return this; - } - /** * Sets the payload size threshold for storing payloads in Amazon S3. * @@ -212,4 +201,38 @@ public boolean isAlwaysThroughS3() { public void setAlwaysThroughS3(boolean alwaysThroughS3) { this.alwaysThroughS3 = alwaysThroughS3; } + + /** + * Sets which method of server side encryption should be used, if required. + * + * This is optional, it is set only when you want to configure S3 server side encryption with KMS. + * + * @param serverSideEncryptionStrategy The method of encryption required for S3 server side encryption with KMS. + * @return the updated PayloadStorageConfiguration object. + */ + public PayloadStorageConfiguration withServerSideEncryption(ServerSideEncryptionStrategy serverSideEncryptionStrategy) { + setServerSideEncryptionStrategy(serverSideEncryptionStrategy); + return this; + } + + /** + * Sets which method of server side encryption should be use, if required. + * + * This is optional, it is set only when you want to configure S3 Server Side Encryption with KMS. + * + * @param serverSideEncryptionStrategy The method of encryption required for S3 server side encryption with KMS. + */ + public void setServerSideEncryptionStrategy(ServerSideEncryptionStrategy serverSideEncryptionStrategy) { + this.serverSideEncryptionStrategy = serverSideEncryptionStrategy; + } + + /** + * The method of service side encryption which should be used, if required. + * + * @return The server side encryption method required. Default null. + */ + public ServerSideEncryptionStrategy getServerSideEncryptionStrategy() { + return this.serverSideEncryptionStrategy; + } + } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStore.java b/src/main/java/software/amazon/payloadoffloading/PayloadStore.java index a47d526..7703e8f 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStore.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStore.java @@ -1,7 +1,7 @@ package software.amazon.payloadoffloading; -import com.amazonaws.AmazonClientException; -import com.amazonaws.AmazonServiceException; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.model.S3Exception; /** * An AWS storage service that supports saving high payload sizes. @@ -11,15 +11,14 @@ public interface PayloadStore { * Stores payload in a store that has higher payload size limit than that is supported by original payload store. * * @param payload - * @param payloadContentSize * @return a pointer that must be used to retrieve the original payload later. - * @throws AmazonClientException If any internal errors are encountered on the client side while + * @throws SdkClientException If any internal errors are encountered on the client side while * attempting to make the request or handle the response. For example * if a network connection is not available. - * @throws AmazonServiceException If an error response is returned by actual PayloadStore indicating + * @throws S3Exception If an error response is returned by actual PayloadStore indicating * either a problem with the data in the request, or a server side issue. */ - String storeOriginalPayload(String payload, Long payloadContentSize); + String storeOriginalPayload(String payload); /** * Retrieves the original payload using the given payloadPointer. The pointer must @@ -27,10 +26,10 @@ public interface PayloadStore { * * @param payloadPointer * @return original payload - * @throws AmazonClientException If any internal errors are encountered on the client side while + * @throws SdkClientException If any internal errors are encountered on the client side while * attempting to make the request or handle the response. For example * if payloadPointer is invalid or a network connection is not available. - * @throws AmazonServiceException If an error response is returned by actual PayloadStore indicating + * @throws S3Exception If an error response is returned by actual PayloadStore indicating * a server side issue. */ String getOriginalPayload(String payloadPointer); @@ -40,10 +39,10 @@ public interface PayloadStore { * have been obtained using {@link storeOriginalPayload} * * @param payloadPointer - * @throws AmazonClientException If any internal errors are encountered on the client side while + * @throws SdkClientException If any internal errors are encountered on the client side while * attempting to make the request or handle the response to/from PayloadStore. * For example, if payloadPointer is invalid or a network connection is not available. - * @throws AmazonServiceException If an error response is returned by actual PayloadStore indicating + * @throws S3Exception If an error response is returned by actual PayloadStore indicating * a server side issue. */ void deleteOriginalPayload(String payloadPointer); diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java index 7fe7965..b8eb6c5 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java +++ b/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java @@ -1,8 +1,7 @@ package software.amazon.payloadoffloading; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.UUID; @@ -10,29 +9,28 @@ * S3 based implementation for PayloadStore. */ public class S3BackedPayloadStore implements PayloadStore { - private static final Log LOG = LogFactory.getLog(S3BackedPayloadStore.class); + private static final Logger LOG = LoggerFactory.getLogger(S3BackedPayloadStore.class); private final String s3BucketName; private final S3Dao s3Dao; - private final SSEAwsKeyManagementParams sseAwsKeyManagementParams; + private final ServerSideEncryptionStrategy serverSideEncryptionStrategy; public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName) { this(s3Dao, s3BucketName, null); } - public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName, - SSEAwsKeyManagementParams sseAwsKeyManagementParams) { + public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName, ServerSideEncryptionStrategy serverSideEncryptionStrategy) { this.s3BucketName = s3BucketName; this.s3Dao = s3Dao; - this.sseAwsKeyManagementParams = sseAwsKeyManagementParams; + this.serverSideEncryptionStrategy = serverSideEncryptionStrategy; } @Override - public String storeOriginalPayload(String payload, Long payloadContentSize) { + public String storeOriginalPayload(String payload) { String s3Key = UUID.randomUUID().toString(); // Store the payload content in S3. - s3Dao.storeTextInS3(s3BucketName, s3Key, sseAwsKeyManagementParams, payload, payloadContentSize); + s3Dao.storeTextInS3(s3BucketName, s3Key, serverSideEncryptionStrategy, payload); LOG.info("S3 object created, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); // Convert S3 pointer (bucket name, key, etc) to JSON string diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java index a4c5c07..14d7d75 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java @@ -1,111 +1,92 @@ package software.amazon.payloadoffloading; -import com.amazonaws.AmazonClientException; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.*; -import com.amazonaws.util.IOUtils; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.utils.IoUtils; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; /** * Dao layer to access S3. */ public class S3Dao { - private static final Log LOG = LogFactory.getLog(S3Dao.class); - private final AmazonS3 s3Client; + private static final Logger LOG = LoggerFactory.getLogger(S3Dao.class); + private final S3Client s3Client; - public S3Dao(AmazonS3 s3Client) { + public S3Dao(S3Client s3Client) { this.s3Client = s3Client; } public String getTextFromS3(String s3BucketName, String s3Key) { - GetObjectRequest getObjectRequest = new GetObjectRequest(s3BucketName, s3Key); - String embeddedText = null; - S3Object obj = null; + GetObjectRequest getObjectRequest = GetObjectRequest.builder() + .bucket(s3BucketName) + .key(s3Key) + .build(); + ResponseInputStream object = null; try { - obj = s3Client.getObject(getObjectRequest); - - } catch (AmazonServiceException e) { - String errorMessage = "Failed to get the S3 object which contains the payload."; - LOG.error(errorMessage, e); - throw new AmazonServiceException(errorMessage, e); - - } catch (AmazonClientException e) { + object = s3Client.getObject(getObjectRequest); + } catch (SdkException e) { String errorMessage = "Failed to get the S3 object which contains the payload."; LOG.error(errorMessage, e); - throw new AmazonClientException(errorMessage, e); + throw SdkException.create(errorMessage, e); } - S3ObjectInputStream is = obj.getObjectContent(); - + String embeddedText; try { - embeddedText = IOUtils.toString(is); - + embeddedText = IoUtils.toUtf8String(object); } catch (IOException e) { String errorMessage = "Failure when handling the message which was read from S3 object."; LOG.error(errorMessage, e); - throw new AmazonClientException(errorMessage, e); + throw SdkClientException.create(errorMessage, e); } finally { - IOUtils.closeQuietly(is, LOG); + IoUtils.closeQuietly(object, LOG); } return embeddedText; } - public void storeTextInS3(String s3BucketName, String s3Key, SSEAwsKeyManagementParams sseAwsKeyManagementParams, - String payloadContentStr, Long payloadContentSize) { - InputStream payloadContentStream = new ByteArrayInputStream(payloadContentStr.getBytes(StandardCharsets.UTF_8)); - ObjectMetadata payloadContentStreamMetadata = new ObjectMetadata(); - payloadContentStreamMetadata.setContentLength(payloadContentSize); - PutObjectRequest putObjectRequest = new PutObjectRequest(s3BucketName, s3Key, - payloadContentStream, payloadContentStreamMetadata); + public void storeTextInS3(String s3BucketName, String s3Key, ServerSideEncryptionStrategy serverSideEncryptionStrategy, String payloadContentStr) { + PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder() + .bucket(s3BucketName) + .key(s3Key); // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html - if (sseAwsKeyManagementParams != null) { - LOG.debug("Using SSE-KMS in put object request."); - putObjectRequest.setSSEAwsKeyManagementParams(sseAwsKeyManagementParams); + if (serverSideEncryptionStrategy != null) { + serverSideEncryptionStrategy.decorate(putObjectRequestBuilder); } try { - s3Client.putObject(putObjectRequest); - - } catch (AmazonServiceException e) { + s3Client.putObject(putObjectRequestBuilder.build(), RequestBody.fromString(payloadContentStr)); + } catch (SdkException e) { String errorMessage = "Failed to store the message content in an S3 object."; LOG.error(errorMessage, e); - throw new AmazonServiceException(errorMessage, e); - - } catch (AmazonClientException e) { - String errorMessage = "Failed to store the message content in an S3 object."; - LOG.error(errorMessage, e); - throw new AmazonClientException(errorMessage, e); + throw SdkException.create(errorMessage, e); } } - public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr, Long payloadContentSize) { - storeTextInS3(s3BucketName, s3Key, null, payloadContentStr, payloadContentSize); - } - public void deletePayloadFromS3(String s3BucketName, String s3Key) { try { - s3Client.deleteObject(s3BucketName, s3Key); - - } catch (AmazonServiceException e) { - String errorMessage = "Failed to delete the S3 object which contains the payload"; - LOG.error(errorMessage, e); - throw new AmazonServiceException(errorMessage, e); + DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder() + .bucket(s3BucketName) + .key(s3Key) + .build(); + s3Client.deleteObject(deleteObjectRequest); - } catch (AmazonClientException e) { + } catch (SdkException e) { String errorMessage = "Failed to delete the S3 object which contains the payload"; LOG.error(errorMessage, e); - throw new AmazonClientException(errorMessage, e); + throw SdkException.create(errorMessage, e); } LOG.info("S3 object deleted, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); diff --git a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionFactory.java b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionFactory.java new file mode 100644 index 0000000..eda5d72 --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionFactory.java @@ -0,0 +1,11 @@ +package software.amazon.payloadoffloading; + +public class ServerSideEncryptionFactory { + public static ServerSideEncryptionStrategy awsManagedCmk() { + return new AwsManagedCmk(); + } + + public static ServerSideEncryptionStrategy customerKey(String awsKmsKeyId) { + return new CustomerKey(awsKmsKeyId); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java new file mode 100644 index 0000000..f385ce6 --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java @@ -0,0 +1,7 @@ +package software.amazon.payloadoffloading; + +import software.amazon.awssdk.services.s3.model.PutObjectRequest; + +public interface ServerSideEncryptionStrategy { + void decorate(PutObjectRequest.Builder putObjectRequestBuilder); +} diff --git a/src/main/java/software/amazon/payloadoffloading/Util.java b/src/main/java/software/amazon/payloadoffloading/Util.java index 5e18fdc..bd3932f 100644 --- a/src/main/java/software/amazon/payloadoffloading/Util.java +++ b/src/main/java/software/amazon/payloadoffloading/Util.java @@ -1,22 +1,23 @@ package software.amazon.payloadoffloading; -import com.amazonaws.AmazonClientException; -import com.amazonaws.util.VersionInfoUtils; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.util.VersionInfo; import java.io.IOException; import java.io.OutputStreamWriter; import java.io.Writer; +import java.nio.charset.StandardCharsets; public class Util { - private static final Log LOG = LogFactory.getLog(Util.class); + private static final Logger LOG = LoggerFactory.getLogger(Util.class); public static long getStringSizeInBytes(String str) { CountingOutputStream counterOutputStream = new CountingOutputStream(); try { - Writer writer = new OutputStreamWriter(counterOutputStream, "UTF-8"); + Writer writer = new OutputStreamWriter(counterOutputStream, StandardCharsets.UTF_8); writer.write(str); writer.flush(); writer.close(); @@ -24,13 +25,13 @@ public static long getStringSizeInBytes(String str) { } catch (IOException e) { String errorMessage = "Failed to calculate the size of payload."; LOG.error(errorMessage, e); - throw new AmazonClientException(errorMessage, e); + throw SdkClientException.create(errorMessage, e); } return counterOutputStream.getTotalSize(); } public static String getUserAgentHeader(String clientName) { - return clientName + "/" + VersionInfoUtils.getVersion(); + return clientName + "/" + VersionInfo.SDK_VERSION; } } diff --git a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java new file mode 100644 index 0000000..4bb95c5 --- /dev/null +++ b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java @@ -0,0 +1,21 @@ +package software.amazon.payloadoffloading; + +import org.junit.Test; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.ServerSideEncryption; + +import static org.junit.Assert.assertEquals; + +public class AwsManagedCmkTest { + + @Test + public void testAwsManagedCmkStrategySetsCorrectEncryptionValues() { + AwsManagedCmk awsManagedCmk = new AwsManagedCmk(); + + PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder(); + awsManagedCmk.decorate(putObjectRequestBuilder); + PutObjectRequest putObjectRequest = putObjectRequestBuilder.build(); + + assertEquals(putObjectRequest.serverSideEncryption(), (ServerSideEncryption.AWS_KMS)); + } +} \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java new file mode 100644 index 0000000..3da7090 --- /dev/null +++ b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java @@ -0,0 +1,24 @@ +package software.amazon.payloadoffloading; + +import org.junit.Test; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.ServerSideEncryption; + +import static org.junit.Assert.assertEquals; + +public class CustomerKeyTest { + + public static final String AWS_KMS_KEY_ID = "123456"; + + @Test + public void testCustomerKeyStrategySetsCorrectEncryptionValues() { + CustomerKey customerKey = new CustomerKey(AWS_KMS_KEY_ID); + + PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder(); + customerKey.decorate(putObjectRequestBuilder); + PutObjectRequest putObjectRequest = putObjectRequestBuilder.build(); + + assertEquals(putObjectRequest.serverSideEncryption(), ServerSideEncryption.AWS_KMS); + assertEquals(putObjectRequest.ssekmsKeyId(), AWS_KMS_KEY_ID); + } +} \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java index 2c51438..b1dad99 100644 --- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java +++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java @@ -1,9 +1,7 @@ package software.amazon.payloadoffloading; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; -import org.junit.Before; import org.junit.Test; +import software.amazon.awssdk.services.s3.S3Client; import static org.mockito.Mockito.mock; import static org.junit.Assert.*; @@ -13,18 +11,13 @@ */ public class PayloadStorageConfigurationTest { - private static String s3BucketName = "test-bucket-name"; - private static String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id"; - private SSEAwsKeyManagementParams sseAwsKeyManagementParams; - - @Before - public void setup() { - sseAwsKeyManagementParams = new SSEAwsKeyManagementParams(s3ServerSideEncryptionKMSKeyId); - } + private static final String s3BucketName = "test-bucket-name"; + private static final String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id"; + private static final ServerSideEncryptionStrategy SERVER_SIDE_ENCRYPTION_STRATEGY = ServerSideEncryptionFactory.awsManagedCmk(); @Test public void testCopyConstructor() { - AmazonS3 s3 = mock(AmazonS3.class); + S3Client s3 = mock(S3Client.class); boolean alwaysThroughS3 = true; int payloadSizeThreshold = 500; @@ -32,15 +25,15 @@ public void testCopyConstructor() { PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); payloadStorageConfiguration.withPayloadSupportEnabled(s3, s3BucketName) - .withAlwaysThroughS3(alwaysThroughS3).withPayloadSizeThreshold(payloadSizeThreshold) - .withSSEAwsKeyManagementParams(sseAwsKeyManagementParams); + .withAlwaysThroughS3(alwaysThroughS3) + .withPayloadSizeThreshold(payloadSizeThreshold) + .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_STRATEGY); PayloadStorageConfiguration newPayloadStorageConfiguration = new PayloadStorageConfiguration(payloadStorageConfiguration); - assertEquals(s3, newPayloadStorageConfiguration.getAmazonS3Client()); + assertEquals(s3, newPayloadStorageConfiguration.getS3Client()); assertEquals(s3BucketName, newPayloadStorageConfiguration.getS3BucketName()); - assertEquals(sseAwsKeyManagementParams, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams()); - assertEquals(s3ServerSideEncryptionKMSKeyId, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams().getAwsKmsKeyId()); + assertEquals(SERVER_SIDE_ENCRYPTION_STRATEGY, newPayloadStorageConfiguration.getServerSideEncryptionStrategy()); assertTrue(newPayloadStorageConfiguration.isPayloadSupportEnabled()); assertEquals(alwaysThroughS3, newPayloadStorageConfiguration.isAlwaysThroughS3()); assertEquals(payloadSizeThreshold, newPayloadStorageConfiguration.getPayloadSizeThreshold()); @@ -49,12 +42,12 @@ public void testCopyConstructor() { @Test public void testPayloadSupportEnabled() { - AmazonS3 s3 = mock(AmazonS3.class); + S3Client s3 = mock(S3Client.class); PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); payloadStorageConfiguration.setPayloadSupportEnabled(s3, s3BucketName); assertTrue(payloadStorageConfiguration.isPayloadSupportEnabled()); - assertNotNull(payloadStorageConfiguration.getAmazonS3Client()); + assertNotNull(payloadStorageConfiguration.getS3Client()); assertEquals(s3BucketName, payloadStorageConfiguration.getS3BucketName()); } @@ -63,7 +56,7 @@ public void testDisablePayloadSupport() { PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); payloadStorageConfiguration.setPayloadSupportDisabled(); - assertNull(payloadStorageConfiguration.getAmazonS3Client()); + assertNull(payloadStorageConfiguration.getS3Client()); assertNull(payloadStorageConfiguration.getS3BucketName()); } @@ -82,10 +75,9 @@ public void testAlwaysThroughS3() { public void testSseAwsKeyManagementParams() { PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); - assertNull(payloadStorageConfiguration.getSSEAwsKeyManagementParams()); + assertNull(payloadStorageConfiguration.getServerSideEncryptionStrategy()); - payloadStorageConfiguration.setSSEAwsKeyManagementParams(sseAwsKeyManagementParams); - assertEquals(s3ServerSideEncryptionKMSKeyId, payloadStorageConfiguration.getSSEAwsKeyManagementParams() - .getAwsKmsKeyId()); + payloadStorageConfiguration.setServerSideEncryptionStrategy(SERVER_SIDE_ENCRYPTION_STRATEGY); + assertEquals(SERVER_SIDE_ENCRYPTION_STRATEGY, payloadStorageConfiguration.getServerSideEncryptionStrategy()); } } diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java index f6bf2dc..e9e12c1 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java @@ -1,7 +1,5 @@ package software.amazon.payloadoffloading; -import com.amazonaws.AmazonClientException; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; import junitparams.JUnitParamsRunner; import junitparams.Parameters; import org.hamcrest.Matchers; @@ -11,14 +9,21 @@ import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.exception.SdkException; + +import java.util.Objects; import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; @RunWith(JUnitParamsRunner.class) public class S3BackedPayloadStoreTest { private static final String S3_BUCKET_NAME = "test-bucket-name"; private static final String S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID = "test-customer-managed-kms-key-id"; + private static final ServerSideEncryptionStrategy KMS_WITH_CUSTOMER_KEY = ServerSideEncryptionFactory.customerKey(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID); + private static final ServerSideEncryptionStrategy KMS_WITH_AWS_MANAGED_CMK = ServerSideEncryptionFactory.awsManagedCmk(); private static final String ANY_PAYLOAD = "AnyPayload"; private static final String ANY_S3_KEY = "AnyS3key"; private static final String INCORRECT_POINTER_EXCEPTION_MSG = "Failed to read the S3 object pointer from given string"; @@ -53,15 +58,14 @@ private Object[] testData() { }, // S3 SSE-KMS encryption with AWS managed KMS keys { - new S3BackedPayloadStore(defaultEncryptionS3Dao, S3_BUCKET_NAME, new SSEAwsKeyManagementParams()), - new SSEAwsKeyManagementParams(), + new S3BackedPayloadStore(defaultEncryptionS3Dao, S3_BUCKET_NAME, KMS_WITH_AWS_MANAGED_CMK), + KMS_WITH_AWS_MANAGED_CMK, defaultEncryptionS3Dao }, // S3 SSE-KMS encryption with customer managed KMS key { - new S3BackedPayloadStore(customerKMSKeyEncryptionS3Dao, S3_BUCKET_NAME, - new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID)), - new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID), + new S3BackedPayloadStore(customerKMSKeyEncryptionS3Dao, S3_BUCKET_NAME, KMS_WITH_CUSTOMER_KEY), + KMS_WITH_CUSTOMER_KEY, customerKMSKeyEncryptionS3Dao } }; @@ -69,15 +73,14 @@ private Object[] testData() { @Test @Parameters(method = "testData") - public void testStoreOriginalPayloadOnSuccess(PayloadStore payloadStore, - SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) { - String actualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH); + public void testStoreOriginalPayloadOnSuccess(PayloadStore payloadStore, ServerSideEncryptionStrategy expectedParams, S3Dao mockS3Dao) { + String actualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD); ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); - ArgumentCaptor sseArgsCaptor = ArgumentCaptor.forClass(SSEAwsKeyManagementParams.class); + ArgumentCaptor sseArgsCaptor = ArgumentCaptor.forClass(ServerSideEncryptionStrategy.class); verify(mockS3Dao, times(1)).storeTextInS3(eq(S3_BUCKET_NAME), keyCaptor.capture(), - sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH)); + sseArgsCaptor.capture(), eq(ANY_PAYLOAD)); PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, keyCaptor.getValue()); assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); @@ -85,30 +88,26 @@ public void testStoreOriginalPayloadOnSuccess(PayloadStore payloadStore, if (expectedParams == null) { assertTrue(sseArgsCaptor.getValue() == null); } else { - assertEquals(expectedParams.getAwsKmsKeyId(), sseArgsCaptor.getValue().getAwsKmsKeyId()); + assertEquals(expectedParams, sseArgsCaptor.getValue()); } } @Test @Parameters(method = "testData") public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payloadStore, - SSEAwsKeyManagementParams expectedParams, + ServerSideEncryptionStrategy expectedParams, S3Dao mockS3Dao) { //Store any payload - String anyActualPayloadPointer = payloadStore - .storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH); + String anyActualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD); //Store any other payload and validate that the pointers are different - String anyOtherActualPayloadPointer = payloadStore - .storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH); + String anyOtherActualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD); ArgumentCaptor anyOtherKeyCaptor = ArgumentCaptor.forClass(String.class); - - ArgumentCaptor sseArgsCaptor = ArgumentCaptor - .forClass(SSEAwsKeyManagementParams.class); + ArgumentCaptor sseArgsCaptor = ArgumentCaptor.forClass(ServerSideEncryptionStrategy.class); verify(mockS3Dao, times(2)).storeTextInS3(eq(S3_BUCKET_NAME), anyOtherKeyCaptor.capture(), - sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH)); + sseArgsCaptor.capture(), eq(ANY_PAYLOAD)); String anyS3Key = anyOtherKeyCaptor.getAllValues().get(0); String anyOtherS3Key = anyOtherKeyCaptor.getAllValues().get(1); @@ -123,31 +122,29 @@ public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payl assertThat(anyActualPayloadPointer, Matchers.not(anyOtherActualPayloadPointer)); if (expectedParams == null) { - assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams -> actualParams == null)); + assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(Objects::isNull)); } else { assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams -> - (actualParams.getAwsKmsKeyId() == null && expectedParams.getAwsKmsKeyId() == null) - || (actualParams.getAwsKmsKeyId().equals(expectedParams.getAwsKmsKeyId())))); + actualParams.equals(expectedParams))); } } @Test @Parameters(method = "testData") - public void testStoreOriginalPayloadOnS3Failure(PayloadStore payloadStore, - SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) { - doThrow(new AmazonClientException("S3 Exception")) + public void testStoreOriginalPayloadOnS3Failure(PayloadStore payloadStore, ServerSideEncryptionStrategy awsKmsKeyId, S3Dao mockS3Dao) { + doThrow(SdkException.create("S3 Exception", new Throwable())) .when(mockS3Dao) .storeTextInS3( any(String.class), any(String.class), - expectedParams == null ? isNull() : any(SSEAwsKeyManagementParams.class), - any(String.class), - any(Long.class)); + // Can be String or null + any(), + any(String.class)); - exception.expect(AmazonClientException.class); + exception.expect(SdkException.class); exception.expectMessage("S3 Exception"); //Any S3 Dao exception is thrown back as-is to clients - payloadStore.storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH); + payloadStore.storeOriginalPayload(ANY_PAYLOAD); } @Test @@ -167,7 +164,7 @@ public void testGetOriginalPayloadOnSuccess() { @Test public void testGetOriginalPayloadIncorrectPointer() { - exception.expect(AmazonClientException.class); + exception.expect(SdkClientException.class); exception.expectMessage(INCORRECT_POINTER_EXCEPTION_MSG); //Any S3 Dao exception is thrown back as-is to clients payloadStore.getOriginalPayload("IncorrectPointer"); @@ -176,8 +173,8 @@ public void testGetOriginalPayloadIncorrectPointer() { @Test public void testGetOriginalPayloadOnS3Failure() { - when(s3Dao.getTextFromS3(any(String.class), any(String.class))).thenThrow(new AmazonClientException("S3 Exception")); - exception.expect(AmazonClientException.class); + when(s3Dao.getTextFromS3(any(String.class), any(String.class))).thenThrow(SdkException.create("S3 Exception", new Throwable())); + exception.expect(SdkException.class); exception.expectMessage("S3 Exception"); //Any S3 Dao exception is thrown back as-is to clients PayloadS3Pointer anyPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY); @@ -199,7 +196,7 @@ public void testDeleteOriginalPayloadOnSuccess() { @Test public void testDeleteOriginalPayloadIncorrectPointer() { - exception.expect(AmazonClientException.class); + exception.expect(SdkClientException.class); exception.expectMessage(INCORRECT_POINTER_EXCEPTION_MSG); payloadStore.deleteOriginalPayload("IncorrectPointer"); verifyNoInteractions(s3Dao);