From 12be628fb7c51f012ef291913b44f9e821f26b64 Mon Sep 17 00:00:00 2001
From: Ashley Frieze
Date: Wed, 17 Jul 2019 21:49:30 +0100
Subject: [PATCH] Split batches which would go over the threshold
---
.gitignore | 2 +
.../AmazonSQSExtendedClient.java | 105 ++++++++-
.../AmazonSQSExtendedClientTest.java | 211 ++++++++++++++----
3 files changed, 260 insertions(+), 58 deletions(-)
diff --git a/.gitignore b/.gitignore
index 2f7896d..ec10551 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,3 @@
target/
+.idea/
+*.iml
diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java
index 8200d7b..3c5abd4 100644
--- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java
+++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java
@@ -22,9 +22,7 @@
import java.io.Writer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
-import java.util.List;
-import java.util.Map;
-import java.util.UUID;
+import java.util.*;
import java.util.Map.Entry;
import com.amazonaws.AmazonClientException;
@@ -707,6 +705,13 @@ public ChangeMessageVisibilityResult changeMessageVisibility(ChangeMessageVisibi
* S3 when necessary.
*
*
+ * As a batch can become larger than the threshold, this method will split
+ * the batch into smaller batches, preserving the order of messages. If any of the
+ * calls to send a message batch results in an error, the sending process will stop.
+ * This can result, for large messages just under the threshold, in a batch of size n
+ * producing n batches.
+ *
+ *
* If the DelaySeconds
parameter is not specified for an entry,
* the default for the queue is used.
*
@@ -741,7 +746,8 @@ public ChangeMessageVisibilityResult changeMessageVisibility(ChangeMessageVisibi
* SendMessageBatch service method on AmazonSQS.
*
* @return The response from the SendMessageBatch service method, as
- * returned by AmazonSQS.
+ * returned by AmazonSQS. Note: if there are multiple calls
+ * then this is a composite of the calls.
*
* @throws BatchEntryIdsNotDistinctException
* @throws TooManyEntriesInBatchRequestException
@@ -774,16 +780,88 @@ public SendMessageBatchResult sendMessageBatch(SendMessageBatchRequest sendMessa
}
List batchEntries = sendMessageBatchRequest.getEntries();
+ List messageSizes = measureMessagesAndMoveLargeOnesToS3(batchEntries);
+
+ return splitAndSendMessageBatches(sendMessageBatchRequest, batchEntries, messageSizes);
+ }
+
+ private SendMessageBatchResult splitAndSendMessageBatches(SendMessageBatchRequest sendMessageBatchRequest,
+ List batchEntries,
+ List messageSizes) {
+ int range = getSliceOfBatchToSend(messageSizes);
+
+ // modify the send batch request to use the messages that fit
+ List sendThisTime = batchEntries.subList(0, range);
+ sendMessageBatchRequest.setEntries(sendThisTime);
+ SendMessageBatchResult result = super.sendMessageBatch(sendMessageBatchRequest);
+
+ // calculate how many will be sent in the next slice
+ List remainder = batchEntries.subList(range, batchEntries.size());
+ List remainingSizes = messageSizes.subList(range, batchEntries.size());
+
+ // return this result if there were errors, or if there's
+ // nothing further to send
+ if (!result.getFailed().isEmpty() || remainder.isEmpty()) {
+ return result;
+ }
+
+ // recurse into the method to send the remainder
+ SendMessageBatchResult recursiveResult =
+ splitAndSendMessageBatches(sendMessageBatchRequest, remainder, remainingSizes);
+
+ // add the messages successfully sent by the earlier call
+ // preserving the order of successful messages
+ // Note: the result from earlier cannot have failures in it so we don't need to copy them
+ // similarly, we want the recursive result's overall status as it's the last call in the chain.
+ result.getSuccessful().addAll(recursiveResult.getSuccessful());
+ recursiveResult.setSuccessful(result.getSuccessful());
+
+ return recursiveResult;
+ }
+
+ private int getSliceOfBatchToSend(List messageSizes) {
+ int range = 0;
+ long totalSize = 0;
+ for (Long messageSize : messageSizes) {
+ // measure the total including this, to see if
+ // we've exceeded the maximum size
+ totalSize += messageSize;
+ if (isLarge(totalSize)) {
+ // stop here
+ break;
+ }
+
+ // we can include this item
+ range++;
+ }
+
+ // the earlier code should already have made this impossible, but add error handling rather than
+ // a confusing infinite loop/stack overflow
+ if (range == 0) {
+ throw new IllegalStateException("A message in the batch is larger than the threshold when " +
+ "it should already have been exported to S3");
+ }
+ return range;
+ }
+
+ private List measureMessagesAndMoveLargeOnesToS3(List batchEntries) {
int index = 0;
+ List messageSizes = new LinkedList<>();
for (SendMessageBatchRequestEntry entry : batchEntries) {
- if (clientConfiguration.isAlwaysThroughS3() || isLarge(entry)) {
- batchEntries.set(index, storeMessageInS3(entry));
+ long entrySize = sizeOf(entry);
+
+ if (clientConfiguration.isAlwaysThroughS3() || isLarge(entrySize)) {
+ SendMessageBatchRequestEntry storedVersion = storeMessageInS3(entry);
+ batchEntries.set(index, storedVersion);
+ messageSizes.add(sizeOf(storedVersion));
+ } else {
+ messageSizes.add(entrySize);
}
+
++index;
}
-
- return super.sendMessageBatch(sendMessageBatchRequest);
+ return messageSizes;
}
/**
@@ -1215,14 +1293,17 @@ private boolean isLarge(SendMessageRequest sendMessageRequest) {
int msgAttributesSize = getMsgAttributesSize(sendMessageRequest.getMessageAttributes());
long msgBodySize = getStringSizeInBytes(sendMessageRequest.getMessageBody());
long totalMsgSize = msgAttributesSize + msgBodySize;
- return (totalMsgSize > clientConfiguration.getMessageSizeThreshold());
+ return isLarge(totalMsgSize);
}
- private boolean isLarge(SendMessageBatchRequestEntry batchEntry) {
+ private boolean isLarge(long size) {
+ return (size > clientConfiguration.getMessageSizeThreshold());
+ }
+
+ private long sizeOf(SendMessageBatchRequestEntry batchEntry) {
int msgAttributesSize = getMsgAttributesSize(batchEntry.getMessageAttributes());
long msgBodySize = getStringSizeInBytes(batchEntry.getMessageBody());
- long totalMsgSize = msgAttributesSize + msgBodySize;
- return (totalMsgSize > clientConfiguration.getMessageSizeThreshold());
+ return msgAttributesSize + msgBodySize;
}
private int getMsgAttributesSize(Map msgAttributes) {
diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java
index 3dd70c3..8c3cd14 100644
--- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java
+++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java
@@ -15,44 +15,47 @@
package com.amazon.sqs.javamessaging;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.PutObjectRequest;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClient;
-import com.amazonaws.services.sqs.model.MessageAttributeValue;
-import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
-import com.amazonaws.services.sqs.model.ReceiveMessageResult;
-import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
-import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry;
-import com.amazonaws.services.sqs.model.SendMessageRequest;
-
-import junit.framework.Assert;
+import com.amazonaws.services.sqs.model.*;
import org.junit.Before;
import org.junit.Test;
+import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.runners.MockitoJUnitRunner;
+import org.mockito.stubbing.Answer;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import static junit.framework.TestCase.assertTrue;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.isA;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
/**
* Tests the AmazonSQSExtendedClient class.
*/
+@RunWith(MockitoJUnitRunner.class)
public class AmazonSQSExtendedClientTest {
private AmazonSQS extendedSqsWithDefaultConfig;
+
+ @Mock
private AmazonSQS mockSqsBackend;
+
+ @Mock
private AmazonS3 mockS3;
+
+ private List> sendMessageBatchInvocationEntries = new ArrayList<>();
+
private static final String S3_BUCKET_NAME = "test-bucket-name";
private static final String SQS_QUEUE_URL = "test-queue-url";
@@ -61,14 +64,31 @@ public class AmazonSQSExtendedClientTest {
private static final int MORE_THAN_SQS_SIZE_LIMIT = SQS_SIZE_LIMIT + 1;
// should be > 1 and << SQS_SIZE_LIMIT
- private static final int ARBITRATY_SMALLER_THRESSHOLD = 500;
+ private static final int ARBITRARY_SMALLER_THRESHOLD = 500;
@Before
public void setupClient() {
- mockS3 = mock(AmazonS3.class);
- mockSqsBackend = mock(AmazonSQS.class);
+
when(mockS3.putObject(isA(PutObjectRequest.class))).thenReturn(null);
+ // send message batch must return a result and must record the entries used
+ // we can't use a captor for this as the send message batch request object is
+ // changed during execution, so the captor ends up with multiple references to
+ // the same object and we don't see the state of the earliest
+ when(mockSqsBackend.sendMessageBatch(any(SendMessageBatchRequest.class)))
+ .thenAnswer(new Answer() {
+
+ @Override
+ public SendMessageBatchResult answer(InvocationOnMock invocation) throws Throwable {
+ // record the entries
+ List entries =
+ invocation.getArgumentAt(0, SendMessageBatchRequest.class).getEntries();
+ sendMessageBatchInvocationEntries.add(entries);
+
+ return new SendMessageBatchResult();
+ }
+ });
+
ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration()
.withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME);
@@ -123,15 +143,15 @@ public void testWhenSendMessageWithAlwaysThroughS3AndMessageIsSmallThenItIsStill
SendMessageRequest messageRequest = new SendMessageRequest(SQS_QUEUE_URL, messageBody);
sqsExtended.sendMessage(messageRequest);
- verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class));
+ verify(mockS3).putObject(isA(PutObjectRequest.class));
}
@Test
public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonored() {
- int messageLength = ARBITRATY_SMALLER_THRESSHOLD * 2;
+ int messageLength = ARBITRARY_SMALLER_THRESHOLD * 2;
String messageBody = generateStringWithLength(messageLength);
ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration()
- .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withMessageSizeThreshold(ARBITRATY_SMALLER_THRESSHOLD);
+ .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withMessageSizeThreshold(ARBITRARY_SMALLER_THRESHOLD);
AmazonSQS sqsExtended = spy(new AmazonSQSExtendedClient(mock(AmazonSQSClient.class), extendedClientConfiguration));
@@ -152,15 +172,39 @@ public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessag
.withMessageAttributeNames(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME);
sqsExtended.receiveMessage(messageRequest);
- Assert.assertEquals(expectedRequest, messageRequest);
+ assertEquals(expectedRequest, messageRequest);
sqsExtended.receiveMessage(messageRequest);
- Assert.assertEquals(expectedRequest, messageRequest);
+ assertEquals(expectedRequest, messageRequest);
+ }
+
+ @Test
+ public void testWhenSmallMessageBatchIsSentThenNoMessagesStoredInS3() {
+ // This creates 10 messages all well within the threshold
+
+ int[] messageLengthForCounter = new int[] {
+ 1_000,
+ 1_000,
+ 1_000,
+ 1_000,
+ 1_000,
+ 1_000,
+ 1_000,
+ 1_000,
+ 1_000,
+ 1_000
+ };
+
+ SendMessageBatchRequest batchRequest = createMessageBatchWithSizes(messageLengthForCounter);
+ extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest);
+
+ // There should be no puts
+ verify(mockS3, never()).putObject(isA(PutObjectRequest.class));
}
@Test
public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() {
- // This creates 10 messages, out of which only two are below the threshold (100K and 200K),
+ // This creates 10 messages, out of which only two are below the threshold (100K and 20K),
// and the other 8 are above the threshold
int[] messageLengthForCounter = new int[] {
@@ -172,25 +216,86 @@ public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStor
700_000,
800_000,
900_000,
- 200_000,
+ 20_000,
1000_000
};
- List batchEntries = new ArrayList();
- for (int i = 0; i < 10; i++) {
- SendMessageBatchRequestEntry entry = new SendMessageBatchRequestEntry();
- int messageLength = messageLengthForCounter[i];
- String messageBody = generateStringWithLength(messageLength);
- entry.setMessageBody(messageBody);
- entry.setId("entry_" + i);
- batchEntries.add(entry);
- }
-
- SendMessageBatchRequest batchRequest = new SendMessageBatchRequest(SQS_QUEUE_URL, batchEntries);
+ SendMessageBatchRequest batchRequest = createMessageBatchWithSizes(messageLengthForCounter);
extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest);
- // There should be 8 puts for the 8 messages above the threshhold
+ // There should be 8 puts for the 8 messages above the threshold
verify(mockS3, times(8)).putObject(isA(PutObjectRequest.class));
+
+ // and one batch send
+ verify(mockSqsBackend).sendMessageBatch(any(SendMessageBatchRequest.class));
+ }
+
+ @Test
+ public void testWhenMessageBatchIsSentWhereSumOfMessageSizesIsOverTheThresholdThenBatchIsSplit() {
+ // This creates 10 messages, all of which are below the threshold, but together would make
+ // a single request over the threshold
+
+ int[] messageLengthForCounter = new int[] {
+ 26_214,
+ 26_214,
+ 26_214,
+ 26_214,
+ 26_214,
+ 26_214,
+ 26_214,
+ 26_214,
+ 26_214,
+ 26_219
+ };
+
+ SendMessageBatchRequest batchRequest = createMessageBatchWithSizes(messageLengthForCounter);
+ extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest);
+
+ // The client should not put any objects to S3
+ verify(mockS3, never()).putObject(isA(PutObjectRequest.class));
+
+ // The client should have made two requests to SQS
+ verify(mockSqsBackend, times(2)).sendMessageBatch(any(SendMessageBatchRequest.class));
+
+ // the client will have put most messages in the first batch, then the remainder in a second
+ assertEquals(9, sendMessageBatchInvocationEntries.get(0).size());
+ assertEquals(1, sendMessageBatchInvocationEntries.get(1).size());
+ }
+
+ @Test
+ public void testWhenMessageBatchIsMadeOfLargeMessagesThenBatchIsSplitAndOrderMaintained() {
+ // This creates 10 messages, all of which are below the threshold, but together would make
+ // a single request over the threshold
+
+ int[] messageLengthForCounter = new int[] {
+ SQS_SIZE_LIMIT,
+ SQS_SIZE_LIMIT - 1,
+ SQS_SIZE_LIMIT - 2,
+ SQS_SIZE_LIMIT - 3,
+ SQS_SIZE_LIMIT - 4,
+ SQS_SIZE_LIMIT - 5,
+ SQS_SIZE_LIMIT - 6,
+ SQS_SIZE_LIMIT - 7,
+ SQS_SIZE_LIMIT - 8,
+ SQS_SIZE_LIMIT - 9
+ };
+
+ SendMessageBatchRequest batchRequest = createMessageBatchWithSizes(messageLengthForCounter);
+ extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest);
+
+ // The client should not put any objects to S3 as they are all small enough
+ // to send to SQS
+ verify(mockS3, never()).putObject(isA(PutObjectRequest.class));
+
+ // The client should have sent each item as a batch request
+ verify(mockSqsBackend, times(10)).sendMessageBatch(any(SendMessageBatchRequest.class));
+
+ // the order of messages has been preserved
+ for (int i = 0; i < messageLengthForCounter.length; i++) {
+ // each batch should correspond to the message length from the list
+ assertEquals(messageLengthForCounter[i],
+ sendMessageBatchInvocationEntries.get(i).get(0).getMessageBody().length());
+ }
}
@Test
@@ -205,11 +310,11 @@ public void testWhenSmallMessageIsSentThenNoAttributeIsAdded() {
verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture());
Map attributes = sendMessageRequestCaptor.getValue().getMessageAttributes();
- Assert.assertTrue(attributes.isEmpty());
+ assertTrue(attributes.isEmpty());
}
@Test
- public void testWhenLargeMessgaeIsSentThenAttributeWithPayloadSizeIsAdded() {
+ public void testWhenLargeMessageIsSentThenAttributeWithPayloadSizeIsAdded() {
int messageLength = MORE_THAN_SQS_SIZE_LIMIT;
String messageBody = generateStringWithLength(messageLength);
@@ -220,8 +325,22 @@ public void testWhenLargeMessgaeIsSentThenAttributeWithPayloadSizeIsAdded() {
verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture());
Map attributes = sendMessageRequestCaptor.getValue().getMessageAttributes();
- Assert.assertEquals("Number", attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getDataType());
- Assert.assertEquals(messageLength, (int)Integer.valueOf(attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getStringValue()));
+ assertEquals("Number", attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getDataType());
+ assertEquals(messageLength, (int)Integer.valueOf(attributes.get(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME).getStringValue()));
+ }
+
+ private SendMessageBatchRequest createMessageBatchWithSizes(int[] messageLengthForCounter) {
+ List batchEntries = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+ SendMessageBatchRequestEntry entry = new SendMessageBatchRequestEntry();
+ int messageLength = messageLengthForCounter[i];
+ String messageBody = generateStringWithLength(messageLength);
+ entry.setMessageBody(messageBody);
+ entry.setId("entry_" + i);
+ batchEntries.add(entry);
+ }
+
+ return new SendMessageBatchRequest(SQS_QUEUE_URL, batchEntries);
}
private String generateStringWithLength(int messageLength) {