junit
junit
diff --git a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java
new file mode 100644
index 0000000..ae291f4
--- /dev/null
+++ b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java
@@ -0,0 +1,11 @@
+package software.amazon.payloadoffloading;
+
+import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
+
+public class AwsManagedCmk implements ServerSideEncryptionStrategy {
+ @Override
+ public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) {
+ putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS);
+ }
+}
diff --git a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java
new file mode 100644
index 0000000..7f62d49
--- /dev/null
+++ b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java
@@ -0,0 +1,18 @@
+package software.amazon.payloadoffloading;
+
+import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
+
+public class CustomerKey implements ServerSideEncryptionStrategy {
+ private final String awsKmsKeyId;
+
+ public CustomerKey(String awsKmsKeyId) {
+ this.awsKmsKeyId = awsKmsKeyId;
+ }
+
+ @Override
+ public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) {
+ putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS);
+ putObjectRequestBuilder.ssekmsKeyId(awsKmsKeyId);
+ }
+}
diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java b/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java
index 31f564a..a694b48 100644
--- a/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java
+++ b/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java
@@ -1,14 +1,14 @@
package software.amazon.payloadoffloading;
-import com.amazonaws.AmazonClientException;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import software.amazon.awssdk.core.exception.SdkClientException;
/**
* This class is used for carrying pointer to Amazon S3 objects which contain payloads.
*/
public class PayloadS3Pointer {
- private static final Log LOG = LogFactory.getLog(PayloadS3Pointer.class);
+ private static final Logger LOG = LoggerFactory.getLogger(PayloadS3Pointer.class);
private String s3BucketName;
private String s3Key;
@@ -38,7 +38,7 @@ public String toJson() {
} catch (Exception e) {
String errorMessage = "Failed to convert S3 object pointer to text.";
LOG.error(errorMessage, e);
- throw new AmazonClientException(errorMessage, e);
+ throw SdkClientException.create(errorMessage, e);
}
return s3PointerStr;
}
@@ -52,7 +52,7 @@ public static PayloadS3Pointer fromJson(String s3PointerJson) {
} catch (Exception e) {
String errorMessage = "Failed to read the S3 object pointer from given string.";
LOG.error(errorMessage, e);
- throw new AmazonClientException(errorMessage, e);
+ throw SdkClientException.create(errorMessage, e);
}
return s3Pointer;
}
diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java
index f1cf7c2..7ec1449 100644
--- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java
+++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java
@@ -1,21 +1,39 @@
package software.amazon.payloadoffloading;
-import com.amazonaws.AmazonClientException;
-import com.amazonaws.annotation.NotThreadSafe;
-import com.amazonaws.services.s3.AmazonS3;
-import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import software.amazon.awssdk.annotations.NotThreadSafe;
+import software.amazon.awssdk.core.exception.SdkClientException;
+import software.amazon.awssdk.services.s3.S3Client;
/**
- * Amazon payload storage configuration options such as Amazon S3 client,
- * bucket name, and payload size threshold for payloads.
+ * Amazon payload storage configuration options such as Amazon S3 client,
+ * bucket name, and payload size threshold for payloads.
+ *
+ * Server side encryption is optional and can be enabled using with {@link #withServerSideEncryption(ServerSideEncryptionStrategy)}
+ * or {@link #setServerSideEncryptionStrategy(ServerSideEncryptionStrategy)}
+ *
+ * There are two possible options for server side encrption. This can be using a customer managed key or AWS managed CMK.
+ *
+ * Example usage:
+ *
+ *
+ * withServerSideEncryption(ServerSideEncrptionFactory.awsManagedCmk())
+ *
+ *
+ * or
+ *
+ *
+* withServerSideEncryption(ServerSideEncrptionFactory.customerKey(YOUR_CUSTOMER_ID))
+ *
+ *
+ * @see software.amazon.payloadoffloading.ServerSideEncryptionFactory
*/
@NotThreadSafe
public class PayloadStorageConfiguration {
- private static final Log LOG = LogFactory.getLog(PayloadStorageConfiguration.class);
+ private static final Logger LOG = LoggerFactory.getLogger(PayloadStorageConfiguration.class);
- private AmazonS3 s3;
+ private S3Client s3;
private String s3BucketName;
private int payloadSizeThreshold = 0;
private boolean alwaysThroughS3 = false;
@@ -23,21 +41,21 @@ public class PayloadStorageConfiguration {
/**
* This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS.
*/
- private SSEAwsKeyManagementParams sseAwsKeyManagementParams;
+ private ServerSideEncryptionStrategy serverSideEncryptionStrategy;
public PayloadStorageConfiguration() {
s3 = null;
s3BucketName = null;
- sseAwsKeyManagementParams = null;
+ serverSideEncryptionStrategy = null;
}
public PayloadStorageConfiguration(PayloadStorageConfiguration other) {
- this.s3 = other.getAmazonS3Client();
+ this.s3 = other.getS3Client();
this.s3BucketName = other.getS3BucketName();
- this.sseAwsKeyManagementParams = other.getSSEAwsKeyManagementParams();
this.payloadSupport = other.isPayloadSupportEnabled();
this.alwaysThroughS3 = other.isAlwaysThroughS3();
this.payloadSizeThreshold = other.getPayloadSizeThreshold();
+ this.serverSideEncryptionStrategy = other.getServerSideEncryptionStrategy();
}
/**
@@ -47,11 +65,11 @@ public PayloadStorageConfiguration(PayloadStorageConfiguration other) {
* @param s3BucketName Name of the bucket which is going to be used for storing payload.
* The bucket must be already created and configured in s3.
*/
- public void setPayloadSupportEnabled(AmazonS3 s3, String s3BucketName) {
+ public void setPayloadSupportEnabled(S3Client s3, String s3BucketName) {
if (s3 == null || s3BucketName == null) {
String errorMessage = "S3 client and/or S3 bucket name cannot be null.";
LOG.error(errorMessage);
- throw new AmazonClientException(errorMessage);
+ throw SdkClientException.create(errorMessage);
}
if (isPayloadSupportEnabled()) {
LOG.warn("Payload support is already enabled. Overwriting AmazonS3Client and S3BucketName.");
@@ -70,7 +88,7 @@ public void setPayloadSupportEnabled(AmazonS3 s3, String s3BucketName) {
* The bucket must be already created and configured in s3.
* @return the updated PayloadStorageConfiguration object.
*/
- public PayloadStorageConfiguration withPayloadSupportEnabled(AmazonS3 s3, String s3BucketName) {
+ public PayloadStorageConfiguration withPayloadSupportEnabled(S3Client s3, String s3BucketName) {
setPayloadSupportEnabled(s3, s3BucketName);
return this;
}
@@ -109,7 +127,7 @@ public boolean isPayloadSupportEnabled() {
*
* @return Reference to the Amazon S3 client which is being used.
*/
- public AmazonS3 getAmazonS3Client() {
+ public S3Client getS3Client() {
return s3;
}
@@ -122,35 +140,6 @@ public String getS3BucketName() {
return s3BucketName;
}
- /**
- * Gets the S3 SSE-KMS encryption params of S3 objects under configured S3 bucket name.
- *
- * @return The S3 SSE-KMS params used for encryption.
- */
- public SSEAwsKeyManagementParams getSSEAwsKeyManagementParams() {
- return sseAwsKeyManagementParams;
- }
-
- /**
- * Sets the the S3 SSE-KMS encryption params of S3 objects under configured S3 bucket name.
- *
- * @param sseAwsKeyManagementParams The S3 SSE-KMS params used for encryption.
- */
- public void setSSEAwsKeyManagementParams(SSEAwsKeyManagementParams sseAwsKeyManagementParams) {
- this.sseAwsKeyManagementParams = sseAwsKeyManagementParams;
- }
-
- /**
- * Sets the the S3 SSE-KMS encryption params of S3 objects under configured S3 bucket name.
- *
- * @param sseAwsKeyManagementParams The S3 SSE-KMS params used for encryption.
- * @return the updated PayloadStorageConfiguration object
- */
- public PayloadStorageConfiguration withSSEAwsKeyManagementParams(SSEAwsKeyManagementParams sseAwsKeyManagementParams) {
- setSSEAwsKeyManagementParams(sseAwsKeyManagementParams);
- return this;
- }
-
/**
* Sets the payload size threshold for storing payloads in Amazon S3.
*
@@ -212,4 +201,38 @@ public boolean isAlwaysThroughS3() {
public void setAlwaysThroughS3(boolean alwaysThroughS3) {
this.alwaysThroughS3 = alwaysThroughS3;
}
+
+ /**
+ * Sets which method of server side encryption should be used, if required.
+ *
+ * This is optional, it is set only when you want to configure S3 server side encryption with KMS.
+ *
+ * @param serverSideEncryptionStrategy The method of encryption required for S3 server side encryption with KMS.
+ * @return the updated PayloadStorageConfiguration object.
+ */
+ public PayloadStorageConfiguration withServerSideEncryption(ServerSideEncryptionStrategy serverSideEncryptionStrategy) {
+ setServerSideEncryptionStrategy(serverSideEncryptionStrategy);
+ return this;
+ }
+
+ /**
+ * Sets which method of server side encryption should be use, if required.
+ *
+ * This is optional, it is set only when you want to configure S3 Server Side Encryption with KMS.
+ *
+ * @param serverSideEncryptionStrategy The method of encryption required for S3 server side encryption with KMS.
+ */
+ public void setServerSideEncryptionStrategy(ServerSideEncryptionStrategy serverSideEncryptionStrategy) {
+ this.serverSideEncryptionStrategy = serverSideEncryptionStrategy;
+ }
+
+ /**
+ * The method of service side encryption which should be used, if required.
+ *
+ * @return The server side encryption method required. Default null.
+ */
+ public ServerSideEncryptionStrategy getServerSideEncryptionStrategy() {
+ return this.serverSideEncryptionStrategy;
+ }
+
}
diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStore.java b/src/main/java/software/amazon/payloadoffloading/PayloadStore.java
index a47d526..7703e8f 100644
--- a/src/main/java/software/amazon/payloadoffloading/PayloadStore.java
+++ b/src/main/java/software/amazon/payloadoffloading/PayloadStore.java
@@ -1,7 +1,7 @@
package software.amazon.payloadoffloading;
-import com.amazonaws.AmazonClientException;
-import com.amazonaws.AmazonServiceException;
+import software.amazon.awssdk.core.exception.SdkClientException;
+import software.amazon.awssdk.services.s3.model.S3Exception;
/**
* An AWS storage service that supports saving high payload sizes.
@@ -11,15 +11,14 @@ public interface PayloadStore {
* Stores payload in a store that has higher payload size limit than that is supported by original payload store.
*
* @param payload
- * @param payloadContentSize
* @return a pointer that must be used to retrieve the original payload later.
- * @throws AmazonClientException If any internal errors are encountered on the client side while
+ * @throws SdkClientException If any internal errors are encountered on the client side while
* attempting to make the request or handle the response. For example
* if a network connection is not available.
- * @throws AmazonServiceException If an error response is returned by actual PayloadStore indicating
+ * @throws S3Exception If an error response is returned by actual PayloadStore indicating
* either a problem with the data in the request, or a server side issue.
*/
- String storeOriginalPayload(String payload, Long payloadContentSize);
+ String storeOriginalPayload(String payload);
/**
* Retrieves the original payload using the given payloadPointer. The pointer must
@@ -27,10 +26,10 @@ public interface PayloadStore {
*
* @param payloadPointer
* @return original payload
- * @throws AmazonClientException If any internal errors are encountered on the client side while
+ * @throws SdkClientException If any internal errors are encountered on the client side while
* attempting to make the request or handle the response. For example
* if payloadPointer is invalid or a network connection is not available.
- * @throws AmazonServiceException If an error response is returned by actual PayloadStore indicating
+ * @throws S3Exception If an error response is returned by actual PayloadStore indicating
* a server side issue.
*/
String getOriginalPayload(String payloadPointer);
@@ -40,10 +39,10 @@ public interface PayloadStore {
* have been obtained using {@link storeOriginalPayload}
*
* @param payloadPointer
- * @throws AmazonClientException If any internal errors are encountered on the client side while
+ * @throws SdkClientException If any internal errors are encountered on the client side while
* attempting to make the request or handle the response to/from PayloadStore.
* For example, if payloadPointer is invalid or a network connection is not available.
- * @throws AmazonServiceException If an error response is returned by actual PayloadStore indicating
+ * @throws S3Exception If an error response is returned by actual PayloadStore indicating
* a server side issue.
*/
void deleteOriginalPayload(String payloadPointer);
diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java
index 7fe7965..b8eb6c5 100644
--- a/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java
+++ b/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java
@@ -1,8 +1,7 @@
package software.amazon.payloadoffloading;
-import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.util.UUID;
@@ -10,29 +9,28 @@
* S3 based implementation for PayloadStore.
*/
public class S3BackedPayloadStore implements PayloadStore {
- private static final Log LOG = LogFactory.getLog(S3BackedPayloadStore.class);
+ private static final Logger LOG = LoggerFactory.getLogger(S3BackedPayloadStore.class);
private final String s3BucketName;
private final S3Dao s3Dao;
- private final SSEAwsKeyManagementParams sseAwsKeyManagementParams;
+ private final ServerSideEncryptionStrategy serverSideEncryptionStrategy;
public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName) {
this(s3Dao, s3BucketName, null);
}
- public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName,
- SSEAwsKeyManagementParams sseAwsKeyManagementParams) {
+ public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName, ServerSideEncryptionStrategy serverSideEncryptionStrategy) {
this.s3BucketName = s3BucketName;
this.s3Dao = s3Dao;
- this.sseAwsKeyManagementParams = sseAwsKeyManagementParams;
+ this.serverSideEncryptionStrategy = serverSideEncryptionStrategy;
}
@Override
- public String storeOriginalPayload(String payload, Long payloadContentSize) {
+ public String storeOriginalPayload(String payload) {
String s3Key = UUID.randomUUID().toString();
// Store the payload content in S3.
- s3Dao.storeTextInS3(s3BucketName, s3Key, sseAwsKeyManagementParams, payload, payloadContentSize);
+ s3Dao.storeTextInS3(s3BucketName, s3Key, serverSideEncryptionStrategy, payload);
LOG.info("S3 object created, Bucket name: " + s3BucketName + ", Object key: " + s3Key + ".");
// Convert S3 pointer (bucket name, key, etc) to JSON string
diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java
index a4c5c07..14d7d75 100644
--- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java
+++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java
@@ -1,111 +1,92 @@
package software.amazon.payloadoffloading;
-import com.amazonaws.AmazonClientException;
-import com.amazonaws.AmazonServiceException;
-import com.amazonaws.services.s3.AmazonS3;
-import com.amazonaws.services.s3.model.*;
-import com.amazonaws.util.IOUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import software.amazon.awssdk.core.ResponseInputStream;
+import software.amazon.awssdk.core.exception.SdkClientException;
+import software.amazon.awssdk.core.exception.SdkException;
+import software.amazon.awssdk.core.sync.RequestBody;
+import software.amazon.awssdk.services.s3.S3Client;
+import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
+import software.amazon.awssdk.services.s3.model.GetObjectRequest;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
+import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.utils.IoUtils;
-import java.io.ByteArrayInputStream;
import java.io.IOException;
-import java.io.InputStream;
-import java.nio.charset.StandardCharsets;
/**
* Dao layer to access S3.
*/
public class S3Dao {
- private static final Log LOG = LogFactory.getLog(S3Dao.class);
- private final AmazonS3 s3Client;
+ private static final Logger LOG = LoggerFactory.getLogger(S3Dao.class);
+ private final S3Client s3Client;
- public S3Dao(AmazonS3 s3Client) {
+ public S3Dao(S3Client s3Client) {
this.s3Client = s3Client;
}
public String getTextFromS3(String s3BucketName, String s3Key) {
- GetObjectRequest getObjectRequest = new GetObjectRequest(s3BucketName, s3Key);
- String embeddedText = null;
- S3Object obj = null;
+ GetObjectRequest getObjectRequest = GetObjectRequest.builder()
+ .bucket(s3BucketName)
+ .key(s3Key)
+ .build();
+ ResponseInputStream object = null;
try {
- obj = s3Client.getObject(getObjectRequest);
-
- } catch (AmazonServiceException e) {
- String errorMessage = "Failed to get the S3 object which contains the payload.";
- LOG.error(errorMessage, e);
- throw new AmazonServiceException(errorMessage, e);
-
- } catch (AmazonClientException e) {
+ object = s3Client.getObject(getObjectRequest);
+ } catch (SdkException e) {
String errorMessage = "Failed to get the S3 object which contains the payload.";
LOG.error(errorMessage, e);
- throw new AmazonClientException(errorMessage, e);
+ throw SdkException.create(errorMessage, e);
}
- S3ObjectInputStream is = obj.getObjectContent();
-
+ String embeddedText;
try {
- embeddedText = IOUtils.toString(is);
-
+ embeddedText = IoUtils.toUtf8String(object);
} catch (IOException e) {
String errorMessage = "Failure when handling the message which was read from S3 object.";
LOG.error(errorMessage, e);
- throw new AmazonClientException(errorMessage, e);
+ throw SdkClientException.create(errorMessage, e);
} finally {
- IOUtils.closeQuietly(is, LOG);
+ IoUtils.closeQuietly(object, LOG);
}
return embeddedText;
}
- public void storeTextInS3(String s3BucketName, String s3Key, SSEAwsKeyManagementParams sseAwsKeyManagementParams,
- String payloadContentStr, Long payloadContentSize) {
- InputStream payloadContentStream = new ByteArrayInputStream(payloadContentStr.getBytes(StandardCharsets.UTF_8));
- ObjectMetadata payloadContentStreamMetadata = new ObjectMetadata();
- payloadContentStreamMetadata.setContentLength(payloadContentSize);
- PutObjectRequest putObjectRequest = new PutObjectRequest(s3BucketName, s3Key,
- payloadContentStream, payloadContentStreamMetadata);
+ public void storeTextInS3(String s3BucketName, String s3Key, ServerSideEncryptionStrategy serverSideEncryptionStrategy, String payloadContentStr) {
+ PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder()
+ .bucket(s3BucketName)
+ .key(s3Key);
// https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html
- if (sseAwsKeyManagementParams != null) {
- LOG.debug("Using SSE-KMS in put object request.");
- putObjectRequest.setSSEAwsKeyManagementParams(sseAwsKeyManagementParams);
+ if (serverSideEncryptionStrategy != null) {
+ serverSideEncryptionStrategy.decorate(putObjectRequestBuilder);
}
try {
- s3Client.putObject(putObjectRequest);
-
- } catch (AmazonServiceException e) {
+ s3Client.putObject(putObjectRequestBuilder.build(), RequestBody.fromString(payloadContentStr));
+ } catch (SdkException e) {
String errorMessage = "Failed to store the message content in an S3 object.";
LOG.error(errorMessage, e);
- throw new AmazonServiceException(errorMessage, e);
-
- } catch (AmazonClientException e) {
- String errorMessage = "Failed to store the message content in an S3 object.";
- LOG.error(errorMessage, e);
- throw new AmazonClientException(errorMessage, e);
+ throw SdkException.create(errorMessage, e);
}
}
- public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr, Long payloadContentSize) {
- storeTextInS3(s3BucketName, s3Key, null, payloadContentStr, payloadContentSize);
- }
-
public void deletePayloadFromS3(String s3BucketName, String s3Key) {
try {
- s3Client.deleteObject(s3BucketName, s3Key);
-
- } catch (AmazonServiceException e) {
- String errorMessage = "Failed to delete the S3 object which contains the payload";
- LOG.error(errorMessage, e);
- throw new AmazonServiceException(errorMessage, e);
+ DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder()
+ .bucket(s3BucketName)
+ .key(s3Key)
+ .build();
+ s3Client.deleteObject(deleteObjectRequest);
- } catch (AmazonClientException e) {
+ } catch (SdkException e) {
String errorMessage = "Failed to delete the S3 object which contains the payload";
LOG.error(errorMessage, e);
- throw new AmazonClientException(errorMessage, e);
+ throw SdkException.create(errorMessage, e);
}
LOG.info("S3 object deleted, Bucket name: " + s3BucketName + ", Object key: " + s3Key + ".");
diff --git a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionFactory.java b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionFactory.java
new file mode 100644
index 0000000..eda5d72
--- /dev/null
+++ b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionFactory.java
@@ -0,0 +1,11 @@
+package software.amazon.payloadoffloading;
+
+public class ServerSideEncryptionFactory {
+ public static ServerSideEncryptionStrategy awsManagedCmk() {
+ return new AwsManagedCmk();
+ }
+
+ public static ServerSideEncryptionStrategy customerKey(String awsKmsKeyId) {
+ return new CustomerKey(awsKmsKeyId);
+ }
+}
diff --git a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java
new file mode 100644
index 0000000..f385ce6
--- /dev/null
+++ b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java
@@ -0,0 +1,7 @@
+package software.amazon.payloadoffloading;
+
+import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+
+public interface ServerSideEncryptionStrategy {
+ void decorate(PutObjectRequest.Builder putObjectRequestBuilder);
+}
diff --git a/src/main/java/software/amazon/payloadoffloading/Util.java b/src/main/java/software/amazon/payloadoffloading/Util.java
index 5e18fdc..bd3932f 100644
--- a/src/main/java/software/amazon/payloadoffloading/Util.java
+++ b/src/main/java/software/amazon/payloadoffloading/Util.java
@@ -1,22 +1,23 @@
package software.amazon.payloadoffloading;
-import com.amazonaws.AmazonClientException;
-import com.amazonaws.util.VersionInfoUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import software.amazon.awssdk.core.exception.SdkClientException;
+import software.amazon.awssdk.core.util.VersionInfo;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.Writer;
+import java.nio.charset.StandardCharsets;
public class Util {
- private static final Log LOG = LogFactory.getLog(Util.class);
+ private static final Logger LOG = LoggerFactory.getLogger(Util.class);
public static long getStringSizeInBytes(String str) {
CountingOutputStream counterOutputStream = new CountingOutputStream();
try {
- Writer writer = new OutputStreamWriter(counterOutputStream, "UTF-8");
+ Writer writer = new OutputStreamWriter(counterOutputStream, StandardCharsets.UTF_8);
writer.write(str);
writer.flush();
writer.close();
@@ -24,13 +25,13 @@ public static long getStringSizeInBytes(String str) {
} catch (IOException e) {
String errorMessage = "Failed to calculate the size of payload.";
LOG.error(errorMessage, e);
- throw new AmazonClientException(errorMessage, e);
+ throw SdkClientException.create(errorMessage, e);
}
return counterOutputStream.getTotalSize();
}
public static String getUserAgentHeader(String clientName) {
- return clientName + "/" + VersionInfoUtils.getVersion();
+ return clientName + "/" + VersionInfo.SDK_VERSION;
}
}
diff --git a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java
new file mode 100644
index 0000000..4bb95c5
--- /dev/null
+++ b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java
@@ -0,0 +1,21 @@
+package software.amazon.payloadoffloading;
+
+import org.junit.Test;
+import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
+
+import static org.junit.Assert.assertEquals;
+
+public class AwsManagedCmkTest {
+
+ @Test
+ public void testAwsManagedCmkStrategySetsCorrectEncryptionValues() {
+ AwsManagedCmk awsManagedCmk = new AwsManagedCmk();
+
+ PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder();
+ awsManagedCmk.decorate(putObjectRequestBuilder);
+ PutObjectRequest putObjectRequest = putObjectRequestBuilder.build();
+
+ assertEquals(putObjectRequest.serverSideEncryption(), (ServerSideEncryption.AWS_KMS));
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java
new file mode 100644
index 0000000..3da7090
--- /dev/null
+++ b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java
@@ -0,0 +1,24 @@
+package software.amazon.payloadoffloading;
+
+import org.junit.Test;
+import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
+
+import static org.junit.Assert.assertEquals;
+
+public class CustomerKeyTest {
+
+ public static final String AWS_KMS_KEY_ID = "123456";
+
+ @Test
+ public void testCustomerKeyStrategySetsCorrectEncryptionValues() {
+ CustomerKey customerKey = new CustomerKey(AWS_KMS_KEY_ID);
+
+ PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder();
+ customerKey.decorate(putObjectRequestBuilder);
+ PutObjectRequest putObjectRequest = putObjectRequestBuilder.build();
+
+ assertEquals(putObjectRequest.serverSideEncryption(), ServerSideEncryption.AWS_KMS);
+ assertEquals(putObjectRequest.ssekmsKeyId(), AWS_KMS_KEY_ID);
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java
index 2c51438..b1dad99 100644
--- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java
+++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java
@@ -1,9 +1,7 @@
package software.amazon.payloadoffloading;
-import com.amazonaws.services.s3.AmazonS3;
-import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
-import org.junit.Before;
import org.junit.Test;
+import software.amazon.awssdk.services.s3.S3Client;
import static org.mockito.Mockito.mock;
import static org.junit.Assert.*;
@@ -13,18 +11,13 @@
*/
public class PayloadStorageConfigurationTest {
- private static String s3BucketName = "test-bucket-name";
- private static String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id";
- private SSEAwsKeyManagementParams sseAwsKeyManagementParams;
-
- @Before
- public void setup() {
- sseAwsKeyManagementParams = new SSEAwsKeyManagementParams(s3ServerSideEncryptionKMSKeyId);
- }
+ private static final String s3BucketName = "test-bucket-name";
+ private static final String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id";
+ private static final ServerSideEncryptionStrategy SERVER_SIDE_ENCRYPTION_STRATEGY = ServerSideEncryptionFactory.awsManagedCmk();
@Test
public void testCopyConstructor() {
- AmazonS3 s3 = mock(AmazonS3.class);
+ S3Client s3 = mock(S3Client.class);
boolean alwaysThroughS3 = true;
int payloadSizeThreshold = 500;
@@ -32,15 +25,15 @@ public void testCopyConstructor() {
PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration();
payloadStorageConfiguration.withPayloadSupportEnabled(s3, s3BucketName)
- .withAlwaysThroughS3(alwaysThroughS3).withPayloadSizeThreshold(payloadSizeThreshold)
- .withSSEAwsKeyManagementParams(sseAwsKeyManagementParams);
+ .withAlwaysThroughS3(alwaysThroughS3)
+ .withPayloadSizeThreshold(payloadSizeThreshold)
+ .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_STRATEGY);
PayloadStorageConfiguration newPayloadStorageConfiguration = new PayloadStorageConfiguration(payloadStorageConfiguration);
- assertEquals(s3, newPayloadStorageConfiguration.getAmazonS3Client());
+ assertEquals(s3, newPayloadStorageConfiguration.getS3Client());
assertEquals(s3BucketName, newPayloadStorageConfiguration.getS3BucketName());
- assertEquals(sseAwsKeyManagementParams, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams());
- assertEquals(s3ServerSideEncryptionKMSKeyId, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams().getAwsKmsKeyId());
+ assertEquals(SERVER_SIDE_ENCRYPTION_STRATEGY, newPayloadStorageConfiguration.getServerSideEncryptionStrategy());
assertTrue(newPayloadStorageConfiguration.isPayloadSupportEnabled());
assertEquals(alwaysThroughS3, newPayloadStorageConfiguration.isAlwaysThroughS3());
assertEquals(payloadSizeThreshold, newPayloadStorageConfiguration.getPayloadSizeThreshold());
@@ -49,12 +42,12 @@ public void testCopyConstructor() {
@Test
public void testPayloadSupportEnabled() {
- AmazonS3 s3 = mock(AmazonS3.class);
+ S3Client s3 = mock(S3Client.class);
PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration();
payloadStorageConfiguration.setPayloadSupportEnabled(s3, s3BucketName);
assertTrue(payloadStorageConfiguration.isPayloadSupportEnabled());
- assertNotNull(payloadStorageConfiguration.getAmazonS3Client());
+ assertNotNull(payloadStorageConfiguration.getS3Client());
assertEquals(s3BucketName, payloadStorageConfiguration.getS3BucketName());
}
@@ -63,7 +56,7 @@ public void testDisablePayloadSupport() {
PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration();
payloadStorageConfiguration.setPayloadSupportDisabled();
- assertNull(payloadStorageConfiguration.getAmazonS3Client());
+ assertNull(payloadStorageConfiguration.getS3Client());
assertNull(payloadStorageConfiguration.getS3BucketName());
}
@@ -82,10 +75,9 @@ public void testAlwaysThroughS3() {
public void testSseAwsKeyManagementParams() {
PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration();
- assertNull(payloadStorageConfiguration.getSSEAwsKeyManagementParams());
+ assertNull(payloadStorageConfiguration.getServerSideEncryptionStrategy());
- payloadStorageConfiguration.setSSEAwsKeyManagementParams(sseAwsKeyManagementParams);
- assertEquals(s3ServerSideEncryptionKMSKeyId, payloadStorageConfiguration.getSSEAwsKeyManagementParams()
- .getAwsKmsKeyId());
+ payloadStorageConfiguration.setServerSideEncryptionStrategy(SERVER_SIDE_ENCRYPTION_STRATEGY);
+ assertEquals(SERVER_SIDE_ENCRYPTION_STRATEGY, payloadStorageConfiguration.getServerSideEncryptionStrategy());
}
}
diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java
index f6bf2dc..e9e12c1 100644
--- a/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java
+++ b/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java
@@ -1,7 +1,5 @@
package software.amazon.payloadoffloading;
-import com.amazonaws.AmazonClientException;
-import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.hamcrest.Matchers;
@@ -11,14 +9,21 @@
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
+import software.amazon.awssdk.core.exception.SdkClientException;
+import software.amazon.awssdk.core.exception.SdkException;
+
+import java.util.Objects;
import static org.junit.Assert.*;
+import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
@RunWith(JUnitParamsRunner.class)
public class S3BackedPayloadStoreTest {
private static final String S3_BUCKET_NAME = "test-bucket-name";
private static final String S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID = "test-customer-managed-kms-key-id";
+ private static final ServerSideEncryptionStrategy KMS_WITH_CUSTOMER_KEY = ServerSideEncryptionFactory.customerKey(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID);
+ private static final ServerSideEncryptionStrategy KMS_WITH_AWS_MANAGED_CMK = ServerSideEncryptionFactory.awsManagedCmk();
private static final String ANY_PAYLOAD = "AnyPayload";
private static final String ANY_S3_KEY = "AnyS3key";
private static final String INCORRECT_POINTER_EXCEPTION_MSG = "Failed to read the S3 object pointer from given string";
@@ -53,15 +58,14 @@ private Object[] testData() {
},
// S3 SSE-KMS encryption with AWS managed KMS keys
{
- new S3BackedPayloadStore(defaultEncryptionS3Dao, S3_BUCKET_NAME, new SSEAwsKeyManagementParams()),
- new SSEAwsKeyManagementParams(),
+ new S3BackedPayloadStore(defaultEncryptionS3Dao, S3_BUCKET_NAME, KMS_WITH_AWS_MANAGED_CMK),
+ KMS_WITH_AWS_MANAGED_CMK,
defaultEncryptionS3Dao
},
// S3 SSE-KMS encryption with customer managed KMS key
{
- new S3BackedPayloadStore(customerKMSKeyEncryptionS3Dao, S3_BUCKET_NAME,
- new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID)),
- new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID),
+ new S3BackedPayloadStore(customerKMSKeyEncryptionS3Dao, S3_BUCKET_NAME, KMS_WITH_CUSTOMER_KEY),
+ KMS_WITH_CUSTOMER_KEY,
customerKMSKeyEncryptionS3Dao
}
};
@@ -69,15 +73,14 @@ private Object[] testData() {
@Test
@Parameters(method = "testData")
- public void testStoreOriginalPayloadOnSuccess(PayloadStore payloadStore,
- SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) {
- String actualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH);
+ public void testStoreOriginalPayloadOnSuccess(PayloadStore payloadStore, ServerSideEncryptionStrategy expectedParams, S3Dao mockS3Dao) {
+ String actualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD);
ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class);
- ArgumentCaptor sseArgsCaptor = ArgumentCaptor.forClass(SSEAwsKeyManagementParams.class);
+ ArgumentCaptor sseArgsCaptor = ArgumentCaptor.forClass(ServerSideEncryptionStrategy.class);
verify(mockS3Dao, times(1)).storeTextInS3(eq(S3_BUCKET_NAME), keyCaptor.capture(),
- sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));
+ sseArgsCaptor.capture(), eq(ANY_PAYLOAD));
PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, keyCaptor.getValue());
assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer);
@@ -85,30 +88,26 @@ public void testStoreOriginalPayloadOnSuccess(PayloadStore payloadStore,
if (expectedParams == null) {
assertTrue(sseArgsCaptor.getValue() == null);
} else {
- assertEquals(expectedParams.getAwsKmsKeyId(), sseArgsCaptor.getValue().getAwsKmsKeyId());
+ assertEquals(expectedParams, sseArgsCaptor.getValue());
}
}
@Test
@Parameters(method = "testData")
public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payloadStore,
- SSEAwsKeyManagementParams expectedParams,
+ ServerSideEncryptionStrategy expectedParams,
S3Dao mockS3Dao) {
//Store any payload
- String anyActualPayloadPointer = payloadStore
- .storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH);
+ String anyActualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD);
//Store any other payload and validate that the pointers are different
- String anyOtherActualPayloadPointer = payloadStore
- .storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH);
+ String anyOtherActualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD);
ArgumentCaptor anyOtherKeyCaptor = ArgumentCaptor.forClass(String.class);
-
- ArgumentCaptor sseArgsCaptor = ArgumentCaptor
- .forClass(SSEAwsKeyManagementParams.class);
+ ArgumentCaptor sseArgsCaptor = ArgumentCaptor.forClass(ServerSideEncryptionStrategy.class);
verify(mockS3Dao, times(2)).storeTextInS3(eq(S3_BUCKET_NAME), anyOtherKeyCaptor.capture(),
- sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));
+ sseArgsCaptor.capture(), eq(ANY_PAYLOAD));
String anyS3Key = anyOtherKeyCaptor.getAllValues().get(0);
String anyOtherS3Key = anyOtherKeyCaptor.getAllValues().get(1);
@@ -123,31 +122,29 @@ public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payl
assertThat(anyActualPayloadPointer, Matchers.not(anyOtherActualPayloadPointer));
if (expectedParams == null) {
- assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams -> actualParams == null));
+ assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(Objects::isNull));
} else {
assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams ->
- (actualParams.getAwsKmsKeyId() == null && expectedParams.getAwsKmsKeyId() == null)
- || (actualParams.getAwsKmsKeyId().equals(expectedParams.getAwsKmsKeyId()))));
+ actualParams.equals(expectedParams)));
}
}
@Test
@Parameters(method = "testData")
- public void testStoreOriginalPayloadOnS3Failure(PayloadStore payloadStore,
- SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) {
- doThrow(new AmazonClientException("S3 Exception"))
+ public void testStoreOriginalPayloadOnS3Failure(PayloadStore payloadStore, ServerSideEncryptionStrategy awsKmsKeyId, S3Dao mockS3Dao) {
+ doThrow(SdkException.create("S3 Exception", new Throwable()))
.when(mockS3Dao)
.storeTextInS3(
any(String.class),
any(String.class),
- expectedParams == null ? isNull() : any(SSEAwsKeyManagementParams.class),
- any(String.class),
- any(Long.class));
+ // Can be String or null
+ any(),
+ any(String.class));
- exception.expect(AmazonClientException.class);
+ exception.expect(SdkException.class);
exception.expectMessage("S3 Exception");
//Any S3 Dao exception is thrown back as-is to clients
- payloadStore.storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH);
+ payloadStore.storeOriginalPayload(ANY_PAYLOAD);
}
@Test
@@ -167,7 +164,7 @@ public void testGetOriginalPayloadOnSuccess() {
@Test
public void testGetOriginalPayloadIncorrectPointer() {
- exception.expect(AmazonClientException.class);
+ exception.expect(SdkClientException.class);
exception.expectMessage(INCORRECT_POINTER_EXCEPTION_MSG);
//Any S3 Dao exception is thrown back as-is to clients
payloadStore.getOriginalPayload("IncorrectPointer");
@@ -176,8 +173,8 @@ public void testGetOriginalPayloadIncorrectPointer() {
@Test
public void testGetOriginalPayloadOnS3Failure() {
- when(s3Dao.getTextFromS3(any(String.class), any(String.class))).thenThrow(new AmazonClientException("S3 Exception"));
- exception.expect(AmazonClientException.class);
+ when(s3Dao.getTextFromS3(any(String.class), any(String.class))).thenThrow(SdkException.create("S3 Exception", new Throwable()));
+ exception.expect(SdkException.class);
exception.expectMessage("S3 Exception");
//Any S3 Dao exception is thrown back as-is to clients
PayloadS3Pointer anyPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY);
@@ -199,7 +196,7 @@ public void testDeleteOriginalPayloadOnSuccess() {
@Test
public void testDeleteOriginalPayloadIncorrectPointer() {
- exception.expect(AmazonClientException.class);
+ exception.expect(SdkClientException.class);
exception.expectMessage(INCORRECT_POINTER_EXCEPTION_MSG);
payloadStore.deleteOriginalPayload("IncorrectPointer");
verifyNoInteractions(s3Dao);