Skip to content

add support to set a prefix for the S3 key #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -965,7 +966,7 @@ private SendMessageBatchRequestEntry storeMessageInS3(SendMessageBatchRequestEnt
updateMessageAttributePayloadSize(batchEntry.messageAttributes(), messageContentSize));

// Store the message content in S3.
String largeMessagePointer = payloadStore.storeOriginalPayload(messageContentStr);
String largeMessagePointer = storeOriginalPayload(messageContentStr);
batchEntryBuilder.messageBody(largeMessagePointer);

return batchEntryBuilder.build();
Expand All @@ -984,12 +985,20 @@ private SendMessageRequest storeMessageInS3(SendMessageRequest sendMessageReques
updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), messageContentSize));

// Store the message content in S3.
String largeMessagePointer = payloadStore.storeOriginalPayload(messageContentStr);
String largeMessagePointer = storeOriginalPayload(messageContentStr);
sendMessageRequestBuilder.messageBody(largeMessagePointer);

return sendMessageRequestBuilder.build();
}

private String storeOriginalPayload(String messageContentStr) {
String s3KeyPrefix = clientConfiguration.getS3KeyPrefix();
if (StringUtils.isBlank(s3KeyPrefix)) {
return payloadStore.storeOriginalPayload(messageContentStr);
}
return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID());
}

