From ef586c587bb165f8b8735f49a6a9e2772c564ab8 Mon Sep 17 00:00:00 2001 From: Nikita Bahliuk Date: Thu, 13 Jun 2024 17:50:52 +0300 Subject: [PATCH 1/2] Correct offloading of large batch entries to s3. Fix case when individual entries are under the threshold, but total batch size exceeds the threshold. Add test case covering the use-case, adjust old test that is now failing (wrong test-case) --- .../AmazonSQSExtendedClient.java | 44 +++++++++++++------ .../AmazonSQSExtendedClientUtil.java | 10 +++-- .../AmazonSQSExtendedClientTest.java | 42 +++++++++++++++++- 3 files changed, 78 insertions(+), 18 deletions(-) diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 68317d2..a0889ca 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -22,14 +22,18 @@ import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getReservedAttributeNameIfPresent; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isLarge; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.sizeOf; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize; import java.util.ArrayList; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -73,6 +77,7 @@ import software.amazon.awssdk.services.sqs.model.SendMessageResponse; import software.amazon.awssdk.services.sqs.model.SqsException; import software.amazon.awssdk.services.sqs.model.TooManyEntriesInBatchRequestException; +import software.amazon.awssdk.utils.Pair; import software.amazon.awssdk.utils.StringUtils; import software.amazon.payloadoffloading.PayloadStore; import software.amazon.payloadoffloading.S3BackedPayloadStore; @@ -616,23 +621,36 @@ public SendMessageBatchResponse sendMessageBatch(SendMessageBatchRequest sendMes return super.sendMessageBatch(sendMessageBatchRequest); } - List batchEntries = new ArrayList<>(sendMessageBatchRequest.entries().size()); + List originalEntries = sendMessageBatchRequest.entries(); + ArrayList alteredEntries = new ArrayList<>(originalEntries.size()); + alteredEntries.addAll(originalEntries); + // Batch entry sizes order by size + List> entrySizes = IntStream.range(0, originalEntries.size()) + .boxed() + .map(i -> Pair.of(i, sizeOf(originalEntries.get(i)))) + .sorted((p1, p2) -> Long.compare(p2.right(), p1.right())) + .collect(Collectors.toList()); + + long totalSize = entrySizes.stream().map(Pair::right).mapToLong(Long::longValue).sum(); + + // Move messages to s3 starting from the largest until total size is under the threshold if needed boolean hasS3Entries = false; - for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) { - //Check message attributes for ExtendedClient related constraints - checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes()); - - if (clientConfiguration.isAlwaysThroughS3() - || isLarge(clientConfiguration.getPayloadSizeThreshold(), entry)) { - entry = storeMessageInS3(entry); - hasS3Entries = true; + for (Pair pair : entrySizes) { + // Verify that total size of batch request is within limits + if (totalSize <= clientConfiguration.getPayloadSizeThreshold() && !clientConfiguration.isAlwaysThroughS3()) { + break; } - batchEntries.add(entry); + Integer entryIndex = pair.left(); + Long originalEntrySize = pair.right(); + SendMessageBatchRequestEntry alteredEntry = storeMessageInS3(originalEntries.get(entryIndex)); + totalSize = totalSize - originalEntrySize + sizeOf(alteredEntry); + alteredEntries.set(entryIndex, alteredEntry); + hasS3Entries = true; } if (hasS3Entries) { - sendMessageBatchRequest = sendMessageBatchRequest.toBuilder().entries(batchEntries).build(); + sendMessageBatchRequest = sendMessageBatchRequest.toBuilder().entries(alteredEntries).build(); } return super.sendMessageBatch(sendMessageBatchRequest); @@ -896,6 +914,6 @@ private static T appendUserAgent(final T builder) public void close() { super.close(); this.clientConfiguration.getS3Client().close(); - } - + } + } diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java index 8bf1609..359ffc7 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java @@ -101,10 +101,14 @@ public static boolean isLarge(int payloadSizeThreshold, SendMessageRequest sendM return (totalMsgSize > payloadSizeThreshold); } + public static long sizeOf(SendMessageBatchRequestEntry batchRequestEntry) { + int msgAttributesSize = getMsgAttributesSize(batchRequestEntry.messageAttributes()); + long msgBodySize = Util.getStringSizeInBytes(batchRequestEntry.messageBody()); + return msgAttributesSize + msgBodySize; + } + public static boolean isLarge(int payloadSizeThreshold, SendMessageBatchRequestEntry batchEntry) { - int msgAttributesSize = getMsgAttributesSize(batchEntry.messageAttributes()); - long msgBodySize = Util.getStringSizeInBytes(batchEntry.messageBody()); - long totalMsgSize = msgAttributesSize + msgBodySize; + long totalMsgSize = sizeOf(batchEntry); return (totalMsgSize > payloadSizeThreshold); } diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index fd58b0b..20649f5 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -425,7 +425,7 @@ public void testReceiveMessage_when_MessageIsSmall() { @Test public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() { - // This creates 10 messages, out of which only two are below the threshold (100K and 200K), + // This creates 10 messages, out of which only two are below the threshold (100K and 150K), // and the other 8 are above the threshold int[] messageLengthForCounter = new int[] { @@ -437,7 +437,7 @@ public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStor 700_000, 800_000, 900_000, - 200_000, + 150_000, 1000_000 }; @@ -459,6 +459,44 @@ public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStor verify(mockS3, times(8)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); } + + @Test + public void testWhenMessageBatchWithTotalSizeOverTheLimitIsSentThenLargestEntriesAreStoredInS3() { + // This creates 10 messages, out of which only two are below the threshold (100K and 150K), + // and the other 8 are above the threshold + + int[] messageLengthForCounter = new int[] { + 10_000, + 10_000, + 10_000, + 150_000, + 160_000, + 170_000, + 180_000, + 10_000, + 10_000, + 10_000 + }; + + List batchEntries = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + int messageLength = messageLengthForCounter[i]; + String messageBody = generateStringWithLength(messageLength); + SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder() + .id("entry_" + i) + .messageBody(messageBody) + .build(); + batchEntries.add(entry); + } + + SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build(); + extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest); + + // There should be 3 puts for the 3 largest messages as sum of sizes of others should be within limit + verify(mockS3, times(3)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + } + + @Test public void testWhenMessageBatchIsLargeS3PointerIsCorrectlySentToSQSAndNotOriginalMessage() { String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT); From 8e675d93f92bb2d26a501113657b14be687dbbda Mon Sep 17 00:00:00 2001 From: Nikita Bahliuk Date: Thu, 13 Jun 2024 18:08:07 +0300 Subject: [PATCH 2/2] add missing validation check --- .../com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index a0889ca..5902c17 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -643,7 +643,9 @@ public SendMessageBatchResponse sendMessageBatch(SendMessageBatchRequest sendMes } Integer entryIndex = pair.left(); Long originalEntrySize = pair.right(); - SendMessageBatchRequestEntry alteredEntry = storeMessageInS3(originalEntries.get(entryIndex)); + SendMessageBatchRequestEntry originalEntry = originalEntries.get(entryIndex); + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), originalEntry.messageAttributes()); + SendMessageBatchRequestEntry alteredEntry = storeMessageInS3(originalEntry); totalSize = totalSize - originalEntrySize + sizeOf(alteredEntry); alteredEntries.set(entryIndex, alteredEntry); hasS3Entries = true;