diff --git a/.gitignore b/.gitignore index 2f7896d..ec10551 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ target/ +.idea/ +*.iml diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 8200d7b..3c5abd4 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -22,9 +22,7 @@ import java.io.Writer; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Map; -import java.util.UUID; +import java.util.*; import java.util.Map.Entry; import com.amazonaws.AmazonClientException; @@ -707,6 +705,13 @@ public ChangeMessageVisibilityResult changeMessageVisibility(ChangeMessageVisibi * S3 when necessary. *

*

+ * As a batch can become larger than the threshold, this method will split + * the batch into smaller batches, preserving the order of messages. If any of the + * calls to send a message batch results in an error, the sending process will stop. + * This can result, for large messages just under the threshold, in a batch of size n + * producing n batches. + *

+ *

* If the DelaySeconds parameter is not specified for an entry, * the default for the queue is used. *

@@ -741,7 +746,8 @@ public ChangeMessageVisibilityResult changeMessageVisibility(ChangeMessageVisibi * SendMessageBatch service method on AmazonSQS. * * @return The response from the SendMessageBatch service method, as - * returned by AmazonSQS. + * returned by AmazonSQS. Note: if there are multiple calls + * then this is a composite of the calls. * * @throws BatchEntryIdsNotDistinctException * @throws TooManyEntriesInBatchRequestException @@ -774,16 +780,88 @@ public SendMessageBatchResult sendMessageBatch(SendMessageBatchRequest sendMessa } List batchEntries = sendMessageBatchRequest.getEntries(); + List messageSizes = measureMessagesAndMoveLargeOnesToS3(batchEntries); + + return splitAndSendMessageBatches(sendMessageBatchRequest, batchEntries, messageSizes); + } + + private SendMessageBatchResult splitAndSendMessageBatches(SendMessageBatchRequest sendMessageBatchRequest, + List batchEntries, + List messageSizes) { + int range = getSliceOfBatchToSend(messageSizes); + + // modify the send batch request to use the messages that fit + List sendThisTime = batchEntries.subList(0, range); + sendMessageBatchRequest.setEntries(sendThisTime); + SendMessageBatchResult result = super.sendMessageBatch(sendMessageBatchRequest); + + // calculate how many will be sent in the next slice + List remainder = batchEntries.subList(range, batchEntries.size()); + List remainingSizes = messageSizes.subList(range, batchEntries.size()); + + // return this result if there were errors, or if there's + // nothing further to send + if (!result.getFailed().isEmpty() || remainder.isEmpty()) { + return result; + } + + // recurse into the method to send the remainder + SendMessageBatchResult recursiveResult = + splitAndSendMessageBatches(sendMessageBatchRequest, remainder, remainingSizes); + + // add the messages successfully sent by the earlier call + // preserving the order of successful messages + // Note: the result from earlier cannot have failures in it so we don't need to copy them + // similarly, we want the recursive result's overall status as it's the last call in the chain. + result.getSuccessful().addAll(recursiveResult.getSuccessful()); + recursiveResult.setSuccessful(result.getSuccessful()); + + return recursiveResult; + } + + private int getSliceOfBatchToSend(List messageSizes) { + int range = 0; + long totalSize = 0; + for (Long messageSize : messageSizes) { + // measure the total including this, to see if + // we've exceeded the maximum size + totalSize += messageSize; + if (isLarge(totalSize)) { + // stop here + break; + } + + // we can include this item + range++; + } + + // the earlier code should already have made this impossible, but add error handling rather than + // a confusing infinite loop/stack overflow + if (range == 0) { + throw new IllegalStateException("A message in the batch is larger than the threshold when " + + "it should already have been exported to S3"); + } + return range; + } + + private List measureMessagesAndMoveLargeOnesToS3(List batchEntries) { int index = 0; + List messageSizes = new LinkedList<>(); for (SendMessageBatchRequestEntry entry : batchEntries) { - if (clientConfiguration.isAlwaysThroughS3() || isLarge(entry)) { - batchEntries.set(index, storeMessageInS3(entry)); + long entrySize = sizeOf(entry); + + if (clientConfiguration.isAlwaysThroughS3() || isLarge(entrySize)) { + SendMessageBatchRequestEntry storedVersion = storeMessageInS3(entry); + batchEntries.set(index, storedVersion); + messageSizes.add(sizeOf(storedVersion)); + } else { + messageSizes.add(entrySize); } + ++index; } - - return super.sendMessageBatch(sendMessageBatchRequest); + return messageSizes; } /** @@ -1215,14 +1293,17 @@ private boolean isLarge(SendMessageRequest sendMessageRequest) { int msgAttributesSize = getMsgAttributesSize(sendMessageRequest.getMessageAttributes()); long msgBodySize = getStringSizeInBytes(sendMessageRequest.getMessageBody()); long totalMsgSize = msgAttributesSize + msgBodySize; - return (totalMsgSize > clientConfiguration.getMessageSizeThreshold()); + return isLarge(totalMsgSize); } - private boolean isLarge(SendMessageBatchRequestEntry batchEntry) { + private boolean isLarge(long size) { + return (size > clientConfiguration.getMessageSizeThreshold()); + } + + private long sizeOf(SendMessageBatchRequestEntry batchEntry) { int msgAttributesSize = getMsgAttributesSize(batchEntry.getMessageAttributes()); long msgBodySize = getStringSizeInBytes(batchEntry.getMessageBody()); - long totalMsgSize = msgAttributesSize + msgBodySize; - return (totalMsgSize > clientConfiguration.getMessageSizeThreshold()); + return msgAttributesSize + msgBodySize; } private int getMsgAttributesSize(Map msgAttributes) { diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index 3dd70c3..8c3cd14 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -15,44 +15,47 @@ package com.amazon.sqs.javamessaging; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; - import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.model.PutObjectRequest; import com.amazonaws.services.sqs.AmazonSQS; import com.amazonaws.services.sqs.AmazonSQSClient; -import com.amazonaws.services.sqs.model.MessageAttributeValue; -import com.amazonaws.services.sqs.model.ReceiveMessageRequest; -import com.amazonaws.services.sqs.model.ReceiveMessageResult; -import com.amazonaws.services.sqs.model.SendMessageBatchRequest; -import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry; -import com.amazonaws.services.sqs.model.SendMessageRequest; - -import junit.framework.Assert; +import com.amazonaws.services.sqs.model.*; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; /** * Tests the AmazonSQSExtendedClient class. */ +@RunWith(MockitoJUnitRunner.class) public class AmazonSQSExtendedClientTest { private AmazonSQS extendedSqsWithDefaultConfig; + + @Mock private AmazonSQS mockSqsBackend; + + @Mock private AmazonS3 mockS3; + + private List> sendMessageBatchInvocationEntries = new ArrayList<>(); + private static final String S3_BUCKET_NAME = "test-bucket-name"; private static final String SQS_QUEUE_URL = "test-queue-url"; @@ -61,14 +64,31 @@ public class AmazonSQSExtendedClientTest { private static final int MORE_THAN_SQS_SIZE_LIMIT = SQS_SIZE_LIMIT + 1; // should be > 1 and << SQS_SIZE_LIMIT - private static final int ARBITRATY_SMALLER_THRESSHOLD = 500; + private static final int ARBITRARY_SMALLER_THRESHOLD = 500; @Before public void setupClient() { - mockS3 = mock(AmazonS3.class); - mockSqsBackend = mock(AmazonSQS.class); + when(mockS3.putObject(isA(PutObjectRequest.class))).thenReturn(null); + // send message batch must return a result and must record the entries used + // we can't use a captor for this as the send message batch request object is + // changed during execution, so the captor ends up with multiple references to + // the same object and we don't see the state of the earliest + when(mockSqsBackend.sendMessageBatch(any(SendMessageBatchRequest.class))) + .thenAnswer(new Answer() { + + @Override + public SendMessageBatchResult answer(InvocationOnMock invocation) throws Throwable { + // record the entries + List entries = + invocation.getArgumentAt(0, SendMessageBatchRequest.class).getEntries(); + sendMessageBatchInvocationEntries.add(entries); + + return new SendMessageBatchResult(); + } + }); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME); @@ -123,15 +143,15 @@ public void testWhenSendMessageWithAlwaysThroughS3AndMessageIsSmallThenItIsStill SendMessageRequest messageRequest = new SendMessageRequest(SQS_QUEUE_URL, messageBody); sqsExtended.sendMessage(messageRequest); - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class)); + verify(mockS3).putObject(isA(PutObjectRequest.class)); } @Test public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonored() { - int messageLength = ARBITRATY_SMALLER_THRESSHOLD * 2; + int messageLength = ARBITRARY_SMALLER_THRESHOLD * 2; String messageBody = generateStringWithLength(messageLength); ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withMessageSizeThreshold(ARBITRATY_SMALLER_THRESSHOLD); + .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withMessageSizeThreshold(ARBITRARY_SMALLER_THRESHOLD); AmazonSQS sqsExtended = spy(new AmazonSQSExtendedClient(mock(AmazonSQSClient.class), extendedClientConfiguration)); @@ -152,15 +172,39 @@ public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessag .withMessageAttributeNames(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); sqsExtended.receiveMessage(messageRequest); - Assert.assertEquals(expectedRequest, messageRequest); + assertEquals(expectedRequest, messageRequest); sqsExtended.receiveMessage(messageRequest); - Assert.assertEquals(expectedRequest, messageRequest); + assertEquals(expectedRequest, messageRequest); + } + + @Test + public void testWhenSmallMessageBatchIsSentThenNoMessagesStoredInS3() { + // This creates 10 messages all well within the threshold + + int[] messageLengthForCounter = new int[] { + 1_000, + 1_000, + 1_000, + 1_000, + 1_000, + 1_000, + 1_000, + 1_000, + 1_000, + 1_000 + }; + + SendMessageBatchRequest batchRequest = createMessageBatchWithSizes(messageLengthForCounter); + extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest); + + // There should be no puts + verify(mockS3, never()).putObject(isA(PutObjectRequest.class)); } @Test public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() { - // This creates 10 messages, out of which only two are below the threshold (100K and 200K), + // This creates 10 messages, out of which only two are below the threshold (100K and 20K), // and the other 8 are above the threshold int[] messageLengthForCounter = new int[] { @@ -172,25 +216,86 @@ public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStor 700_000, 800_000, 900_000, - 200_000, + 20_000, 1000_000 }; - List batchEntries = new ArrayList(); - for (int i = 0; i < 10; i++) { - SendMessageBatchRequestEntry entry = new SendMessageBatchRequestEntry(); - int messageLength = messageLengthForCounter[i]; - String messageBody = generateStringWithLength(messageLength); - entry.setMessageBody(messageBody); - entry.setId("entry_" + i); - batchEntries.add(entry); - } - - SendMessageBatchRequest batchRequest = new SendMessageBatchRequest(SQS_QUEUE_URL, batchEntries); + SendMessageBatchRequest batchRequest = createMessageBatchWithSizes(messageLengthForCounter); extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest); - // There should be 8 puts for the 8 messages above the threshhold + // There should be 8 puts for the 8 messages above the threshold verify(mockS3, times(8)).putObject(isA(PutObjectRequest.class)); + + // and one batch send + verify(mockSqsBackend).sendMessageBatch(any(SendMessageBatchRequest.class)); + } + + @Test + public void testWhenMessageBatchIsSentWhereSumOfMessageSizesIsOverTheThresholdThenBatchIsSplit() { + // This creates 10 messages, all of which are below the threshold, but together would make + // a single request over the threshold + + int[] messageLengthForCounter = new int[] { + 26_214, + 26_214, + 26_214, + 26_214, + 26_214, + 26_214, + 26_214, + 26_214, + 26_214, + 26_219 + }; + + SendMessageBatchRequest batchRequest = createMessageBatchWithSizes(messageLengthForCounter); + extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest); + + // The client should not put any objects to S3 + verify(mockS3, never()).putObject(isA(PutObjectRequest.class)); + + // The client should have made two requests to SQS + verify(mockSqsBackend, times(2)).sendMessageBatch(any(SendMessageBatchRequest.class)); + + // the client will have put most messages in the first batch, then the remainder in a second + assertEquals(9, sendMessageBatchInvocationEntries.get(0).size()); + assertEquals(1, sendMessageBatchInvocationEntries.get(1).size()); + } + + @Test + public void testWhenMessageBatchIsMadeOfLargeMessagesThenBatchIsSplitAndOrderMaintained() { + // This creates 10 messages, all of which are below the threshold, but together would make + // a single request over the threshold + + int[] messageLengthForCounter = new int[] { + SQS_SIZE_LIMIT, + SQS_SIZE_LIMIT - 1, + SQS_SIZE_LIMIT - 2, + SQS_SIZE_LIMIT - 3, + SQS_SIZE_LIMIT - 4, + SQS_SIZE_LIMIT - 5, + SQS_SIZE_LIMIT - 6, + SQS_SIZE_LIMIT - 7, + SQS_SIZE_LIMIT - 8, + SQS_SIZE_LIMIT - 9 + }; + + SendMessageBatchRequest batchRequest = createMessageBatchWithSizes(messageLengthForCounter); + extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest); + + // The client should not put any objects to S3 as they are all small enough + // to send to SQS + verify(mockS3, never()).putObject(isA(PutObjectRequest.class)); + + // The client should have sent each item as a batch request + verify(mockSqsBackend, times(10)).sendMessageBatch(any(SendMessageBatchRequest.class)); + + // the order of messages has been preserved + for (int i = 0; i < messageLengthForCounter.length; i++) { + // each batch should correspond to the message length from the list + assertEquals(messageLengthForCounter[i], + sendMessageBatchInvocationEntries.get(i).get(0).getMessageBody().length()); + } } @Test @@ -205,11 +310,11 @@ public void testWhenSmallMessageIsSentThenNoAttributeIsAdded() { verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); Map attributes = sendMessageRequestCaptor.getValue().getMessageAttributes(); - Assert.assertTrue(attributes.isEmpty()); + assertTrue(attributes.isEmpty()); } @Test - public void testWhenLargeMessgaeIsSentThenAttributeWithPayloadSizeIsAdded() { + public void testWhenLargeMessageIsSentThenAttributeWithPayloadSizeIsAdded() { int messageLength = MORE_THAN_SQS_SIZE_LIMIT; String messageBody = generateStringWithLength(messageLength); @@ -220,8 +325,22 @@ public void testWhenLargeMessgaeIsSentThenAttributeWithPayloadSizeIsAdded() { verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); Map attributes = sendMessageRequestCaptor.getValue().getMessageAttributes(); - Assert.assertEquals("Number", attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getDataType()); - Assert.assertEquals(messageLength, (int)Integer.valueOf(attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getStringValue())); + assertEquals("Number", attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getDataType()); + assertEquals(messageLength, (int)Integer.valueOf(attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getStringValue())); + } + + private SendMessageBatchRequest createMessageBatchWithSizes(int[] messageLengthForCounter) { + List batchEntries = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + SendMessageBatchRequestEntry entry = new SendMessageBatchRequestEntry(); + int messageLength = messageLengthForCounter[i]; + String messageBody = generateStringWithLength(messageLength); + entry.setMessageBody(messageBody); + entry.setId("entry_" + i); + batchEntries.add(entry); + } + + return new SendMessageBatchRequest(SQS_QUEUE_URL, batchEntries); } private String generateStringWithLength(int messageLength) {