diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 569467f2248..72d43d8af82 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -222,7 +222,12 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse: class BasePartialBatchProcessor(BasePartialProcessor): # noqa DEFAULT_RESPONSE: PartialItemFailureResponse = {"batchItemFailures": []} - def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None): + def __init__( + self, + event_type: EventType, + model: Optional["BatchTypeModels"] = None, + raise_on_entire_batch_failure: bool = True, + ): """Process batch and partially report failed items Parameters @@ -231,6 +236,9 @@ def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = N Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event model: Optional["BatchTypeModels"] Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord + raise_on_entire_batch_failure: bool + Raise an exception when the entire batch has failed processing. + When set to False, partial failures are reported in the response Exceptions ---------- @@ -239,6 +247,7 @@ def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = N """ self.event_type = event_type self.model = model + self.raise_on_entire_batch_failure = raise_on_entire_batch_failure self.batch_response: PartialItemFailureResponse = copy.deepcopy(self.DEFAULT_RESPONSE) self._COLLECTOR_MAPPING = { EventType.SQS: self._collect_sqs_failures, @@ -274,7 +283,7 @@ def _clean(self): if not self._has_messages_to_report(): return - if self._entire_batch_failed(): + if self._entire_batch_failed() and self.raise_on_entire_batch_failure: raise BatchProcessingError( msg=f"All records failed processing. {len(self.exceptions)} individual errors logged " f"separately below.", @@ -475,7 +484,7 @@ def lambda_handler(event, context: LambdaContext): Raises ------ BatchProcessingError - When all batch records fail processing + When all batch records fail processing and raise_on_entire_batch_failure is True Limitations ----------- @@ -624,7 +633,7 @@ def lambda_handler(event, context: LambdaContext): Raises ------ BatchProcessingError - When all batch records fail processing + When all batch records fail processing and raise_on_entire_batch_failure is True Limitations ----------- diff --git a/docs/utilities/batch.md b/docs/utilities/batch.md index 6b8e0fd3000..65efb6a1805 100644 --- a/docs/utilities/batch.md +++ b/docs/utilities/batch.md @@ -491,6 +491,20 @@ Inheritance is importance because we need to access message IDs and sequence num --8<-- "examples/batch_processing/src/pydantic_dynamodb_event.json" ``` +### Working with full batch failures + +By default, the `BatchProcessor` will raise `BatchProcessingError` if all records in the batch fail to process, we do this to reflect the failure in your operational metrics. + +When working with functions that handle batches with a small number of records, or when you use errors as a flow control mechanism, this behavior might not be desirable as your function might generate an unnaturally high number of errors. When this happens, the [Lambda service will scale down the concurrency of your function](https://docs.aws.amazon.com/lambda/latest/dg/services-sqs-errorhandling.html#services-sqs-backoff-strategy){target="_blank"}, potentially impacting performance. + +For these scenarios, you can set the `raise_on_entire_batch_failure` option to `False`. + +=== "working_with_entire_batch_fail.py" + + ```python hl_lines="10" + --8<-- "examples/batch_processing/src/working_with_entire_batch_fail.py" + ``` + ### Accessing processed messages Use the context manager to access a list of all returned values from your `record_handler` function. diff --git a/examples/batch_processing/src/working_with_entire_batch_fail.py b/examples/batch_processing/src/working_with_entire_batch_fail.py new file mode 100644 index 00000000000..9058ce23483 --- /dev/null +++ b/examples/batch_processing/src/working_with_entire_batch_fail.py @@ -0,0 +1,29 @@ +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.utilities.batch import ( + BatchProcessor, + EventType, + process_partial_response, +) +from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.typing import LambdaContext + +processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False) +tracer = Tracer() +logger = Logger() + + +@tracer.capture_method +def record_handler(record: SQSRecord): + payload: str = record.json_body # if json string data, otherwise record.body for str + logger.info(payload) + + +@logger.inject_lambda_context +@tracer.capture_lambda_handler +def lambda_handler(event, context: LambdaContext): + return process_partial_response( + event=event, + record_handler=record_handler, + processor=processor, + context=context, + ) diff --git a/tests/functional/batch/required_dependencies/test_utilities_batch.py b/tests/functional/batch/required_dependencies/test_utilities_batch.py index 732e2f0ef78..77b1f865dca 100644 --- a/tests/functional/batch/required_dependencies/test_utilities_batch.py +++ b/tests/functional/batch/required_dependencies/test_utilities_batch.py @@ -408,6 +408,48 @@ def lambda_handler(event, context): assert "All records failed processing. " in str(e.value) +def test_batch_processor_not_raise_when_entire_batch_fails_sync(sqs_event_factory, record_handler): + first_record = SQSRecord(sqs_event_factory("fail")) + second_record = SQSRecord(sqs_event_factory("fail")) + event = {"Records": [first_record.raw_event, second_record.raw_event]} + + # GIVEN the BatchProcessor constructor with raise_on_entire_batch_failure False + processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False) + + # WHEN processing the messages + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + response = lambda_handler(event, {}) + + # THEN assert the `itemIdentifier` of each failure matches the message ID of the corresponding record + assert len(response["batchItemFailures"]) == 2 + assert response["batchItemFailures"][0]["itemIdentifier"] == first_record.message_id + assert response["batchItemFailures"][1]["itemIdentifier"] == second_record.message_id + + +def test_batch_processor_not_raise_when_entire_batch_fails_async(sqs_event_factory, record_handler): + first_record = SQSRecord(sqs_event_factory("fail")) + second_record = SQSRecord(sqs_event_factory("fail")) + event = {"Records": [first_record.raw_event, second_record.raw_event]} + + # GIVEN the BatchProcessor constructor with raise_on_entire_batch_failure False + processor = AsyncBatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False) + + # WHEN processing the messages + @async_batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + response = lambda_handler(event, {}) + + # THEN assert the `itemIdentifier` of each failure matches the message ID of the corresponding record + assert len(response["batchItemFailures"]) == 2 + assert response["batchItemFailures"][0]["itemIdentifier"] == first_record.message_id + assert response["batchItemFailures"][1]["itemIdentifier"] == second_record.message_id + + def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_fifo_factory, record_handler): # GIVEN first_record = SQSRecord(sqs_event_fifo_factory("success"))