From dcac11196b915bf360ba7ef4e3cb79f0d506893a Mon Sep 17 00:00:00 2001 From: Evangilo Morais Date: Fri, 22 Sep 2023 10:10:27 -0300 Subject: [PATCH] add support to set a prefix for the S3 key #7 --- .../AmazonSQSExtendedClient.java | 13 +++- .../ExtendedClientConfiguration.java | 71 +++++++++++++++++++ .../AmazonSQSExtendedClientTest.java | 59 +++++++++++++-- .../ExtendedClientConfigurationTest.java | 60 ++++++++++++++++ .../sqs/javamessaging/StringTestUtil.java | 11 +++ 5 files changed, 205 insertions(+), 9 deletions(-) create mode 100644 src/test/java/com/amazon/sqs/javamessaging/StringTestUtil.java diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index b8d42d4..3fb596f 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.UUID; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -965,7 +966,7 @@ private SendMessageBatchRequestEntry storeMessageInS3(SendMessageBatchRequestEnt updateMessageAttributePayloadSize(batchEntry.messageAttributes(), messageContentSize)); // Store the message content in S3. - String largeMessagePointer = payloadStore.storeOriginalPayload(messageContentStr); + String largeMessagePointer = storeOriginalPayload(messageContentStr); batchEntryBuilder.messageBody(largeMessagePointer); return batchEntryBuilder.build(); @@ -984,12 +985,20 @@ private SendMessageRequest storeMessageInS3(SendMessageRequest sendMessageReques updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), messageContentSize)); // Store the message content in S3. - String largeMessagePointer = payloadStore.storeOriginalPayload(messageContentStr); + String largeMessagePointer = storeOriginalPayload(messageContentStr); sendMessageRequestBuilder.messageBody(largeMessagePointer); return sendMessageRequestBuilder.build(); } + private String storeOriginalPayload(String messageContentStr) { + String s3KeyPrefix = clientConfiguration.getS3KeyPrefix(); + if (StringUtils.isBlank(s3KeyPrefix)) { + return payloadStore.storeOriginalPayload(messageContentStr); + } + return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID()); + } + private Map updateMessageAttributePayloadSize( Map messageAttributes, Long messageContentSize) { Map updatedMessageAttributes = new HashMap<>(messageAttributes); diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java index 039ca06..19e2d39 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java @@ -15,12 +15,18 @@ package com.amazon.sqs.javamessaging; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import software.amazon.awssdk.annotations.NotThreadSafe; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; +import software.amazon.awssdk.utils.StringUtils; import software.amazon.payloadoffloading.PayloadStorageConfiguration; import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; +import java.util.regex.Pattern; + /** * Amazon SQS extended client configuration options such as Amazon S3 client, @@ -28,10 +34,17 @@ */ @NotThreadSafe public class ExtendedClientConfiguration extends PayloadStorageConfiguration { + private static final Log LOG = LogFactory.getLog(ExtendedClientConfiguration.class); + + private static final int UUID_LENGTH = 36; + private static final int MAX_S3_KEY_LENGTH = 1024; + private static final int MAX_S3_KEY_PREFIX_LENGTH = MAX_S3_KEY_LENGTH - UUID_LENGTH; + private static final Pattern INVALID_S3_PREFIX_KEY_CHARACTERS_PATTERN = Pattern.compile("[^a-zA-Z0-9./_-]"); private boolean cleanupS3Payload = true; private boolean useLegacyReservedAttributeName = true; private boolean ignorePayloadNotFound = false; + private String s3KeyPrefix = ""; public ExtendedClientConfiguration() { super(); @@ -43,6 +56,7 @@ public ExtendedClientConfiguration(ExtendedClientConfiguration other) { this.cleanupS3Payload = other.doesCleanupS3Payload(); this.useLegacyReservedAttributeName = other.usesLegacyReservedAttributeName(); this.ignorePayloadNotFound = other.ignoresPayloadNotFound(); + this.s3KeyPrefix = other.s3KeyPrefix; } /** @@ -128,6 +142,63 @@ public ExtendedClientConfiguration withIgnorePayloadNotFound(boolean ignorePaylo return this; } + /** + * Sets a string that will be used as prefix of the S3 Key. + * + * @param s3KeyPrefix + * A S3 key prefix value + */ + public void setS3KeyPrefix(String s3KeyPrefix) { + String trimmedPrefix = StringUtils.trimToEmpty(s3KeyPrefix); + + if (trimmedPrefix.length() > MAX_S3_KEY_PREFIX_LENGTH) { + String errorMessage = "The S3 key prefix length must not be greater than " + MAX_S3_KEY_PREFIX_LENGTH; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + if (trimmedPrefix.startsWith(".") || trimmedPrefix.startsWith("/")) { + String errorMessage = "The S3 key prefix must not starts with '.' or '/'"; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + if (trimmedPrefix.contains("..")) { + String errorMessage = "The S3 key prefix must not contains the string '..'"; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + if (INVALID_S3_PREFIX_KEY_CHARACTERS_PATTERN.matcher(trimmedPrefix).find()) { + String errorMessage = "The S3 key prefix contain invalid characters. The allowed characters are: letters, digits, '/', '_', '-', and '.'"; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + this.s3KeyPrefix = trimmedPrefix; + } + + /** + * Sets a string that will be used as prefix of the S3 Key. + * + * @param s3KeyPrefix + * A S3 key prefix value + * + * @return the updated ExtendedClientConfiguration object. + */ + public ExtendedClientConfiguration withS3KeyPrefix(String s3KeyPrefix) { + setS3KeyPrefix(s3KeyPrefix); + return this; + } + + /** + * Gets the S3 key prefix + * @return the prefix value which is being used for compose the S3 key. + */ + public String getS3KeyPrefix() { + return this.s3KeyPrefix; + } + /** * Checks whether or not clean up large objects in S3 is enabled. * diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index 10c1b52..eda2179 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -15,9 +15,13 @@ package com.amazon.sqs.javamessaging; +import static com.amazon.sqs.javamessaging.StringTestUtil.generateStringWithLength; + +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.ApiName; import software.amazon.awssdk.core.ResponseInputStream; @@ -62,9 +66,11 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.isA; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -82,11 +88,16 @@ public class AmazonSQSExtendedClientTest { private SqsClient extendedSqsWithDefaultKMS; private SqsClient extendedSqsWithGenericReservedAttributeName; private SqsClient extendedSqsWithDeprecatedMethods; + private SqsClient extendedSqsWithS3KeyPrefix; private SqsClient mockSqsBackend; private S3Client mockS3; + + private MockedStatic uuidMockStatic; private static final String S3_BUCKET_NAME = "test-bucket-name"; private static final String SQS_QUEUE_URL = "test-queue-url"; private static final String S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID = "test-customer-managed-kms-key-id"; + private static final String S3_KEY_PREFIX = "test-s3-key-prefix"; + private static final String S3_KEY_UUID = "test-s3-key-uuid"; private static final int LESS_THAN_SQS_SIZE_LIMIT = 3; private static final int SQS_SIZE_LIMIT = 262144; @@ -101,6 +112,7 @@ public class AmazonSQSExtendedClientTest { @BeforeEach public void setupClients() { + uuidMockStatic = mockStatic(UUID.class); mockS3 = mock(S3Client.class); mockSqsBackend = mock(SqsClient.class); when(mockS3.putObject(isA(PutObjectRequest.class), isA(RequestBody.class))).thenReturn(null); @@ -121,11 +133,25 @@ public void setupClients() { ExtendedClientConfiguration extendedClientConfigurationDeprecated = new ExtendedClientConfiguration().withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME); + ExtendedClientConfiguration extendedClientConfigurationWithS3KeyPrefix = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withS3KeyPrefix(S3_KEY_PREFIX); + + UUID uuidMock = mock(UUID.class); + when(uuidMock.toString()).thenReturn(S3_KEY_UUID); + uuidMockStatic.when(UUID::randomUUID).thenReturn(uuidMock); + extendedSqsWithDefaultConfig = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); extendedSqsWithCustomKMS = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationWithCustomKMS)); extendedSqsWithDefaultKMS = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationWithDefaultKMS)); extendedSqsWithGenericReservedAttributeName = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationWithGenericReservedAttributeName)); extendedSqsWithDeprecatedMethods = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationDeprecated)); + extendedSqsWithS3KeyPrefix = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationWithS3KeyPrefix)); + } + + @AfterEach + public void tearDown() { + uuidMockStatic.close(); } @Test @@ -617,6 +643,32 @@ public void testWhenSendMessageWIthCannedAccessControlListDefined() { assertEquals(expected, captor.getValue().acl()); } + @Test + public void testWhenSendLargeMessageWithS3PrefixKeyDefined() { + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + + extendedSqsWithS3KeyPrefix.sendMessage(messageRequest); + + verify(mockS3, times(1)).putObject( + argThat((PutObjectRequest obj) -> obj.key().equals(S3_KEY_PREFIX + S3_KEY_UUID)), + isA(RequestBody.class)); + } + + @Test + public void testWhenSendLargeMessageWithUndefinedS3PrefixKey() { + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + + extendedSqsWithDefaultConfig.sendMessage(messageRequest); + + verify(mockS3, times(1)).putObject( + argThat((PutObjectRequest obj) -> obj.key().equals(S3_KEY_UUID)), + isA(RequestBody.class)); + } + private void testReceiveMessage_when_MessageIsLarge(String reservedAttributeName) { String pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "S3Key").toJson(); Message message = Message.builder() @@ -665,11 +717,4 @@ private String getLargeReceiptHandle(String s3Key, String originalReceiptHandle) private String getSampleLargeReceiptHandle(String originalReceiptHandle) { return getLargeReceiptHandle(UUID.randomUUID().toString(), originalReceiptHandle); } - - private String generateStringWithLength(int messageLength) { - char[] charArray = new char[messageLength]; - Arrays.fill(charArray, 'x'); - return new String(charArray); - } - } diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java index 35de297..2dc5b6b 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java @@ -15,7 +15,12 @@ package com.amazon.sqs.javamessaging; +import static com.amazon.sqs.javamessaging.StringTestUtil.generateStringWithLength; + import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.payloadoffloading.ServerSideEncryptionFactory; import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; @@ -24,6 +29,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.mockito.Mockito.mock; @@ -164,4 +170,58 @@ public void testMessageSizeThreshold() { assertEquals(messageLength, extendedClientConfiguration.getPayloadSizeThreshold()); } + + @ParameterizedTest + @ValueSource(strings = { + "test-s3-key-prefix", + "TEST-S3-KEY-PREFIX", + "test.s3.key.prefix", + "test_s3_key_prefix", + "test/s3/key/prefix/" + }) + public void testS3keyPrefix(String s3KeyPrefix) { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + + extendedClientConfiguration.withS3KeyPrefix(s3KeyPrefix); + + assertEquals(s3KeyPrefix, extendedClientConfiguration.getS3KeyPrefix()); + } + + @Test + public void testTrimS3keyPrefix() { + String s3KeyPrefix = "test-s3-key-prefix"; + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + + extendedClientConfiguration.withS3KeyPrefix(String.format(" %s ", s3KeyPrefix)); + + assertEquals(s3KeyPrefix, extendedClientConfiguration.getS3KeyPrefix()); + } + + @ParameterizedTest + @ValueSource(strings = { + ".test-s3-key-prefix", + "./test-s3-key-prefix", + "../test-s3-key-prefix", + "/test-s3-key-prefix", + "test..s3..key..prefix", + "test-s3-key-prefix@", + "test s3 key prefix" + }) + public void testS3KeyPrefixWithInvalidCharacters(String s3KeyPrefix) { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + + assertThrows(SdkClientException.class, () -> extendedClientConfiguration.withS3KeyPrefix(s3KeyPrefix)); + } + + @Test + public void testS3keyPrefixWithALargeString() { + int maxS3KeyLength = 1024; + int uuidLength = 36; + int maxS3KeyPrefixLength = maxS3KeyLength - uuidLength; + String s3KeyPrefix = generateStringWithLength(maxS3KeyPrefixLength + 1); + + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + + assertThrows(SdkClientException.class, () -> extendedClientConfiguration.withS3KeyPrefix(s3KeyPrefix)); + } } diff --git a/src/test/java/com/amazon/sqs/javamessaging/StringTestUtil.java b/src/test/java/com/amazon/sqs/javamessaging/StringTestUtil.java new file mode 100644 index 0000000..4de5ac4 --- /dev/null +++ b/src/test/java/com/amazon/sqs/javamessaging/StringTestUtil.java @@ -0,0 +1,11 @@ +package com.amazon.sqs.javamessaging; + +import java.util.Arrays; + +public class StringTestUtil { + public static String generateStringWithLength(int messageLength) { + char[] charArray = new char[messageLength]; + Arrays.fill(charArray, 'x'); + return new String(charArray); + } +}