diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 8200d7b..c35927a 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -1252,7 +1252,7 @@ private SendMessageBatchRequestEntry storeMessageInS3(SendMessageBatchRequestEnt checkMessageAttributes(batchEntry.getMessageAttributes()); - String s3Key = UUID.randomUUID().toString(); + String s3Key = getS3Key(); // Read the content of the message from message body String messageContentStr = batchEntry.getMessageBody(); @@ -1285,7 +1285,7 @@ private SendMessageRequest storeMessageInS3(SendMessageRequest sendMessageReques checkMessageAttributes(sendMessageRequest.getMessageAttributes()); - String s3Key = UUID.randomUUID().toString(); + String s3Key = getS3Key(); // Read the content of the message from message body String messageContentStr = sendMessageRequest.getMessageBody(); @@ -1315,6 +1315,10 @@ private SendMessageRequest storeMessageInS3(SendMessageRequest sendMessageReques return sendMessageRequest; } + private String getS3Key() { + return clientConfiguration.isS3KeyUsed() ? clientConfiguration.getS3Key() + UUID.randomUUID().toString() : UUID.randomUUID().toString(); + } + private String getJSONFromS3Pointer(MessageS3Pointer s3Pointer) { String s3PointerStr = null; try { diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java index 4852178..f7a565e 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java @@ -21,8 +21,6 @@ import org.apache.commons.logging.LogFactory; import com.amazonaws.annotation.NotThreadSafe; -import java.util.List; - /** * Amazon SQS extended client configuration options such as Amazon S3 client, * bucket name, and message size threshold for large-payload messages. @@ -33,6 +31,7 @@ public class ExtendedClientConfiguration { private AmazonS3 s3; private String s3BucketName; + private String s3Key; private boolean largePayloadSupport = false; private boolean alwaysThroughS3 = false; private int messageSizeThreshold = SQSExtendedClientConstants.DEFAULT_MESSAGE_SIZE_THRESHOLD; @@ -40,11 +39,13 @@ public class ExtendedClientConfiguration { public ExtendedClientConfiguration() { s3 = null; s3BucketName = null; + s3Key= null; } public ExtendedClientConfiguration(ExtendedClientConfiguration other) { this.s3 = other.s3; this.s3BucketName = other.s3BucketName; + this.s3Key = other.s3Key; this.largePayloadSupport = other.largePayloadSupport; this.alwaysThroughS3 = other.alwaysThroughS3; this.messageSizeThreshold = other.messageSizeThreshold; @@ -93,6 +94,57 @@ public ExtendedClientConfiguration withLargePayloadSupportEnabled(AmazonS3 s3, S return this; } + /** + * + * @param s3 + * Amazon S3 client which is going to be used for storing + * large-payload messages. + * @param s3BucketName + * Name of the bucket which is going to be used for storing + * large-payload messages. The bucket must be already created and + * configured in s3. + * @param s3Key + * Name of the s3 key which is going to be used for storing + * large-payload messages. The bucket must be already created and + * configured in s3. + */ + public void setLargePayloadSupportEnabled(AmazonS3 s3, String s3BucketName, String s3Key) { + if (s3 == null || s3BucketName == null || s3Key == null) { + String errorMessage = "S3 client and/or S3 bucket name and/or S3 key cannot be null."; + LOG.error(errorMessage); + throw new AmazonClientException(errorMessage); + } + if (isLargePayloadSupportEnabled()) { + LOG.warn("Large-payload support is already enabled. Overwriting AmazonS3Client, S3BucketName, S3key."); + } + this.s3 = s3; + this.s3BucketName = s3BucketName; + this.s3Key= s3Key; + largePayloadSupport = true; + LOG.info("Large-payload support enabled."); + } + + /** + * + * @param s3 + * Amazon S3 client which is going to be used for storing + * large-payload messages. + * @param s3BucketName + * Name of the bucket which is going to be used for storing + * large-payload messages. The bucket must be already created and + * configured in s3. + * @param s3Key + * Name of the s3 key which is going to be used for storing + * large-payload messages. The bucket must be already created and + * configured in s3. + * @return the updated ExtendedClientConfiguration object. + */ + public ExtendedClientConfiguration withLargePayloadSupportEnabled(AmazonS3 s3, String s3BucketName,String s3Key) { + setLargePayloadSupportEnabled(s3, s3BucketName,s3Key); + return this; + } + + /** * Disables support for large-payload messages. */ @@ -100,6 +152,7 @@ public void setLargePayloadSupportDisabled() { s3 = null; s3BucketName = null; largePayloadSupport = false; + s3Key = null; LOG.info("Large-payload support disabled."); } @@ -141,6 +194,16 @@ public String getS3BucketName() { return s3BucketName; } + /** + * Gets the name of the S3 key which is being used for storing + * large-payload messages. + * + * @return The name of the key which is being used. + */ + public String getS3Key() { + return s3Key; + } + /** * Sets the message size threshold for storing message payloads in Amazon * S3. @@ -214,4 +277,13 @@ public ExtendedClientConfiguration withAlwaysThroughS3(boolean alwaysThroughS3) public boolean isAlwaysThroughS3() { return alwaysThroughS3; } + + /** + * Checks whether or not S3 key exists. + * + * @return True if S3 key is specified in client. Default: false + */ + public boolean isS3KeyUsed() { + return s3Key != null; + } } diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index 3dd70c3..4d98d93 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -51,10 +51,12 @@ public class AmazonSQSExtendedClientTest { private AmazonSQS extendedSqsWithDefaultConfig; + private AmazonSQS extendedSqsWithS3KeyAndDefaultConfig; private AmazonSQS mockSqsBackend; private AmazonS3 mockS3; private static final String S3_BUCKET_NAME = "test-bucket-name"; private static final String SQS_QUEUE_URL = "test-queue-url"; + private static final String S3_KEY = "sqs/messages/"; private static final int LESS_THAN_SQS_SIZE_LIMIT = 3; private static final int SQS_SIZE_LIMIT = 262144; @@ -72,7 +74,11 @@ public void setupClient() { ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME); + ExtendedClientConfiguration extendedClientConfigurationWithS3Key = new ExtendedClientConfiguration() + .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME,S3_KEY); + extendedSqsWithDefaultConfig = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + extendedSqsWithS3KeyAndDefaultConfig = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfigurationWithS3Key)); } @@ -125,6 +131,19 @@ public void testWhenSendMessageWithAlwaysThroughS3AndMessageIsSmallThenItIsStill verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class)); } + @Test + public void testWhenSendMessageWithAlwaysThroughS3AndS3KeyAndMessageIsSmallThenItIsStillStoredInS3() { + int messageLength = LESS_THAN_SQS_SIZE_LIMIT; + String messageBody = generateStringWithLength(messageLength); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME,S3_KEY).withAlwaysThroughS3(true); + AmazonSQS sqsExtended = spy(new AmazonSQSExtendedClient(mock(AmazonSQSClient.class), extendedClientConfiguration)); + + SendMessageRequest messageRequest = new SendMessageRequest(SQS_QUEUE_URL, messageBody); + sqsExtended.sendMessage(messageRequest); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class)); + } @Test public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonored() { @@ -139,6 +158,19 @@ public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonored sqsExtended.sendMessage(messageRequest); verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class)); } + @Test + public void testWhenSendMessageWithS3KeyAndSetMessageSizeThresholdThenThresholdIsHonored() { + int messageLength = ARBITRATY_SMALLER_THRESSHOLD * 2; + String messageBody = generateStringWithLength(messageLength); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME,S3_KEY).withMessageSizeThreshold(ARBITRATY_SMALLER_THRESSHOLD); + + AmazonSQS sqsExtended = spy(new AmazonSQSExtendedClient(mock(AmazonSQSClient.class), extendedClientConfiguration)); + + SendMessageRequest messageRequest = new SendMessageRequest(SQS_QUEUE_URL, messageBody); + sqsExtended.sendMessage(messageRequest); + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class)); + } @Test public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessageRequest() { @@ -157,6 +189,23 @@ public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessag sqsExtended.receiveMessage(messageRequest); Assert.assertEquals(expectedRequest, messageRequest); } + @Test + public void testReceiveMessageWithS3KeyMultipleTimesDoesNotAdditionallyAlterReceiveMessageRequest() { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME,S3_KEY); + AmazonSQS sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn(new ReceiveMessageResult()); + + ReceiveMessageRequest messageRequest = new ReceiveMessageRequest(); + ReceiveMessageRequest expectedRequest = new ReceiveMessageRequest() + .withMessageAttributeNames(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); + + sqsExtended.receiveMessage(messageRequest); + Assert.assertEquals(expectedRequest, messageRequest); + + sqsExtended.receiveMessage(messageRequest); + Assert.assertEquals(expectedRequest, messageRequest); + } @Test public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() { @@ -192,6 +241,40 @@ public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStor // There should be 8 puts for the 8 messages above the threshhold verify(mockS3, times(8)).putObject(isA(PutObjectRequest.class)); } + @Test + public void testWhenMessageWithS3KeyBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() { + // This creates 10 messages, out of which only two are below the threshold (100K and 200K), + // and the other 8 are above the threshold + + int[] messageLengthForCounter = new int[] { + 100_000, + 300_000, + 400_000, + 500_000, + 600_000, + 700_000, + 800_000, + 900_000, + 200_000, + 1000_000 + }; + + List batchEntries = new ArrayList(); + for (int i = 0; i < 10; i++) { + SendMessageBatchRequestEntry entry = new SendMessageBatchRequestEntry(); + int messageLength = messageLengthForCounter[i]; + String messageBody = generateStringWithLength(messageLength); + entry.setMessageBody(messageBody); + entry.setId("entry_" + i); + batchEntries.add(entry); + } + + SendMessageBatchRequest batchRequest = new SendMessageBatchRequest(SQS_QUEUE_URL, batchEntries); + extendedSqsWithS3KeyAndDefaultConfig.sendMessageBatch(batchRequest); + + // There should be 8 puts for the 8 messages above the threshhold + verify(mockS3, times(8)).putObject(isA(PutObjectRequest.class)); + } @Test public void testWhenSmallMessageIsSentThenNoAttributeIsAdded() { @@ -207,6 +290,20 @@ public void testWhenSmallMessageIsSentThenNoAttributeIsAdded() { Map attributes = sendMessageRequestCaptor.getValue().getMessageAttributes(); Assert.assertTrue(attributes.isEmpty()); } + @Test + public void testWhenSmallMessageWithS3KeyIsSentThenNoAttributeIsAdded() { + int messageLength = LESS_THAN_SQS_SIZE_LIMIT; + String messageBody = generateStringWithLength(messageLength); + + SendMessageRequest messageRequest = new SendMessageRequest(SQS_QUEUE_URL, messageBody); + extendedSqsWithS3KeyAndDefaultConfig.sendMessage(messageRequest); + + ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); + + Map attributes = sendMessageRequestCaptor.getValue().getMessageAttributes(); + Assert.assertTrue(attributes.isEmpty()); + } @Test public void testWhenLargeMessgaeIsSentThenAttributeWithPayloadSizeIsAdded() { @@ -223,6 +320,21 @@ public void testWhenLargeMessgaeIsSentThenAttributeWithPayloadSizeIsAdded() { Assert.assertEquals("Number", attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getDataType()); Assert.assertEquals(messageLength, (int)Integer.valueOf(attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getStringValue())); } + @Test + public void testWhenLargeMessgaeWithS3KeyIsSentThenAttributeWithPayloadSizeIsAdded() { + int messageLength = MORE_THAN_SQS_SIZE_LIMIT; + String messageBody = generateStringWithLength(messageLength); + + SendMessageRequest messageRequest = new SendMessageRequest(SQS_QUEUE_URL, messageBody); + extendedSqsWithS3KeyAndDefaultConfig.sendMessage(messageRequest); + + ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); + + Map attributes = sendMessageRequestCaptor.getValue().getMessageAttributes(); + Assert.assertEquals("Number", attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getDataType()); + Assert.assertEquals(messageLength, (int)Integer.valueOf(attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getStringValue())); + } private String generateStringWithLength(int messageLength) { char[] charArray = new char[messageLength]; diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java index 9949853..642a603 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java @@ -29,6 +29,7 @@ public class ExtendedClientConfigurationTest { private static String s3BucketName = "test-bucket-name"; + private static String s3Key = "sqs/messages/"; @Before public void setup() { @@ -46,7 +47,7 @@ public void testCopyConstructor() { ExtendedClientConfiguration extendedClientConfig = new ExtendedClientConfiguration(); - extendedClientConfig.withLargePayloadSupportEnabled(s3, s3BucketName) + extendedClientConfig.withLargePayloadSupportEnabled(s3, s3BucketName,s3Key) .withAlwaysThroughS3(alwaysThroughS3).withMessageSizeThreshold(messageSizeThreshold); ExtendedClientConfiguration newExtendedClientConfig = new ExtendedClientConfiguration(extendedClientConfig); @@ -75,6 +76,48 @@ public void testLargePayloadSupportEnabled() { } + @Test + public void testCopyConstructorWithS3Key() { + + AmazonS3 s3 = mock(AmazonS3.class); + when(s3.putObject(isA(PutObjectRequest.class))).thenReturn(null); + + boolean alwaysThroughS3 = true; + int messageSizeThreshold = 500; + + ExtendedClientConfiguration extendedClientConfig = new ExtendedClientConfiguration(); + + extendedClientConfig.withLargePayloadSupportEnabled(s3, s3BucketName,s3Key) + .withAlwaysThroughS3(alwaysThroughS3).withMessageSizeThreshold(messageSizeThreshold); + + ExtendedClientConfiguration newExtendedClientConfig = new ExtendedClientConfiguration(extendedClientConfig); + + Assert.assertEquals(s3, newExtendedClientConfig.getAmazonS3Client()); + Assert.assertEquals(s3BucketName, newExtendedClientConfig.getS3BucketName()); + Assert.assertEquals(s3Key, newExtendedClientConfig.getS3Key()); + Assert.assertTrue(newExtendedClientConfig.isLargePayloadSupportEnabled()); + Assert.assertEquals(alwaysThroughS3, newExtendedClientConfig.isAlwaysThroughS3()); + Assert.assertEquals(messageSizeThreshold, newExtendedClientConfig.getMessageSizeThreshold()); + + Assert.assertNotSame(newExtendedClientConfig, extendedClientConfig); + } + + @Test + public void testLargePayloadSupportEnabledWithS3Key() { + + AmazonS3 s3 = mock(AmazonS3.class); + when(s3.putObject(isA(PutObjectRequest.class))).thenReturn(null); + + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.setLargePayloadSupportEnabled(s3, s3BucketName,s3Key); + + Assert.assertTrue(extendedClientConfiguration.isLargePayloadSupportEnabled()); + Assert.assertNotNull(extendedClientConfiguration.getAmazonS3Client()); + Assert.assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName()); + Assert.assertEquals(s3Key, extendedClientConfiguration.getS3Key()); + + } + @Test public void testDisableLargePayloadSupport() { @@ -86,6 +129,7 @@ public void testDisableLargePayloadSupport() { Assert.assertNull(extendedClientConfiguration.getAmazonS3Client()); Assert.assertNull(extendedClientConfiguration.getS3BucketName()); + Assert.assertNull(extendedClientConfiguration.getS3Key()); verify(s3, never()).putObject(isA(PutObjectRequest.class)); } @@ -115,6 +159,15 @@ public void testMessageSizeThreshold() { Assert.assertEquals(messageLength, extendedClientConfiguration.getMessageSizeThreshold()); } + @Test + public void testS3KeyUsedWhenKeyNameIsNotSpecified() { + AmazonS3 s3 = mock(AmazonS3.class); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.withLargePayloadSupportEnabled(s3,s3BucketName); + + Assert.assertEquals(false,extendedClientConfiguration.isS3KeyUsed()); + + } }