Skip to content

Commit f8a7390

Browse files
Adding a flag to SqsFifoProcessor to allow message processing to continue
1 parent c3e36de commit f8a7390

File tree

3 files changed

+93
-16
lines changed

3 files changed

+93
-16
lines changed

aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def lambda_handler(event, context: LambdaContext):
5757
None,
5858
)
5959

60-
def __init__(self, model: Optional["BatchSqsTypeModel"] = None):
60+
def __init__(self, model: Optional["BatchSqsTypeModel"] = None, return_on_first_error: bool = True):
61+
self.return_on_first_error = return_on_first_error
6162
super().__init__(EventType.SQS, model)
6263

6364
def process(self) -> List[Tuple]:
@@ -68,13 +69,26 @@ def process(self) -> List[Tuple]:
6869
result: List[Tuple] = []
6970

7071
for i, record in enumerate(self.records):
71-
# If we have failed messages, it means that the last message failed.
72-
# We then short circuit the process, failing the remaining messages
73-
if self.fail_messages:
72+
# If we have failed messages and we are set to return on the first error,
73+
# short circuit the process and return the remaining messages as failed items
74+
if self.fail_messages and self.return_on_first_error:
7475
return self._short_circuit_processing(i, result)
7576

76-
# Otherwise, process the message normally
77-
result.append(self._process_record(record))
77+
# Process the current record
78+
processed_messages = self._process_record(record)
79+
80+
# If a processed message fail,
81+
# mark subsequent messages from the same MessageGroupId as skipped
82+
if processed_messages[0] == "fail":
83+
for subsequent_record in self.records[i + 1 :]:
84+
if subsequent_record.get("attributes", {}).get("MessageGroupId") == record.get(
85+
"attributes",
86+
{},
87+
).get("MessageGroupId"):
88+
continue # Skip subsequent message from the same MessageGroupId
89+
90+
# Append the processed message normally
91+
result.append(processed_messages)
7892

7993
return result
8094

docs/utilities/batch.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,12 @@ Processing batches from SQS works in three stages:
141141

142142
#### FIFO queues
143143

144-
When using [SQS FIFO queues](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/FIFO-queues.html){target="_blank" rel="nofollow"}, we will stop processing messages after the first failure, and return all failed and unprocessed messages in `batchItemFailures`.
145-
This helps preserve the ordering of messages in your queue.
144+
When working with [SQS FIFO queues](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/FIFO-queues.html){target="_blank"}, it's important to know that a batch sent from SQS to Lambda can include multiple messages from different group IDs.
145+
146+
By default, message processing halts after the initial failure, returning all failed and unprocessed messages in `batchItemFailures` to preserve the ordering of messages in your queue. However, customers can opt to continue processing messages and retrieve failed messages within a message group ID by setting `return_on_first_error` to False.
147+
148+
???+ notice "Having problems with DLQ?"
149+
`AsyncBatchProcessor` uses `asyncio.gather`. This might cause [side effects and reach trace limits at high concurrency](../core/tracer.md#concurrent-asynchronous-functions){target="_blank"}.
146150

147151
=== "Recommended"
148152

tests/functional/test_utilities_batch.py

+67-8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,32 @@
3838
from tests.functional.utils import b64_to_str, str_to_b64
3939

4040

41+
@pytest.fixture(scope="module")
42+
def sqs_event_fifo_factory() -> Callable:
43+
def factory(body: str, message_group_id: str = ""):
44+
return {
45+
"messageId": f"{uuid.uuid4()}",
46+
"receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a",
47+
"body": body,
48+
"attributes": {
49+
"ApproximateReceiveCount": "1",
50+
"SentTimestamp": "1703675223472",
51+
"SequenceNumber": "18882884930918384133",
52+
"MessageGroupId": message_group_id,
53+
"SenderId": "SenderId",
54+
"MessageDeduplicationId": "1eea03c3f7e782c7bdc2f2a917f40389314733ff39f5ab16219580c0109ade98",
55+
"ApproximateFirstReceiveTimestamp": "1703675223484",
56+
},
57+
"messageAttributes": {},
58+
"md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3",
59+
"eventSource": "aws:sqs",
60+
"eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue",
61+
"awsRegion": "us-east-1",
62+
}
63+
64+
return factory
65+
66+
4167
@pytest.fixture(scope="module")
4268
def sqs_event_factory() -> Callable:
4369
def factory(body: str):
@@ -48,7 +74,7 @@ def factory(body: str):
4874
"attributes": {
4975
"ApproximateReceiveCount": "1",
5076
"SentTimestamp": "1545082649183",
51-
"SenderId": "AIDAIENQZJOLO23YVJ4VO",
77+
"SenderId": "SenderId",
5278
"ApproximateFirstReceiveTimestamp": "1545082649185",
5379
},
5480
"messageAttributes": {},
@@ -660,10 +686,10 @@ def lambda_handler(event, context):
660686
assert "All records failed processing. " in str(e.value)
661687

662688

663-
def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_factory, record_handler):
689+
def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_fifo_factory, record_handler):
664690
# GIVEN
665-
first_record = SQSRecord(sqs_event_factory("success"))
666-
second_record = SQSRecord(sqs_event_factory("success"))
691+
first_record = SQSRecord(sqs_event_fifo_factory("success"))
692+
second_record = SQSRecord(sqs_event_fifo_factory("success"))
667693
event = {"Records": [first_record.raw_event, second_record.raw_event]}
668694

