Skip to content

feat(batch): add option to not raise BatchProcessingError exception when the entire batch fails #4719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

17 changes: 13 additions & 4 deletions aws_lambda_powertools/utilities/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,12 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
class BasePartialBatchProcessor(BasePartialProcessor): # noqa
DEFAULT_RESPONSE: PartialItemFailureResponse = {"batchItemFailures": []}

def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None):
def __init__(
self,
event_type: EventType,
model: Optional["BatchTypeModels"] = None,
raise_on_entire_batch_fail: bool = True,
):
"""Process batch and partially report failed items

Parameters
Expand All @@ -231,6 +236,9 @@ def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = N
Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event
model: Optional["BatchTypeModels"]
Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord
raise_on_entire_batch_fail: bool
Raise an exception when the entire batch has failed processing.
When set to False, partial failures are reported in the response

Exceptions
----------
Expand All @@ -239,6 +247,7 @@ def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = N
"""
self.event_type = event_type
self.model = model
self.raise_on_entire_batch_fail = raise_on_entire_batch_fail
self.batch_response: PartialItemFailureResponse = copy.deepcopy(self.DEFAULT_RESPONSE)
self._COLLECTOR_MAPPING = {
EventType.SQS: self._collect_sqs_failures,
Expand Down Expand Up @@ -274,7 +283,7 @@ def _clean(self):
if not self._has_messages_to_report():
return

if self._entire_batch_failed():
if self._entire_batch_failed() and self.raise_on_entire_batch_fail:
raise BatchProcessingError(
msg=f"All records failed processing. {len(self.exceptions)} individual errors logged "
f"separately below.",
Expand Down Expand Up @@ -475,7 +484,7 @@ def lambda_handler(event, context: LambdaContext):
Raises
------
BatchProcessingError
When all batch records fail processing
When all batch records fail processing and raise_on_entire_batch_fail is True

Limitations
-----------
Expand Down Expand Up @@ -624,7 +633,7 @@ def lambda_handler(event, context: LambdaContext):
Raises
------
BatchProcessingError
When all batch records fail processing
When all batch records fail processing and raise_on_entire_batch_fail is True

Limitations
-----------
Expand Down
14 changes: 14 additions & 0 deletions docs/utilities/batch.md
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,20 @@ Inheritance is importance because we need to access message IDs and sequence num
--8<-- "examples/batch_processing/src/pydantic_dynamodb_event.json"
```

### Working with full batch failures

By default, the `BatchProcessor` will raise `BatchProcessingError` if all records in the batch fail to process, we do this to reflect the failure in your operational metrics.

When working with functions that handle batches with a small number of records, or when you use errors as a flow control mechanism, this behavior might not be desirable as your function might generate an unnaturally high number of errors. When this happens, the [Lambda service will scale down the concurrency of your function](https://docs.aws.amazon.com/lambda/latest/dg/services-sqs-errorhandling.html#services-sqs-backoff-strategy){target="_blank"}, potentially impacting performance.

For these scenarios, you can set the `raise_on_entire_batch_fail` option to `False`.

=== "working_with_entire_batch_fail.py"

```python hl_lines="10"
--8<-- "examples/batch_processing/src/working_with_entire_batch_fail.py"
```

### Accessing processed messages

Use the context manager to access a list of all returned values from your `record_handler` function.
Expand Down
29 changes: 29 additions & 0 deletions examples/batch_processing/src/working_with_entire_batch_fail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities.batch import (
BatchProcessor,
EventType,
process_partial_response,
)
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext

processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_fail=False)
tracer = Tracer()
logger = Logger()


@tracer.capture_method
def record_handler(record: SQSRecord):
payload: str = record.json_body # if json string data, otherwise record.body for str
logger.info(payload)


@logger.inject_lambda_context
@tracer.capture_lambda_handler
def lambda_handler(event, context: LambdaContext):
return process_partial_response(
event=event,
record_handler=record_handler,
processor=processor,
context=context,
)
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,48 @@ def lambda_handler(event, context):
assert "All records failed processing. " in str(e.value)


def test_batch_processor_not_raise_when_entire_batch_fails_sync(sqs_event_factory, record_handler):
first_record = SQSRecord(sqs_event_factory("fail"))
second_record = SQSRecord(sqs_event_factory("fail"))
event = {"Records": [first_record.raw_event, second_record.raw_event]}

# GIVEN the BatchProcessor constructor with raise_on_entire_batch_fail False
processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_fail=False)

# WHEN processing the messages
@batch_processor(record_handler=record_handler, processor=processor)
def lambda_handler(event, context):
return processor.response()

response = lambda_handler(event, {})

# THEN assert the `itemIdentifier` of each failure matches the message ID of the corresponding record
assert len(response["batchItemFailures"]) == 2
assert response["batchItemFailures"][0]["itemIdentifier"] == first_record.message_id
assert response["batchItemFailures"][1]["itemIdentifier"] == second_record.message_id


def test_batch_processor_not_raise_when_entire_batch_fails_async(sqs_event_factory, record_handler):
first_record = SQSRecord(sqs_event_factory("fail"))
second_record = SQSRecord(sqs_event_factory("fail"))
event = {"Records": [first_record.raw_event, second_record.raw_event]}

# GIVEN the BatchProcessor constructor with raise_on_entire_batch_fail False
processor = AsyncBatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_fail=False)

# WHEN processing the messages
@async_batch_processor(record_handler=record_handler, processor=processor)
def lambda_handler(event, context):
return processor.response()

response = lambda_handler(event, {})

# THEN assert the `itemIdentifier` of each failure matches the message ID of the corresponding record
assert len(response["batchItemFailures"]) == 2
assert response["batchItemFailures"][0]["itemIdentifier"] == first_record.message_id
assert response["batchItemFailures"][1]["itemIdentifier"] == second_record.message_id


def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_fifo_factory, record_handler):
# GIVEN
first_record = SQSRecord(sqs_event_fifo_factory("success"))
Expand Down