private Map<String, MessageAttributeValue> updateMessageAttributePayloadSize(
Map<String, MessageAttributeValue> messageAttributes, Long messageContentSize) {
Map<String, MessageAttributeValue> updatedMessageAttributes = new HashMap<>(messageAttributes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,36 @@

package com.amazon.sqs.javamessaging;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import software.amazon.awssdk.annotations.NotThreadSafe;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
import software.amazon.awssdk.utils.StringUtils;
import software.amazon.payloadoffloading.PayloadStorageConfiguration;
import software.amazon.payloadoffloading.ServerSideEncryptionStrategy;

import java.util.regex.Pattern;


/**
* Amazon SQS extended client configuration options such as Amazon S3 client,
* bucket name, and message size threshold for large-payload messages.
*/
@NotThreadSafe
public class ExtendedClientConfiguration extends PayloadStorageConfiguration {
private static final Log LOG = LogFactory.getLog(ExtendedClientConfiguration.class);

private static final int UUID_LENGTH = 36;
private static final int MAX_S3_KEY_LENGTH = 1024;
private static final int MAX_S3_KEY_PREFIX_LENGTH = MAX_S3_KEY_LENGTH - UUID_LENGTH;
private static final Pattern INVALID_S3_PREFIX_KEY_CHARACTERS_PATTERN = Pattern.compile("[^a-zA-Z0-9./_-]");

private boolean cleanupS3Payload = true;
private boolean useLegacyReservedAttributeName = true;
private boolean ignorePayloadNotFound = false;
private String s3KeyPrefix = "";

public ExtendedClientConfiguration() {
super();
Expand All @@ -43,6 +56,7 @@ public ExtendedClientConfiguration(ExtendedClientConfiguration other) {
this.cleanupS3Payload = other.doesCleanupS3Payload();
this.useLegacyReservedAttributeName = other.usesLegacyReservedAttributeName();
this.ignorePayloadNotFound = other.ignoresPayloadNotFound();
this.s3KeyPrefix = other.s3KeyPrefix;
}

/**
Expand Down Expand Up @@ -128,6 +142,63 @@ public ExtendedClientConfiguration withIgnorePayloadNotFound(boolean ignorePaylo
return this;
}

/**
* Sets a string that will be used as prefix of the S3 Key.
*
* @param s3KeyPrefix
* A S3 key prefix value
*/
public void setS3KeyPrefix(String s3KeyPrefix) {
String trimmedPrefix = StringUtils.trimToEmpty(s3KeyPrefix);

if (trimmedPrefix.length() > MAX_S3_KEY_PREFIX_LENGTH) {
String errorMessage = "The S3 key prefix length must not be greater than " + MAX_S3_KEY_PREFIX_LENGTH;
LOG.error(errorMessage);
throw SdkClientException.create(errorMessage);
}

if (trimmedPrefix.startsWith(".") || trimmedPrefix.startsWith("/")) {
String errorMessage = "The S3 key prefix must not starts with '.' or '/'";
LOG.error(errorMessage);
throw SdkClientException.create(errorMessage);
}

if (trimmedPrefix.contains("..")) {
String errorMessage = "The S3 key prefix must not contains the string '..'";
LOG.error(errorMessage);
throw SdkClientException.create(errorMessage);
}

if (INVALID_S3_PREFIX_KEY_CHARACTERS_PATTERN.matcher(trimmedPrefix).find()) {
String errorMessage = "The S3 key prefix contain invalid characters. The allowed characters are: letters, digits, '/', '_', '-', and '.'";
LOG.error(errorMessage);
throw SdkClientException.create(errorMessage);
}

this.s3KeyPrefix = trimmedPrefix;
}

/**
* Sets a string that will be used as prefix of the S3 Key.
*
* @param s3KeyPrefix
* A S3 key prefix value
*
* @return the updated ExtendedClientConfiguration object.
*/
public ExtendedClientConfiguration withS3KeyPrefix(String s3KeyPrefix) {
setS3KeyPrefix(s3KeyPrefix);
return this;
}

/**
* Gets the S3 key prefix
* @return the prefix value which is being used for compose the S3 key.
*/
public String getS3KeyPrefix() {
return this.s3KeyPrefix;
}

/**
* Checks whether or not clean up large objects in S3 is enabled.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@

package com.amazon.sqs.javamessaging;

import static com.amazon.sqs.javamessaging.StringTestUtil.generateStringWithLength;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.core.ApiName;
import software.amazon.awssdk.core.ResponseInputStream;
Expand Down Expand Up @@ -62,9 +66,11 @@
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
Expand All @@ -82,11 +88,16 @@ public class AmazonSQSExtendedClientTest {
private SqsClient extendedSqsWithDefaultKMS;
private SqsClient extendedSqsWithGenericReservedAttributeName;
private SqsClient extendedSqsWithDeprecatedMethods;
private SqsClient extendedSqsWithS3KeyPrefix;
private SqsClient mockSqsBackend;
private S3Client mockS3;

private MockedStatic<UUID> uuidMockStatic;
private static final String S3_BUCKET_NAME = "test-bucket-name";
private static final String SQS_QUEUE_URL = "test-queue-url";
private static final String S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID = "test-customer-managed-kms-key-id";
private static final String S3_KEY_PREFIX = "test-s3-key-prefix";
private static final String S3_KEY_UUID = "test-s3-key-uuid";

private static final int LESS_THAN_SQS_SIZE_LIMIT = 3;
private static final int SQS_SIZE_LIMIT = 262144;
Expand All @@ -101,6 +112,7 @@ public class AmazonSQSExtendedClientTest {

@BeforeEach
public void setupClients() {
uuidMockStatic = mockStatic(UUID.class);
mockS3 = mock(S3Client.class);
mockSqsBackend = mock(SqsClient.class);
when(mockS3.putObject(isA(PutObjectRequest.class), isA(RequestBody.class))).thenReturn(null);
Expand All @@ -121,11 +133,25 @@ public void setupClients() {

ExtendedClientConfiguration extendedClientConfigurationDeprecated = new ExtendedClientConfiguration().withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME);

ExtendedClientConfiguration extendedClientConfigurationWithS3KeyPrefix = new ExtendedClientConfiguration()
.withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
.withS3KeyPrefix(S3_KEY_PREFIX);

UUID uuidMock = mock(UUID.class);
when(uuidMock.toString()).thenReturn(S3_KEY_UUID);
uuidMockStatic.when(UUID::randomUUID).thenReturn(uuidMock);

extendedSqsWithDefaultConfig = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration));
extendedSqsWithCustomKMS = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationWithCustomKMS));
extendedSqsWithDefaultKMS = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationWithDefaultKMS));
extendedSqsWithGenericReservedAttributeName = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationWithGenericReservedAttributeName));
extendedSqsWithDeprecatedMethods = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationDeprecated));
extendedSqsWithS3KeyPrefix = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationWithS3KeyPrefix));
}

@AfterEach
public void tearDown() {
uuidMockStatic.close();
}

@Test
Expand Down Expand Up @@ -617,6 +643,32 @@ public void testWhenSendMessageWIthCannedAccessControlListDefined() {
assertEquals(expected, captor.getValue().acl());
}

@Test
public void testWhenSendLargeMessageWithS3PrefixKeyDefined() {
String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);

SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();

extendedSqsWithS3KeyPrefix.sendMessage(messageRequest);

verify(mockS3, times(1)).putObject(
argThat((PutObjectRequest obj) -> obj.key().equals(S3_KEY_PREFIX + S3_KEY_UUID)),
isA(RequestBody.class));
}

@Test
public void testWhenSendLargeMessageWithUndefinedS3PrefixKey() {
String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);

SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();

extendedSqsWithDefaultConfig.sendMessage(messageRequest);

verify(mockS3, times(1)).putObject(
argThat((PutObjectRequest obj) -> obj.key().equals(S3_KEY_UUID)),
isA(RequestBody.class));
}

private void testReceiveMessage_when_MessageIsLarge(String reservedAttributeName) {
String pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "S3Key").toJson();
Message message = Message.builder()
Expand Down Expand Up @@ -665,11 +717,4 @@ private String getLargeReceiptHandle(String s3Key, String originalReceiptHandle)
private String getSampleLargeReceiptHandle(String originalReceiptHandle) {
return getLargeReceiptHandle(UUID.randomUUID().toString(), originalReceiptHandle);
}

private String generateStringWithLength(int messageLength) {
char[] charArray = new char[messageLength];
Arrays.fill(charArray, 'x');
return new String(charArray);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

package com.amazon.sqs.javamessaging;

import static com.amazon.sqs.javamessaging.StringTestUtil.generateStringWithLength;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.payloadoffloading.ServerSideEncryptionFactory;
import software.amazon.payloadoffloading.ServerSideEncryptionStrategy;
Expand All @@ -24,6 +29,7 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -164,4 +170,58 @@ public void testMessageSizeThreshold() {
assertEquals(messageLength, extendedClientConfiguration.getPayloadSizeThreshold());

}

@ParameterizedTest
@ValueSource(strings = {
"test-s3-key-prefix",
"TEST-S3-KEY-PREFIX",
"test.s3.key.prefix",
"test_s3_key_prefix",
"test/s3/key/prefix/"
})
public void testS3keyPrefix(String s3KeyPrefix) {
ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration();

extendedClientConfiguration.withS3KeyPrefix(s3KeyPrefix);

assertEquals(s3KeyPrefix, extendedClientConfiguration.getS3KeyPrefix());
}

@Test
public void testTrimS3keyPrefix() {
String s3KeyPrefix = "test-s3-key-prefix";
ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration();

extendedClientConfiguration.withS3KeyPrefix(String.format(" %s ", s3KeyPrefix));

assertEquals(s3KeyPrefix, extendedClientConfiguration.getS3KeyPrefix());
}

@ParameterizedTest
@ValueSource(strings = {
".test-s3-key-prefix",
"./test-s3-key-prefix",
"../test-s3-key-prefix",
"/test-s3-key-prefix",
"test..s3..key..prefix",
"test-s3-key-prefix@",
"test s3 key prefix"
})
public void testS3KeyPrefixWithInvalidCharacters(String s3KeyPrefix) {
ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration();

assertThrows(SdkClientException.class, () -> extendedClientConfiguration.withS3KeyPrefix(s3KeyPrefix));
}

@Test
public void testS3keyPrefixWithALargeString() {
int maxS3KeyLength = 1024;
int uuidLength = 36;
int maxS3KeyPrefixLength = maxS3KeyLength - uuidLength;
String s3KeyPrefix = generateStringWithLength(maxS3KeyPrefixLength + 1);

ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration();

assertThrows(SdkClientException.class, () -> extendedClientConfiguration.withS3KeyPrefix(s3KeyPrefix));
}
}
11 changes: 11 additions & 0 deletions src/test/java/com/amazon/sqs/javamessaging/StringTestUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.amazon.sqs.javamessaging;

import java.util.Arrays;

public class StringTestUtil {
public static String generateStringWithLength(int messageLength) {
char[] charArray = new char[messageLength];
Arrays.fill(charArray, 'x');
return new String(charArray);
}
}