Skip to content

Cross account support #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ You can download release builds through the [releases section of this](https://g
<dependency>
<groupId>software.amazon.payloadoffloading</groupId>
<artifactId>payloadoffloading-common</artifactId>
<version>1.0.0</version>
<version>1.1.0</version>
<type>jar</type>
</dependency>
```
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>software.amazon.payloadoffloading</groupId>
<artifactId>payloadoffloading-common</artifactId>
<version>1.0.0</version>
<version>1.1.0</version>
<packaging>jar</packaging>
<name>Payload offloading common library for AWS</name>
<description>Common library between extended Amazon AWS clients to save payloads up to 2GB on Amazon S3.</description>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.amazonaws.AmazonClientException;
import com.amazonaws.annotation.NotThreadSafe;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.CannedAccessControlList;
import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -20,6 +21,10 @@ public class PayloadStorageConfiguration {
private int payloadSizeThreshold = 0;
private boolean alwaysThroughS3 = false;
private boolean payloadSupport = false;
/**
* This field is optional, it is set only when we want to add access control list to Amazon S3 buckets and objects
*/
private CannedAccessControlList cannedAccessControlList;
/**
* This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS.
*/
Expand All @@ -29,6 +34,7 @@ public PayloadStorageConfiguration() {
s3 = null;
s3BucketName = null;
sseAwsKeyManagementParams = null;
cannedAccessControlList = null;
}

public PayloadStorageConfiguration(PayloadStorageConfiguration other) {
Expand All @@ -38,6 +44,7 @@ public PayloadStorageConfiguration(PayloadStorageConfiguration other) {
this.payloadSupport = other.isPayloadSupportEnabled();
this.alwaysThroughS3 = other.isAlwaysThroughS3();
this.payloadSizeThreshold = other.getPayloadSizeThreshold();
this.cannedAccessControlList = other.cannedAccessControlList;
}

/**
Expand Down Expand Up @@ -212,4 +219,39 @@ public boolean isAlwaysThroughS3() {
public void setAlwaysThroughS3(boolean alwaysThroughS3) {
this.alwaysThroughS3 = alwaysThroughS3;
}

/**
* Configures the ACL to apply to the Amazon S3 putObject request.
* @param cannedAccessControlList
* The ACL to be used when storing objects in Amazon S3
*/
public void setCannedAccessControlList(CannedAccessControlList cannedAccessControlList) {
this.cannedAccessControlList = cannedAccessControlList;
}

/**
* Configures the ACL to apply to the Amazon S3 putObject request.
* @param cannedAccessControlList
* The ACL to be used when storing objects in Amazon S3
*/
public PayloadStorageConfiguration withCannedAccessControlList(CannedAccessControlList cannedAccessControlList) {
setCannedAccessControlList(cannedAccessControlList);
return this;
}

/**
* Checks whether an ACL have been configured for storing objects in Amazon S3.
* @return True if ACL is defined
*/
public boolean isCannedAccessControlListDefined() {
return null != cannedAccessControlList;
}

/**
* Gets the AWS ACL to apply to the Amazon S3 putObject request.
* @return Amazon S3 object ACL
*/
public CannedAccessControlList getCannedAccessControlList() {
return cannedAccessControlList;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package software.amazon.payloadoffloading;

import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

Expand All @@ -14,25 +13,18 @@ public class S3BackedPayloadStore implements PayloadStore {

private final String s3BucketName;
private final S3Dao s3Dao;
private final SSEAwsKeyManagementParams sseAwsKeyManagementParams;

public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName) {
this(s3Dao, s3BucketName, null);
}

public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName,
SSEAwsKeyManagementParams sseAwsKeyManagementParams) {
this.s3BucketName = s3BucketName;
this.s3Dao = s3Dao;
this.sseAwsKeyManagementParams = sseAwsKeyManagementParams;
}

@Override
public String storeOriginalPayload(String payload, Long payloadContentSize) {
String s3Key = UUID.randomUUID().toString();

// Store the payload content in S3.
s3Dao.storeTextInS3(s3BucketName, s3Key, sseAwsKeyManagementParams, payload, payloadContentSize);
s3Dao.storeTextInS3(s3BucketName, s3Key, payload, payloadContentSize);
LOG.info("S3 object created, Bucket name: " + s3BucketName + ", Object key: " + s3Key + ".");

// Convert S3 pointer (bucket name, key, etc) to JSON string
Expand Down
19 changes: 13 additions & 6 deletions src/main/java/software/amazon/payloadoffloading/S3Dao.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@
public class S3Dao {
private static final Log LOG = LogFactory.getLog(S3Dao.class);
private final AmazonS3 s3Client;
private final SSEAwsKeyManagementParams sseAwsKeyManagementParams;
private final CannedAccessControlList cannedAccessControlList;

public S3Dao(AmazonS3 s3Client) {
this(s3Client, null, null);
}

public S3Dao(AmazonS3 s3Client, SSEAwsKeyManagementParams sseAwsKeyManagementParams, CannedAccessControlList cannedAccessControlList) {
this.s3Client = s3Client;
this.sseAwsKeyManagementParams = sseAwsKeyManagementParams;
this.cannedAccessControlList = cannedAccessControlList;
}

public String getTextFromS3(String s3BucketName, String s3Key) {
Expand Down Expand Up @@ -60,14 +68,17 @@ public String getTextFromS3(String s3BucketName, String s3Key) {
return embeddedText;
}

public void storeTextInS3(String s3BucketName, String s3Key, SSEAwsKeyManagementParams sseAwsKeyManagementParams,
String payloadContentStr, Long payloadContentSize) {
public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr, Long payloadContentSize) {
InputStream payloadContentStream = new ByteArrayInputStream(payloadContentStr.getBytes(StandardCharsets.UTF_8));
ObjectMetadata payloadContentStreamMetadata = new ObjectMetadata();
payloadContentStreamMetadata.setContentLength(payloadContentSize);
PutObjectRequest putObjectRequest = new PutObjectRequest(s3BucketName, s3Key,
payloadContentStream, payloadContentStreamMetadata);

if (cannedAccessControlList != null) {
putObjectRequest.withCannedAcl(cannedAccessControlList);
}

// https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html
if (sseAwsKeyManagementParams != null) {
LOG.debug("Using SSE-KMS in put object request.");
Expand All @@ -89,10 +100,6 @@ public void storeTextInS3(String s3BucketName, String s3Key, SSEAwsKeyManagement
}
}

public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr, Long payloadContentSize) {
storeTextInS3(s3BucketName, s3Key, null, payloadContentStr, payloadContentSize);
}

public void deletePayloadFromS3(String s3BucketName, String s3Key) {
try {
s3Client.deleteObject(s3BucketName, s3Key);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package software.amazon.payloadoffloading;

import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.CannedAccessControlList;
import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -16,10 +17,12 @@ public class PayloadStorageConfigurationTest {
private static String s3BucketName = "test-bucket-name";
private static String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id";
private SSEAwsKeyManagementParams sseAwsKeyManagementParams;
private CannedAccessControlList cannedAccessControlList;

@Before
public void setup() {
sseAwsKeyManagementParams = new SSEAwsKeyManagementParams(s3ServerSideEncryptionKMSKeyId);
cannedAccessControlList = CannedAccessControlList.BucketOwnerFullControl;
}

@Test
Expand All @@ -33,14 +36,16 @@ public void testCopyConstructor() {

payloadStorageConfiguration.withPayloadSupportEnabled(s3, s3BucketName)
.withAlwaysThroughS3(alwaysThroughS3).withPayloadSizeThreshold(payloadSizeThreshold)
.withSSEAwsKeyManagementParams(sseAwsKeyManagementParams);
.withSSEAwsKeyManagementParams(sseAwsKeyManagementParams)
.withCannedAccessControlList(cannedAccessControlList);

PayloadStorageConfiguration newPayloadStorageConfiguration = new PayloadStorageConfiguration(payloadStorageConfiguration);

assertEquals(s3, newPayloadStorageConfiguration.getAmazonS3Client());
assertEquals(s3BucketName, newPayloadStorageConfiguration.getS3BucketName());
assertEquals(sseAwsKeyManagementParams, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams());
assertEquals(s3ServerSideEncryptionKMSKeyId, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams().getAwsKmsKeyId());
assertEquals(cannedAccessControlList, newPayloadStorageConfiguration.getCannedAccessControlList());
assertTrue(newPayloadStorageConfiguration.isPayloadSupportEnabled());
assertEquals(alwaysThroughS3, newPayloadStorageConfiguration.isAlwaysThroughS3());
assertEquals(payloadSizeThreshold, newPayloadStorageConfiguration.getPayloadSizeThreshold());
Expand Down Expand Up @@ -88,4 +93,16 @@ public void testSseAwsKeyManagementParams() {
assertEquals(s3ServerSideEncryptionKMSKeyId, payloadStorageConfiguration.getSSEAwsKeyManagementParams()
.getAwsKmsKeyId());
}

@Test
public void testCannedAccessControlList() {

PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration();

assertFalse(payloadStorageConfiguration.isCannedAccessControlListDefined());

payloadStorageConfiguration.withCannedAccessControlList(cannedAccessControlList);
assertTrue(payloadStorageConfiguration.isCannedAccessControlListDefined());
assertEquals(cannedAccessControlList, payloadStorageConfiguration.getCannedAccessControlList());
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package software.amazon.payloadoffloading;

import com.amazonaws.AmazonClientException;
import com.amazonaws.services.s3.model.CannedAccessControlList;
import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.hamcrest.Matchers;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -35,65 +35,23 @@ public void setup() {
payloadStore = new S3BackedPayloadStore(s3Dao, S3_BUCKET_NAME);
}

private Object[] testData() {
// Here, we create separate mock of S3Dao because JUnitParamsRunner collects parameters
// for tests well before invocation of @Before or @BeforeClass methods.
// That means our default s3Dao mock isn't instantiated until then. For parameterized tests,
// we instantiate our local S3Dao mock per combination, pass it to S3BackedPayloadStore and also pass it
// as test parameter to allow verifying calls to the mockS3Dao.
S3Dao noEncryptionS3Dao = mock(S3Dao.class);
S3Dao defaultEncryptionS3Dao = mock(S3Dao.class);
S3Dao customerKMSKeyEncryptionS3Dao = mock(S3Dao.class);
return new Object[][]{
// No S3 SSE-KMS encryption
{
new S3BackedPayloadStore(noEncryptionS3Dao, S3_BUCKET_NAME),
null,
noEncryptionS3Dao
},
// S3 SSE-KMS encryption with AWS managed KMS keys
{
new S3BackedPayloadStore(defaultEncryptionS3Dao, S3_BUCKET_NAME, new SSEAwsKeyManagementParams()),
new SSEAwsKeyManagementParams(),
defaultEncryptionS3Dao
},
// S3 SSE-KMS encryption with customer managed KMS key
{
new S3BackedPayloadStore(customerKMSKeyEncryptionS3Dao, S3_BUCKET_NAME,
new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID)),
new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID),
customerKMSKeyEncryptionS3Dao
}
};
}

@Test
@Parameters(method = "testData")
public void testStoreOriginalPayloadOnSuccess(PayloadStore payloadStore,
SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) {
public void testStoreOriginalPayloadOnSuccess() {
String actualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH);

ArgumentCaptor<String> keyCaptor = ArgumentCaptor.forClass(String.class);
ArgumentCaptor<SSEAwsKeyManagementParams> sseArgsCaptor = ArgumentCaptor.forClass(SSEAwsKeyManagementParams.class);
ArgumentCaptor<CannedAccessControlList> cannedArgsCaptor = ArgumentCaptor.forClass(CannedAccessControlList.class);

verify(mockS3Dao, times(1)).storeTextInS3(eq(S3_BUCKET_NAME), keyCaptor.capture(),
sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));
verify(s3Dao, times(1)).storeTextInS3(eq(S3_BUCKET_NAME), keyCaptor.capture(),
eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));

PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, keyCaptor.getValue());
assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer);

if (expectedParams == null) {
assertTrue(sseArgsCaptor.getValue() == null);
} else {
assertEquals(expectedParams.getAwsKmsKeyId(), sseArgsCaptor.getValue().getAwsKmsKeyId());
}
}

@Test
@Parameters(method = "testData")
public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payloadStore,
SSEAwsKeyManagementParams expectedParams,
S3Dao mockS3Dao) {
public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects() {
//Store any payload
String anyActualPayloadPointer = payloadStore
.storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH);
Expand All @@ -104,11 +62,8 @@ public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payl

ArgumentCaptor<String> anyOtherKeyCaptor = ArgumentCaptor.forClass(String.class);

ArgumentCaptor<SSEAwsKeyManagementParams> sseArgsCaptor = ArgumentCaptor
.forClass(SSEAwsKeyManagementParams.class);

verify(mockS3Dao, times(2)).storeTextInS3(eq(S3_BUCKET_NAME), anyOtherKeyCaptor.capture(),
sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));
verify(s3Dao, times(2)).storeTextInS3(eq(S3_BUCKET_NAME), anyOtherKeyCaptor.capture(),
eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));

String anyS3Key = anyOtherKeyCaptor.getAllValues().get(0);
String anyOtherS3Key = anyOtherKeyCaptor.getAllValues().get(1);
Expand All @@ -121,26 +76,15 @@ public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payl

assertThat(anyS3Key, Matchers.not(anyOtherS3Key));
assertThat(anyActualPayloadPointer, Matchers.not(anyOtherActualPayloadPointer));

if (expectedParams == null) {
assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams -> actualParams == null));
} else {
assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams ->
(actualParams.getAwsKmsKeyId() == null && expectedParams.getAwsKmsKeyId() == null)
|| (actualParams.getAwsKmsKeyId().equals(expectedParams.getAwsKmsKeyId()))));
}
}

@Test
@Parameters(method = "testData")
public void testStoreOriginalPayloadOnS3Failure(PayloadStore payloadStore,
SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) {
public void testStoreOriginalPayloadOnS3Failure() {
doThrow(new AmazonClientException("S3 Exception"))
.when(mockS3Dao)
.when(s3Dao)
.storeTextInS3(
any(String.class),
any(String.class),
expectedParams == null ? isNull() : any(SSEAwsKeyManagementParams.class),
any(String.class),
any(Long.class));

Expand Down
Loading