diff --git a/aws_lambda_powertools/utilities/batch/__init__.py b/aws_lambda_powertools/utilities/batch/__init__.py index 02f3e786441..0e2637cc358 100644 --- a/aws_lambda_powertools/utilities/batch/__init__.py +++ b/aws_lambda_powertools/utilities/batch/__init__.py @@ -16,16 +16,22 @@ batch_processor, ) from aws_lambda_powertools.utilities.batch.exceptions import ExceptionInfo +from aws_lambda_powertools.utilities.batch.sqs_fifo_partial_processor import ( + SqsFifoPartialProcessor, +) +from aws_lambda_powertools.utilities.batch.types import BatchTypeModels __all__ = ( "BatchProcessor", "AsyncBatchProcessor", "BasePartialProcessor", "BasePartialBatchProcessor", + "BatchTypeModels", "ExceptionInfo", "EventType", "FailureResponse", "SuccessResponse", + "SqsFifoPartialProcessor", "batch_processor", "async_batch_processor", ) diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 171858c6d11..3aea2b70fa4 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -19,7 +19,6 @@ List, Optional, Tuple, - Type, Union, overload, ) @@ -30,6 +29,7 @@ BatchProcessingError, ExceptionInfo, ) +from aws_lambda_powertools.utilities.batch.types import BatchTypeModels from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( DynamoDBRecord, ) @@ -48,24 +48,6 @@ class EventType(Enum): DynamoDBStreams = "DynamoDBStreams" -# -# type specifics -# -has_pydantic = "pydantic" in sys.modules - -# For IntelliSense and Mypy to work, we need to account for possible SQS, Kinesis and DynamoDB subclasses -# We need them as subclasses as we must access their message ID or sequence number metadata via dot notation -if has_pydantic: - from aws_lambda_powertools.utilities.parser.models import DynamoDBStreamRecordModel - from aws_lambda_powertools.utilities.parser.models import ( - KinesisDataStreamRecord as KinesisDataStreamRecordModel, - ) - from aws_lambda_powertools.utilities.parser.models import SqsRecordModel - - BatchTypeModels = Optional[ - Union[Type[SqsRecordModel], Type[DynamoDBStreamRecordModel], Type[KinesisDataStreamRecordModel]] - ] - # When using processor with default arguments, records will carry EventSourceDataClassTypes # and depending on what EventType it's passed it'll correctly map to the right record # When using Pydantic Models, it'll accept any subclass from SQS, DynamoDB and Kinesis diff --git a/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py b/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py new file mode 100644 index 00000000000..d48749a137e --- /dev/null +++ b/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py @@ -0,0 +1,92 @@ +from typing import List, Optional, Tuple + +from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType +from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel + + +class SQSFifoCircuitBreakerError(Exception): + """ + Signals a record not processed due to the SQS FIFO processing being interrupted + """ + + pass + + +class SqsFifoPartialProcessor(BatchProcessor): + """Process native partial responses from SQS FIFO queues. + + Stops processing records when the first record fails. The remaining records are reported as failed items. + + Example + _______ + + ## Process batch triggered by a FIFO SQS + + ```python + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import SqsFifoPartialProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = SqsFifoPartialProcessor() + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: SQSRecord): + payload: str = record.body + if payload: + item: dict = json.loads(payload) + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + """ + + circuit_breaker_exc = ( + SQSFifoCircuitBreakerError, + SQSFifoCircuitBreakerError("A previous record failed processing"), + None, + ) + + def __init__(self, model: Optional["BatchSqsTypeModel"] = None): + super().__init__(EventType.SQS, model) + + def process(self) -> List[Tuple]: + """ + Call instance's handler for each record. When the first failed message is detected, + the process is short-circuited, and the remaining messages are reported as failed items. + """ + result: List[Tuple] = [] + + for i, record in enumerate(self.records): + # If we have failed messages, it means that the last message failed. + # We then short circuit the process, failing the remaining messages + if self.fail_messages: + return self._short_circuit_processing(i, result) + + # Otherwise, process the message normally + result.append(self._process_record(record)) + + return result + + def _short_circuit_processing(self, first_failure_index: int, result: List[Tuple]) -> List[Tuple]: + """ + Starting from the first failure index, fail all the remaining messages, and append them to the result list. + """ + remaining_records = self.records[first_failure_index:] + for remaining_record in remaining_records: + data = self._to_batch_type(record=remaining_record, event_type=self.event_type, model=self.model) + result.append(self.failure_handler(record=data, exception=self.circuit_breaker_exc)) + return result + + async def _async_process_record(self, record: dict): + raise NotImplementedError() diff --git a/aws_lambda_powertools/utilities/batch/types.py b/aws_lambda_powertools/utilities/batch/types.py new file mode 100644 index 00000000000..1fc5aba4fc4 --- /dev/null +++ b/aws_lambda_powertools/utilities/batch/types.py @@ -0,0 +1,24 @@ +# +# type specifics +# +import sys +from typing import Optional, Type, Union + +has_pydantic = "pydantic" in sys.modules + +# For IntelliSense and Mypy to work, we need to account for possible SQS subclasses +# We need them as subclasses as we must access their message ID or sequence number metadata via dot notation +if has_pydantic: + from aws_lambda_powertools.utilities.parser.models import DynamoDBStreamRecordModel + from aws_lambda_powertools.utilities.parser.models import ( + KinesisDataStreamRecord as KinesisDataStreamRecordModel, + ) + from aws_lambda_powertools.utilities.parser.models import SqsRecordModel + + BatchTypeModels = Optional[ + Union[Type[SqsRecordModel], Type[DynamoDBStreamRecordModel], Type[KinesisDataStreamRecordModel]] + ] + BatchSqsTypeModel = Optional[Type[SqsRecordModel]] +else: + BatchTypeModels = "BatchTypeModels" # type: ignore + BatchSqsTypeModel = "BatchSqsTypeModel" # type: ignore diff --git a/docs/utilities/batch.md b/docs/utilities/batch.md index 4a53e053f44..0f899673c2e 100644 --- a/docs/utilities/batch.md +++ b/docs/utilities/batch.md @@ -347,6 +347,23 @@ Processing batches from SQS works in four stages: } ``` +#### FIFO queues + +When using [SQS FIFO queues](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/FIFO-queues.html){target="_blank"}, we will stop processing messages after the first failure, and return all failed and unprocessed messages in `batchItemFailures`. +This helps preserve the ordering of messages in your queue. + +=== "As a decorator" + + ```python hl_lines="5 11" + --8<-- "examples/batch_processing/src/sqs_fifo_batch_processor.py" + ``` + +=== "As a context manager" + + ```python hl_lines="4 8" + --8<-- "examples/batch_processing/src/sqs_fifo_batch_processor_context_manager.py" + ``` + ### Processing messages from Kinesis Processing batches from Kinesis works in four stages: diff --git a/examples/batch_processing/src/sqs_fifo_batch_processor.py b/examples/batch_processing/src/sqs_fifo_batch_processor.py new file mode 100644 index 00000000000..a5fe9f23235 --- /dev/null +++ b/examples/batch_processing/src/sqs_fifo_batch_processor.py @@ -0,0 +1,23 @@ +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.utilities.batch import ( + SqsFifoPartialProcessor, + batch_processor, +) +from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.typing import LambdaContext + +processor = SqsFifoPartialProcessor() +tracer = Tracer() +logger = Logger() + + +@tracer.capture_method +def record_handler(record: SQSRecord): + ... + + +@logger.inject_lambda_context +@tracer.capture_lambda_handler +@batch_processor(record_handler=record_handler, processor=processor) +def lambda_handler(event, context: LambdaContext): + return processor.response() diff --git a/examples/batch_processing/src/sqs_fifo_batch_processor_context_manager.py b/examples/batch_processing/src/sqs_fifo_batch_processor_context_manager.py new file mode 100644 index 00000000000..45759b2a585 --- /dev/null +++ b/examples/batch_processing/src/sqs_fifo_batch_processor_context_manager.py @@ -0,0 +1,23 @@ +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.utilities.batch import SqsFifoPartialProcessor +from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.typing import LambdaContext + +processor = SqsFifoPartialProcessor() +tracer = Tracer() +logger = Logger() + + +@tracer.capture_method +def record_handler(record: SQSRecord): + ... + + +@logger.inject_lambda_context +@tracer.capture_lambda_handler +def lambda_handler(event, context: LambdaContext): + batch = event["Records"] + with processor(records=batch, handler=record_handler): + processor.process() # kick off processing, return List[Tuple] + + return processor.response() diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index 6dcfc3d179d..c98d59a7042 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -1,4 +1,5 @@ import json +import uuid from random import randint from typing import Any, Awaitable, Callable, Dict, Optional @@ -9,6 +10,7 @@ AsyncBatchProcessor, BatchProcessor, EventType, + SqsFifoPartialProcessor, async_batch_processor, batch_processor, ) @@ -40,7 +42,7 @@ def sqs_event_factory() -> Callable: def factory(body: str): return { - "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "messageId": f"{uuid.uuid4()}", "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", "body": body, "attributes": { @@ -654,6 +656,48 @@ def lambda_handler(event, context): assert "All records failed processing. " in str(e.value) +def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_factory, record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("success")) + second_record = SQSRecord(sqs_event_factory("success")) + event = {"Records": [first_record.raw_event, second_record.raw_event]} + + processor = SqsFifoPartialProcessor() + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN + assert result["batchItemFailures"] == [] + + +def test_sqs_fifo_batch_processor_middleware_with_failure(sqs_event_factory, record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("success")) + second_record = SQSRecord(sqs_event_factory("fail")) + # this would normally succeed, but since it's a FIFO queue, it will be marked as failure + third_record = SQSRecord(sqs_event_factory("success")) + event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]} + + processor = SqsFifoPartialProcessor() + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN + assert len(result["batchItemFailures"]) == 2 + assert result["batchItemFailures"][0]["itemIdentifier"] == second_record.message_id + assert result["batchItemFailures"][1]["itemIdentifier"] == third_record.message_id + + def test_async_batch_processor_middleware_success_only(sqs_event_factory, async_record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("success"))