669695
processor = SqsFifoPartialProcessor()
@@ -679,12 +705,12 @@ def lambda_handler(event, context):
679705
assert result["batchItemFailures"] == []
680706

681707

682-
def test_sqs_fifo_batch_processor_middleware_with_failure(sqs_event_factory, record_handler):
708+
def test_sqs_fifo_batch_processor_middleware_with_failure(sqs_event_fifo_factory, record_handler):
683709
# GIVEN
684-
first_record = SQSRecord(sqs_event_factory("success"))
685-
second_record = SQSRecord(sqs_event_factory("fail"))
710+
first_record = SQSRecord(sqs_event_fifo_factory("success"))
711+
second_record = SQSRecord(sqs_event_fifo_factory("fail"))
686712
# this would normally succeed, but since it's a FIFO queue, it will be marked as failure
687-
third_record = SQSRecord(sqs_event_factory("success"))
713+
third_record = SQSRecord(sqs_event_fifo_factory("success"))
688714
event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]}
689715

690716
processor = SqsFifoPartialProcessor()
@@ -702,6 +728,39 @@ def lambda_handler(event, context):
702728
assert result["batchItemFailures"][1]["itemIdentifier"] == third_record.message_id
703729

704730

731+
def test_sqs_fifo_batch_processor_middleware_without_first_failure(sqs_event_fifo_factory, record_handler):
732+
# GIVEN a batch of 5 records with 3 different MessageGroupID
733+
first_record = SQSRecord(sqs_event_fifo_factory("success", "1"))
734+
second_record = SQSRecord(sqs_event_fifo_factory("success", "1"))
735+
third_record = SQSRecord(sqs_event_fifo_factory("fail", "2"))
736+
fourth_record = SQSRecord(sqs_event_fifo_factory("fail", "2"))
737+
fifth_record = SQSRecord(sqs_event_fifo_factory("success", "3"))
738+
event = {
739+
"Records": [
740+
first_record.raw_event,
741+
second_record.raw_event,
742+
third_record.raw_event,
743+
fourth_record.raw_event,
744+
fifth_record.raw_event,
745+
],
746+
}
747+
748+
# WHEN the FIFO processor is set to continue processing even after encountering errors in specific MessageGroupID
749+
processor = SqsFifoPartialProcessor(return_on_first_error=False)
750+
751+
@batch_processor(record_handler=record_handler, processor=processor)
752+
def lambda_handler(event, context):
753+
return processor.response()
754+
755+
# WHEN
756+
result = lambda_handler(event, {})
757+
758+
# THEN only failed messages should originate from MessageGroupID 2
759+
assert len(result["batchItemFailures"]) == 2
760+
assert result["batchItemFailures"][0]["itemIdentifier"] == third_record.message_id
761+
assert result["batchItemFailures"][1]["itemIdentifier"] == fourth_record.message_id
762+
763+
705764
def test_async_batch_processor_middleware_success_only(sqs_event_factory, async_record_handler):
706765
# GIVEN
707766
first_record = SQSRecord(sqs_event_factory("success"))

0 commit comments

Comments
 (0)