diff --git a/aws_lambda_powertools/utilities/batch/__init__.py b/aws_lambda_powertools/utilities/batch/__init__.py index 08c35560b3f..02f3e786441 100644 --- a/aws_lambda_powertools/utilities/batch/__init__.py +++ b/aws_lambda_powertools/utilities/batch/__init__.py @@ -5,21 +5,27 @@ """ from aws_lambda_powertools.utilities.batch.base import ( + AsyncBatchProcessor, + BasePartialBatchProcessor, BasePartialProcessor, BatchProcessor, EventType, FailureResponse, SuccessResponse, + async_batch_processor, batch_processor, ) from aws_lambda_powertools.utilities.batch.exceptions import ExceptionInfo __all__ = ( "BatchProcessor", + "AsyncBatchProcessor", "BasePartialProcessor", + "BasePartialBatchProcessor", "ExceptionInfo", "EventType", "FailureResponse", "SuccessResponse", "batch_processor", + "async_batch_processor", ) diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 4f9c4ca8780..171858c6d11 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -3,15 +3,29 @@ """ Batch processing utilities """ +import asyncio import copy import inspect import logging +import os import sys from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, overload +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + overload, +) from aws_lambda_powertools.middleware_factory import lambda_handler_decorator +from aws_lambda_powertools.shared import constants from aws_lambda_powertools.utilities.batch.exceptions import ( BatchProcessingError, ExceptionInfo, @@ -100,6 +114,49 @@ def process(self) -> List[Tuple]: """ return [self._process_record(record) for record in self.records] + @abstractmethod + async def _async_process_record(self, record: dict): + """ + Async process record with handler. + """ + raise NotImplementedError() + + def async_process(self) -> List[Tuple]: + """ + Async call instance's handler for each record. + + Note + ---- + + We keep the outer function synchronous to prevent making Lambda handler async, so to not impact + customers' existing middlewares. Instead, we create an async closure to handle asynchrony. + + We also handle edge cases like Lambda container thaw by getting an existing or creating an event loop. + + See: https://docs.aws.amazon.com/lambda/latest/dg/lambda-runtime-environment.html#runtimes-lifecycle-shutdown + """ + + async def async_process_closure(): + return list(await asyncio.gather(*[self._async_process_record(record) for record in self.records])) + + # WARNING + # Do not use "asyncio.run(async_process())" due to Lambda container thaws/freeze, otherwise we might get "Event Loop is closed" # noqa: E501 + # Instead, get_event_loop() can also create one if a previous was erroneously closed + # Mangum library does this as well. It's battle tested with other popular async-only frameworks like FastAPI + # https://github.com/jordaneremieff/mangum/discussions/256#discussioncomment-2638946 + # https://github.com/jordaneremieff/mangum/blob/b85cd4a97f8ddd56094ccc540ca7156c76081745/mangum/protocols/http.py#L44 + + # Let's prime the coroutine and decide + # whether we create an event loop (Lambda) or schedule it as usual (non-Lambda) + coro = async_process_closure() + if os.getenv(constants.LAMBDA_TASK_ROOT_ENV): + loop = asyncio.get_event_loop() # NOTE: this might return an error starting in Python 3.12 in a few years + task_instance = loop.create_task(coro) + return loop.run_until_complete(task_instance) + + # Non-Lambda environment, run coroutine as usual + return asyncio.run(coro) + def __enter__(self): self._prepare() return self @@ -191,9 +248,262 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse: return entry +class BasePartialBatchProcessor(BasePartialProcessor): # noqa + DEFAULT_RESPONSE: Dict[str, List[Optional[dict]]] = {"batchItemFailures": []} + + def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None): + """Process batch and partially report failed items + + Parameters + ---------- + event_type: EventType + Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event + model: Optional["BatchTypeModels"] + Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord + + Exceptions + ---------- + BatchProcessingError + Raised when the entire batch has failed processing + """ + self.event_type = event_type + self.model = model + self.batch_response = copy.deepcopy(self.DEFAULT_RESPONSE) + self._COLLECTOR_MAPPING = { + EventType.SQS: self._collect_sqs_failures, + EventType.KinesisDataStreams: self._collect_kinesis_failures, + EventType.DynamoDBStreams: self._collect_dynamodb_failures, + } + self._DATA_CLASS_MAPPING = { + EventType.SQS: SQSRecord, + EventType.KinesisDataStreams: KinesisStreamRecord, + EventType.DynamoDBStreams: DynamoDBRecord, + } + + super().__init__() + + def response(self): + """Batch items that failed processing, if any""" + return self.batch_response + + def _prepare(self): + """ + Remove results from previous execution. + """ + self.success_messages.clear() + self.fail_messages.clear() + self.exceptions.clear() + self.batch_response = copy.deepcopy(self.DEFAULT_RESPONSE) + + def _clean(self): + """ + Report messages to be deleted in case of partial failure. + """ + + if not self._has_messages_to_report(): + return + + if self._entire_batch_failed(): + raise BatchProcessingError( + msg=f"All records failed processing. {len(self.exceptions)} individual errors logged " + f"separately below.", + child_exceptions=self.exceptions, + ) + + messages = self._get_messages_to_report() + self.batch_response = {"batchItemFailures": messages} + + def _has_messages_to_report(self) -> bool: + if self.fail_messages: + return True + + logger.debug(f"All {len(self.success_messages)} records successfully processed") + return False + + def _entire_batch_failed(self) -> bool: + return len(self.exceptions) == len(self.records) + + def _get_messages_to_report(self) -> List[Dict[str, str]]: + """ + Format messages to use in batch deletion + """ + return self._COLLECTOR_MAPPING[self.event_type]() + + # Event Source Data Classes follow python idioms for fields + # while Parser/Pydantic follows the event field names to the latter + def _collect_sqs_failures(self): + failures = [] + for msg in self.fail_messages: + msg_id = msg.messageId if self.model else msg.message_id + failures.append({"itemIdentifier": msg_id}) + return failures + + def _collect_kinesis_failures(self): + failures = [] + for msg in self.fail_messages: + msg_id = msg.kinesis.sequenceNumber if self.model else msg.kinesis.sequence_number + failures.append({"itemIdentifier": msg_id}) + return failures + + def _collect_dynamodb_failures(self): + failures = [] + for msg in self.fail_messages: + msg_id = msg.dynamodb.SequenceNumber if self.model else msg.dynamodb.sequence_number + failures.append({"itemIdentifier": msg_id}) + return failures + + @overload + def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels": + ... # pragma: no cover + + @overload + def _to_batch_type(self, record: dict, event_type: EventType) -> EventSourceDataClassTypes: + ... # pragma: no cover + + def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None): + if model is not None: + return model.parse_obj(record) + return self._DATA_CLASS_MAPPING[event_type](record) + + +class BatchProcessor(BasePartialBatchProcessor): # Keep old name for compatibility + """Process native partial responses from SQS, Kinesis Data Streams, and DynamoDB. + + Example + ------- + + ## Process batch triggered by SQS + + ```python + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.SQS) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: SQSRecord): + payload: str = record.body + if payload: + item: dict = json.loads(payload) + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + + ## Process batch triggered by Kinesis Data Streams + + ```python + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.KinesisDataStreams) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: KinesisStreamRecord): + logger.info(record.kinesis.data_as_text) + payload: dict = record.kinesis.data_as_json() + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + + ## Process batch triggered by DynamoDB Data Streams + + ```python + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.DynamoDBStreams) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: DynamoDBRecord): + logger.info(record.dynamodb.new_image) + payload: dict = json.loads(record.dynamodb.new_image.get("item")) + # alternatively: + # changes: Dict[str, Any] = record.dynamodb.new_image # noqa: E800 + # payload = change.get("Message") -> "" + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + def lambda_handler(event, context: LambdaContext): + batch = event["Records"] + with processor(records=batch, processor=processor): + processed_messages = processor.process() # kick off processing, return list[tuple] + + return processor.response() + ``` + + + Raises + ------ + BatchProcessingError + When all batch records fail processing + + Limitations + ----------- + * Async record handler not supported, use AsyncBatchProcessor instead. + """ + + async def _async_process_record(self, record: dict): + raise NotImplementedError() + + def _process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]: + """ + Process a record with instance's handler + + Parameters + ---------- + record: dict + A batch record to be processed. + """ + data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) + try: + if self._handler_accepts_lambda_context: + result = self.handler(record=data, lambda_context=self.lambda_context) + else: + result = self.handler(record=data) + + return self.success_handler(record=record, result=result) + except Exception: + return self.failure_handler(record=data, exception=sys.exc_info()) + + @lambda_handler_decorator def batch_processor( - handler: Callable, event: Dict, context: LambdaContext, record_handler: Callable, processor: BasePartialProcessor + handler: Callable, event: Dict, context: LambdaContext, record_handler: Callable, processor: BatchProcessor ): """ Middleware to handle batch event processing @@ -207,8 +517,8 @@ def batch_processor( context: LambdaContext Lambda's Context record_handler: Callable - Callable to process each record from the batch - processor: BasePartialProcessor + Callable or corutine to process each record from the batch + processor: BatchProcessor Batch Processor to handle partial failure cases Examples @@ -226,8 +536,7 @@ def batch_processor( Limitations ----------- - * Async batch processors - + * Async batch processors. Use `async_batch_processor` instead. """ records = event["Records"] @@ -237,9 +546,8 @@ def batch_processor( return handler(event, context) -class BatchProcessor(BasePartialProcessor): - """Process native partial responses from SQS, Kinesis Data Streams, and DynamoDB. - +class AsyncBatchProcessor(BasePartialBatchProcessor): + """Process native partial responses from SQS, Kinesis Data Streams, and DynamoDB asynchronously. Example ------- @@ -261,7 +569,7 @@ class BatchProcessor(BasePartialProcessor): @tracer.capture_method - def record_handler(record: SQSRecord): + async def record_handler(record: SQSRecord): payload: str = record.body if payload: item: dict = json.loads(payload) @@ -291,7 +599,7 @@ def lambda_handler(event, context: LambdaContext): @tracer.capture_method - def record_handler(record: KinesisStreamRecord): + async def record_handler(record: KinesisStreamRecord): logger.info(record.kinesis.data_as_text) payload: dict = record.kinesis.data_as_json() ... @@ -303,7 +611,6 @@ def lambda_handler(event, context: LambdaContext): return processor.response() ``` - ## Process batch triggered by DynamoDB Data Streams ```python @@ -321,7 +628,7 @@ def lambda_handler(event, context: LambdaContext): @tracer.capture_method - def record_handler(record: DynamoDBRecord): + async def record_handler(record: DynamoDBRecord): logger.info(record.dynamodb.new_image) payload: dict = json.loads(record.dynamodb.new_image.get("item")) # alternatively: @@ -344,55 +651,16 @@ def lambda_handler(event, context: LambdaContext): ------ BatchProcessingError When all batch records fail processing - """ - - DEFAULT_RESPONSE: Dict[str, List[Optional[dict]]] = {"batchItemFailures": []} - - def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None): - """Process batch and partially report failed items - - Parameters - ---------- - event_type: EventType - Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event - model: Optional["BatchTypeModels"] - Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord - - Exceptions - ---------- - BatchProcessingError - Raised when the entire batch has failed processing - """ - self.event_type = event_type - self.model = model - self.batch_response = copy.deepcopy(self.DEFAULT_RESPONSE) - self._COLLECTOR_MAPPING = { - EventType.SQS: self._collect_sqs_failures, - EventType.KinesisDataStreams: self._collect_kinesis_failures, - EventType.DynamoDBStreams: self._collect_dynamodb_failures, - } - self._DATA_CLASS_MAPPING = { - EventType.SQS: SQSRecord, - EventType.KinesisDataStreams: KinesisStreamRecord, - EventType.DynamoDBStreams: DynamoDBRecord, - } - - super().__init__() - def response(self): - """Batch items that failed processing, if any""" - return self.batch_response + Limitations + ----------- + * Sync record handler not supported, use BatchProcessor instead. + """ - def _prepare(self): - """ - Remove results from previous execution. - """ - self.success_messages.clear() - self.fail_messages.clear() - self.exceptions.clear() - self.batch_response = copy.deepcopy(self.DEFAULT_RESPONSE) + def _process_record(self, record: dict): + raise NotImplementedError() - def _process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]: + async def _async_process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]: """ Process a record with instance's handler @@ -404,80 +672,59 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) try: if self._handler_accepts_lambda_context: - result = self.handler(record=data, lambda_context=self.lambda_context) + result = await self.handler(record=data, lambda_context=self.lambda_context) else: - result = self.handler(record=data) + result = await self.handler(record=data) return self.success_handler(record=record, result=result) except Exception: return self.failure_handler(record=data, exception=sys.exc_info()) - def _clean(self): - """ - Report messages to be deleted in case of partial failure. - """ - - if not self._has_messages_to_report(): - return - - if self._entire_batch_failed(): - raise BatchProcessingError( - msg=f"All records failed processing. {len(self.exceptions)} individual errors logged " - f"separately below.", - child_exceptions=self.exceptions, - ) - - messages = self._get_messages_to_report() - self.batch_response = {"batchItemFailures": messages} - - def _has_messages_to_report(self) -> bool: - if self.fail_messages: - return True - - logger.debug(f"All {len(self.success_messages)} records successfully processed") - return False - - def _entire_batch_failed(self) -> bool: - return len(self.exceptions) == len(self.records) - - def _get_messages_to_report(self) -> List[Dict[str, str]]: - """ - Format messages to use in batch deletion - """ - return self._COLLECTOR_MAPPING[self.event_type]() - - # Event Source Data Classes follow python idioms for fields - # while Parser/Pydantic follows the event field names to the latter - def _collect_sqs_failures(self): - failures = [] - for msg in self.fail_messages: - msg_id = msg.messageId if self.model else msg.message_id - failures.append({"itemIdentifier": msg_id}) - return failures - - def _collect_kinesis_failures(self): - failures = [] - for msg in self.fail_messages: - msg_id = msg.kinesis.sequenceNumber if self.model else msg.kinesis.sequence_number - failures.append({"itemIdentifier": msg_id}) - return failures - def _collect_dynamodb_failures(self): - failures = [] - for msg in self.fail_messages: - msg_id = msg.dynamodb.SequenceNumber if self.model else msg.dynamodb.sequence_number - failures.append({"itemIdentifier": msg_id}) - return failures +@lambda_handler_decorator +def async_batch_processor( + handler: Callable, + event: Dict, + context: LambdaContext, + record_handler: Callable[..., Awaitable[Any]], + processor: AsyncBatchProcessor, +): + """ + Middleware to handle batch event processing + Parameters + ---------- + handler: Callable + Lambda's handler + event: Dict + Lambda's Event + context: LambdaContext + Lambda's Context + record_handler: Callable[..., Awaitable[Any]] + Callable to process each record from the batch + processor: AsyncBatchProcessor + Batch Processor to handle partial failure cases + Examples + -------- + **Processes Lambda's event with a BasePartialProcessor** + >>> from aws_lambda_powertools.utilities.batch import async_batch_processor, AsyncBatchProcessor + >>> + >>> async def async_record_handler(record): + >>> payload: str = record.body + >>> return payload + >>> + >>> processor = AsyncBatchProcessor(event_type=EventType.SQS) + >>> + >>> @async_batch_processor(record_handler=async_record_handler, processor=processor) + >>> async def lambda_handler(event, context: LambdaContext): + >>> return processor.response() - @overload - def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels": - ... # pragma: no cover + Limitations + ----------- + * Sync batch processors. Use `batch_processor` instead. + """ + records = event["Records"] - @overload - def _to_batch_type(self, record: dict, event_type: EventType) -> EventSourceDataClassTypes: - ... # pragma: no cover + with processor(records, record_handler, lambda_context=context): + processor.async_process() - def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None): - if model is not None: - return model.parse_obj(record) - return self._DATA_CLASS_MAPPING[event_type](record) + return handler(event, context) diff --git a/docs/utilities/batch.md b/docs/utilities/batch.md index 988b1937b5b..4a53e053f44 100644 --- a/docs/utilities/batch.md +++ b/docs/utilities/batch.md @@ -636,6 +636,28 @@ All records in the batch will be passed to this handler for processing, even if All processing logic will and should be performed by the `record_handler` function. +### Processing messages asynchronously + +!!! tip "New to AsyncIO? Read this [comprehensive guide first](https://realpython.com/async-io-python/){target="_blank"}." + +You can use `AsyncBatchProcessor` class and `async_batch_processor` decorator to process messages concurrently. + +???+ question "When is this useful?" + Your use case might be able to process multiple records at the same time without conflicting with one another. + + For example, imagine you need to process multiple loyalty points and incrementally save in a database. While you await the database to confirm your records are saved, you could start processing another request concurrently. + + The reason this is not the default behaviour is that not all use cases can handle concurrency safely (e.g., loyalty points must be updated in order). + +```python hl_lines="4 6 11 14 23" title="High-concurrency with AsyncBatchProcessor" +--8<-- "examples/batch_processing/src/getting_started_async_batch_processor.py" +``` + +???+ warning "Using tracer?" + `AsyncBatchProcessor` uses `asyncio.gather` which can cause side effects and reach trace limits at high concurrency. + + See [Tracing concurrent asynchronous functions](../core/tracer.md#concurrent-asynchronous-functions). + ## Advanced ### Pydantic integration diff --git a/examples/batch_processing/src/getting_started_async_batch_processor.py b/examples/batch_processing/src/getting_started_async_batch_processor.py new file mode 100644 index 00000000000..594be0540f3 --- /dev/null +++ b/examples/batch_processing/src/getting_started_async_batch_processor.py @@ -0,0 +1,25 @@ +import httpx # external dependency + +from aws_lambda_powertools.utilities.batch import ( + AsyncBatchProcessor, + EventType, + async_batch_processor, +) +from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.typing import LambdaContext + +processor = AsyncBatchProcessor(event_type=EventType.SQS) + + +async def async_record_handler(record: SQSRecord): + # Yield control back to the event loop to schedule other tasks + # while you await from a response from httpbin.org + async with httpx.AsyncClient() as client: + ret = await client.get("https://httpbin.org/get") + + return ret.status_code + + +@async_batch_processor(record_handler=async_record_handler, processor=processor) +def lambda_handler(event, context: LambdaContext): + return processor.response() diff --git a/poetry.lock b/poetry.lock index fbc1aef492a..49a3b94e304 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,27 @@ # This file is automatically @generated by Poetry and should not be changed by hand. +[[package]] +name = "anyio" +version = "3.6.2" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "dev" +optional = false +python-versions = ">=3.6.2" +files = [ + {file = "anyio-3.6.2-py3-none-any.whl", hash = "sha256:fbbe32bd270d2a2ef3ed1c5d45041250284e31fc0a4df4a5a6071842051a51e3"}, + {file = "anyio-3.6.2.tar.gz", hash = "sha256:25ea0d673ae30af41a0c442f81cf3b38c7e79fdc7b60335a4c14e05eb0947421"}, +] + +[package.dependencies] +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = "*", markers = "python_version < \"3.8\""} + +[package.extras] +doc = ["packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (<0.15)", "uvloop (>=0.15)"] +trio = ["trio (>=0.16,<0.22)"] + [[package]] name = "attrs" version = "22.2.0" @@ -21,18 +43,18 @@ tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy [[package]] name = "aws-cdk-asset-awscli-v1" -version = "2.2.52" +version = "2.2.63" description = "A library that contains the AWS CLI for use in Lambda Layers" category = "dev" optional = false python-versions = "~=3.7" files = [ - {file = "aws-cdk.asset-awscli-v1-2.2.52.tar.gz", hash = "sha256:ab04beec8e267e363931df2caf48a24100cb5799d7fd8db51efe881d117efa7a"}, - {file = "aws_cdk.asset_awscli_v1-2.2.52-py3-none-any.whl", hash = "sha256:6e9d686bb0b00242e869e91d57b65b619ffb42e99abe482436e3a6692485dbfe"}, + {file = "aws-cdk.asset-awscli-v1-2.2.63.tar.gz", hash = "sha256:76154ade5391f8927c932b609028b28426af34215f144d07576ba35e4eca9442"}, + {file = "aws_cdk.asset_awscli_v1-2.2.63-py3-none-any.whl", hash = "sha256:1ad1d5b7287097f6546902801a40f39b6580c99e5d0eb07dfc5e8ddf428167b0"}, ] [package.dependencies] -jsii = ">=1.73.0,<2.0.0" +jsii = ">=1.74.0,<2.0.0" publication = ">=0.0.3" typeguard = ">=2.13.3,<2.14.0" @@ -55,77 +77,77 @@ typeguard = ">=2.13.3,<2.14.0" [[package]] name = "aws-cdk-asset-node-proxy-agent-v5" -version = "2.0.42" +version = "2.0.52" description = "@aws-cdk/asset-node-proxy-agent-v5" category = "dev" optional = false python-versions = "~=3.7" files = [ - {file = "aws-cdk.asset-node-proxy-agent-v5-2.0.42.tar.gz", hash = "sha256:ae1b615be42e78681e05b145460603f171c06b671a2d1caa060a159b94b06366"}, - {file = "aws_cdk.asset_node_proxy_agent_v5-2.0.42-py3-none-any.whl", hash = "sha256:6e0174802097d558daa1be5c4e6e7f309eeba626392955e596bf967ee37865d3"}, + {file = "aws-cdk.asset-node-proxy-agent-v5-2.0.52.tar.gz", hash = "sha256:1346ce52303e8b8c7c88ce16599a36d947e9546fc6cae0965182594d7b0e600d"}, + {file = "aws_cdk.asset_node_proxy_agent_v5-2.0.52-py3-none-any.whl", hash = "sha256:1a08b261ea2bf10f07fe89a7502686e6be2adea636e6bb1ee1f56b678231fe02"}, ] [package.dependencies] -jsii = ">=1.73.0,<2.0.0" +jsii = ">=1.74.0,<2.0.0" publication = ">=0.0.3" typeguard = ">=2.13.3,<2.14.0" [[package]] name = "aws-cdk-aws-apigatewayv2-alpha" -version = "2.62.2a0" +version = "2.64.0a0" description = "The CDK Construct Library for AWS::APIGatewayv2" category = "dev" optional = false python-versions = "~=3.7" files = [ - {file = "aws-cdk.aws-apigatewayv2-alpha-2.62.2a0.tar.gz", hash = "sha256:63c191fdcb8b20d1afd34af84ae465740b14009a06af7bdc8e78475614f85a23"}, - {file = "aws_cdk.aws_apigatewayv2_alpha-2.62.2a0-py3-none-any.whl", hash = "sha256:32ff5d8745b71ef30ba009de4d8d9f12bd34a4f3c940500ba34367211f05c9f4"}, + {file = "aws-cdk.aws-apigatewayv2-alpha-2.64.0a0.tar.gz", hash = "sha256:7e33fb04b10c1668abe334e25a998967b51aeed76243fc591b66705c8d6241d4"}, + {file = "aws_cdk.aws_apigatewayv2_alpha-2.64.0a0-py3-none-any.whl", hash = "sha256:88f72a435fc91f7c02a8f1fb564958ac1c8125c5319021d61b67d00466185199"}, ] [package.dependencies] -aws-cdk-lib = ">=2.62.2,<3.0.0" +aws-cdk-lib = ">=2.64.0,<3.0.0" constructs = ">=10.0.0,<11.0.0" -jsii = ">=1.73.0,<2.0.0" +jsii = ">=1.74.0,<2.0.0" publication = ">=0.0.3" typeguard = ">=2.13.3,<2.14.0" [[package]] name = "aws-cdk-aws-apigatewayv2-authorizers-alpha" -version = "2.62.2a0" +version = "2.64.0a0" description = "Authorizers for AWS APIGateway V2" category = "dev" optional = false python-versions = "~=3.7" files = [ - {file = "aws-cdk.aws-apigatewayv2-authorizers-alpha-2.62.2a0.tar.gz", hash = "sha256:9a4ba121c49e4ba866b985495b87e9ecaec50c1f26e0d8cb116e15492196c042"}, - {file = "aws_cdk.aws_apigatewayv2_authorizers_alpha-2.62.2a0-py3-none-any.whl", hash = "sha256:9cfb1495b618880b395d6ecbd45c3c524c67013f2567eae6e19e6f06586b9a38"}, + {file = "aws-cdk.aws-apigatewayv2-authorizers-alpha-2.64.0a0.tar.gz", hash = "sha256:670ee77f19818723aeeea47fbac1441d58f39b5eff79332e15196452ec6183bf"}, + {file = "aws_cdk.aws_apigatewayv2_authorizers_alpha-2.64.0a0-py3-none-any.whl", hash = "sha256:e2377441ad33aa43453f5c501e00a9a0c261627e78b2080617edd6e09949c139"}, ] [package.dependencies] -"aws-cdk.aws-apigatewayv2-alpha" = "2.62.2.a0" -aws-cdk-lib = ">=2.62.2,<3.0.0" +"aws-cdk.aws-apigatewayv2-alpha" = "2.64.0.a0" +aws-cdk-lib = ">=2.64.0,<3.0.0" constructs = ">=10.0.0,<11.0.0" -jsii = ">=1.73.0,<2.0.0" +jsii = ">=1.74.0,<2.0.0" publication = ">=0.0.3" typeguard = ">=2.13.3,<2.14.0" [[package]] name = "aws-cdk-aws-apigatewayv2-integrations-alpha" -version = "2.62.2a0" +version = "2.64.0a0" description = "Integrations for AWS APIGateway V2" category = "dev" optional = false python-versions = "~=3.7" files = [ - {file = "aws-cdk.aws-apigatewayv2-integrations-alpha-2.62.2a0.tar.gz", hash = "sha256:4ae06b6585664c659eb6b88ff70eaa628a96ffb4728ab0d0eb7ff1f23913565b"}, - {file = "aws_cdk.aws_apigatewayv2_integrations_alpha-2.62.2a0-py3-none-any.whl", hash = "sha256:497e93d193895b1b38545d5ca152e31f575b971ce371ad655aeb3bbed7fc6052"}, + {file = "aws-cdk.aws-apigatewayv2-integrations-alpha-2.64.0a0.tar.gz", hash = "sha256:1826fa641a0e849cff90e681033066fa3fea44bca447c6696681dddf862df364"}, + {file = "aws_cdk.aws_apigatewayv2_integrations_alpha-2.64.0a0-py3-none-any.whl", hash = "sha256:a34f87cafbbdf76078ce564642f7f11771f4693a04bb7f41eca7d76b26ffe562"}, ] [package.dependencies] -"aws-cdk.aws-apigatewayv2-alpha" = "2.62.2.a0" -aws-cdk-lib = ">=2.62.2,<3.0.0" +"aws-cdk.aws-apigatewayv2-alpha" = "2.64.0.a0" +aws-cdk-lib = ">=2.64.0,<3.0.0" constructs = ">=10.0.0,<11.0.0" -jsii = ">=1.73.0,<2.0.0" +jsii = ">=1.74.0,<2.0.0" publication = ">=0.0.3" typeguard = ">=2.13.3,<2.14.0" @@ -167,21 +189,20 @@ requests = ">=0.14.0" [[package]] name = "aws-sam-translator" -version = "1.58.1" +version = "1.59.0" description = "AWS SAM Translator is a library that transform SAM templates into AWS CloudFormation templates" category = "dev" optional = false python-versions = ">=3.7, <=4.0, !=4.0" files = [ - {file = "aws-sam-translator-1.58.1.tar.gz", hash = "sha256:cd60a19085d432bc00769b597bc2e6854f546ff9928f8067fc5fbcb5a1ed74ff"}, - {file = "aws_sam_translator-1.58.1-py2-none-any.whl", hash = "sha256:c4e261e450d574572d389edcafab04d1fe337615f867610410390c2435cb1f26"}, - {file = "aws_sam_translator-1.58.1-py3-none-any.whl", hash = "sha256:ca47d6eb04d8cf358bea9160411193da40a80dc3e79bb0c5bace0c21f0e4c888"}, + {file = "aws-sam-translator-1.59.0.tar.gz", hash = "sha256:9b8f23a5754cba92677d334ece5c5d9dc9b1f1a327a650fc8939ae3fc6da4141"}, + {file = "aws_sam_translator-1.59.0-py3-none-any.whl", hash = "sha256:6761293a21bd1cb0e19f168926ebfc4a3a6c9011aca67bd448ef485a55d6f658"}, ] [package.dependencies] boto3 = ">=1.19.5,<2.0.0" jsonschema = ">=3.2,<5" -pydantic = ">=1.10.2,<1.11.0" +pydantic = ">=1.8,<2.0" typing-extensions = ">=4.4.0,<4.5.0" [package.extras] @@ -279,18 +300,18 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "boto3" -version = "1.26.60" +version = "1.26.68" description = "The AWS SDK for Python" category = "main" optional = false python-versions = ">= 3.7" files = [ - {file = "boto3-1.26.60-py3-none-any.whl", hash = "sha256:5fd2810217a74a38078a19fb85a9e5d6934d0c146eb060967a3ffd7ab33cdf00"}, - {file = "boto3-1.26.60.tar.gz", hash = "sha256:f0824b3bcf803800d3ecef903b4840427e4b3d37a069f6fc9a86310f7e036ad5"}, + {file = "boto3-1.26.68-py3-none-any.whl", hash = "sha256:bbb426a9b3afd3ccbac25e03b215d79e90b4c47905b1b08b3b9d86fc74096974"}, + {file = "boto3-1.26.68.tar.gz", hash = "sha256:c92dd0fde7839c0ca9c16a989d67ceb7f80f53de19f2b087fd1182f2af41b2ae"}, ] [package.dependencies] -botocore = ">=1.29.60,<1.30.0" +botocore = ">=1.29.68,<1.30.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.6.0,<0.7.0" @@ -299,14 +320,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.29.60" +version = "1.29.68" description = "Low-level, data-driven core of boto 3." category = "main" optional = false python-versions = ">= 3.7" files = [ - {file = "botocore-1.29.60-py3-none-any.whl", hash = "sha256:c4ae251e7df0cf01d893eb945bc8f23c14989ed349775a8e16c949f08a068f9a"}, - {file = "botocore-1.29.60.tar.gz", hash = "sha256:a21217ccf4613c9ebbe4c3192e13ba91d46be642560e39a16406662a398a107b"}, + {file = "botocore-1.29.68-py3-none-any.whl", hash = "sha256:08fa8302a22553e69b70b1de2cc8cec61a3a878546658d091473e13d5b9d2ca4"}, + {file = "botocore-1.29.68.tar.gz", hash = "sha256:8f5cb96dc0862809d29fe512087c77c15fe6328a2d8238f0a96cccb6eb77ec12"}, ] [package.dependencies] @@ -315,7 +336,7 @@ python-dateutil = ">=2.1,<3.0.0" urllib3 = ">=1.25.4,<1.27" [package.extras] -crt = ["awscrt (==0.15.3)"] +crt = ["awscrt (==0.16.9)"] [[package]] name = "cattrs" @@ -508,14 +529,14 @@ files = [ [[package]] name = "constructs" -version = "10.1.236" +version = "10.1.246" description = "A programming model for software-defined state" category = "dev" optional = false python-versions = "~=3.7" files = [ - {file = "constructs-10.1.236-py3-none-any.whl", hash = "sha256:e51d8fac38b12a88359d5d2bedb535987eaa54e68631add29726652be66490e9"}, - {file = "constructs-10.1.236.tar.gz", hash = "sha256:10b3c5ed3d4c6fd930bd8f59c8a5926028dafe8a5bf703fba5bcc53c89fce002"}, + {file = "constructs-10.1.246-py3-none-any.whl", hash = "sha256:f07c7c4aa2d22ff960a9f51f7011030b4a3d8cc6df0e0a84e30ea63c2c8c8456"}, + {file = "constructs-10.1.246.tar.gz", hash = "sha256:26d0b017eef92bde3ece7454b524dddc051425819c59932ebe3c1ff6f9e1cb4a"}, ] [package.dependencies] @@ -901,6 +922,67 @@ files = [ gitdb = ">=4.0.1,<5" typing-extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.8\""} +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[package.dependencies] +typing-extensions = {version = "*", markers = "python_version < \"3.8\""} + +[[package]] +name = "httpcore" +version = "0.16.3" +description = "A minimal low-level HTTP client." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "httpcore-0.16.3-py3-none-any.whl", hash = "sha256:da1fb708784a938aa084bde4feb8317056c55037247c787bd7e19eb2c2949dc0"}, + {file = "httpcore-0.16.3.tar.gz", hash = "sha256:c5d6f04e2fc530f39e0c077e6a30caa53f1451096120f1f38b954afd0b17c0cb"}, +] + +[package.dependencies] +anyio = ">=3.0,<5.0" +certifi = "*" +h11 = ">=0.13,<0.15" +sniffio = ">=1.0.0,<2.0.0" + +[package.extras] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + +[[package]] +name = "httpx" +version = "0.23.3" +description = "The next generation HTTP client." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "httpx-0.23.3-py3-none-any.whl", hash = "sha256:a211fcce9b1254ea24f0cd6af9869b3d29aba40154e947d2a07bb499b3e310d6"}, + {file = "httpx-0.23.3.tar.gz", hash = "sha256:9818458eb565bb54898ccb9b8b251a28785dd4a55afbc23d0eb410754fe7d0f9"}, +] + +[package.dependencies] +certifi = "*" +httpcore = ">=0.15.0,<0.17.0" +rfc3986 = {version = ">=1.3,<2", extras = ["idna2008"]} +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<13)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + [[package]] name = "hvac" version = "1.0.2" @@ -1233,7 +1315,6 @@ category = "dev" optional = false python-versions = "*" files = [ - {file = "junit-xml-1.9.tar.gz", hash = "sha256:de16a051990d4e25a3982b2dd9e89d671067548718866416faec14d9de56db9f"}, {file = "junit_xml-1.9-py2.py3-none-any.whl", hash = "sha256:ec5ca1a55aefdd76d28fcc0b135251d156c7106fa979686a4b48d62b761b4732"}, ] @@ -1708,14 +1789,14 @@ typing-extensions = ">=4.1.0" [[package]] name = "mypy-extensions" -version = "0.4.3" -description = "Experimental type system extensions for programs checked with the mypy typechecker." +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." category = "dev" optional = false -python-versions = "*" +python-versions = ">=3.5" files = [ - {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, - {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] [[package]] @@ -1803,22 +1884,22 @@ files = [ [[package]] name = "platformdirs" -version = "2.6.2" +version = "3.0.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "platformdirs-2.6.2-py3-none-any.whl", hash = "sha256:83c8f6d04389165de7c9b6f0c682439697887bca0aa2f1c87ef1826be3584490"}, - {file = "platformdirs-2.6.2.tar.gz", hash = "sha256:e1fea1fe471b9ff8332e229df3cb7de4f53eeea4998d3b6bfff542115e998bd2"}, + {file = "platformdirs-3.0.0-py3-none-any.whl", hash = "sha256:b1d5eb14f221506f50d6604a561f4c5786d9e80355219694a1b244bcd96f4567"}, + {file = "platformdirs-3.0.0.tar.gz", hash = "sha256:8a1228abb1ef82d788f74139988b137e78692984ec7b08eaa6c65f1723af28f9"}, ] [package.dependencies] typing-extensions = {version = ">=4.4", markers = "python_version < \"3.8\""} [package.extras] -docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=5.3)", "sphinx-autodoc-typehints (>=1.19.5)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] +docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] [[package]] name = "pluggy" @@ -2469,6 +2550,24 @@ files = [ decorator = ">=3.4.2" py = ">=1.4.26,<2.0.0" +[[package]] +name = "rfc3986" +version = "1.5.0" +description = "Validating URI References per RFC 3986" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "rfc3986-1.5.0-py2.py3-none-any.whl", hash = "sha256:a86d6e1f5b1dc238b218b012df0aa79409667bb209e58da56d0b94704e712a97"}, + {file = "rfc3986-1.5.0.tar.gz", hash = "sha256:270aaf10d87d0d4e095063c65bf3ddbc6ee3d0b226328ce21e036f946e421835"}, +] + +[package.dependencies] +idna = {version = "*", optional = true, markers = "extra == \"idna2008\""} + +[package.extras] +idna2008 = ["idna"] + [[package]] name = "s3transfer" version = "0.6.0" @@ -2527,6 +2626,18 @@ files = [ {file = "smmap-5.0.0.tar.gz", hash = "sha256:c840e62059cd3be204b0c9c9f74be2c09d5648eddd4580d9314c3ecde0b30936"}, ] +[[package]] +name = "sniffio" +version = "1.3.0" +description = "Sniff out which async library your code is running under" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, + {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, +] + [[package]] name = "stevedore" version = "3.5.2" @@ -2634,14 +2745,14 @@ types-urllib3 = "<1.27" [[package]] name = "types-urllib3" -version = "1.26.25.4" +version = "1.26.25.5" description = "Typing stubs for urllib3" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-urllib3-1.26.25.4.tar.gz", hash = "sha256:eec5556428eec862b1ac578fb69aab3877995a99ffec9e5a12cf7fbd0cc9daee"}, - {file = "types_urllib3-1.26.25.4-py3-none-any.whl", hash = "sha256:ed6b9e8a8be488796f72306889a06a3fc3cb1aa99af02ab8afb50144d7317e49"}, + {file = "types-urllib3-1.26.25.5.tar.gz", hash = "sha256:5630e578246d170d91ebe3901788cd28d53c4e044dc2e2488e3b0d55fb6895d8"}, + {file = "types_urllib3-1.26.25.5-py3-none-any.whl", hash = "sha256:e8f25c8bb85cde658c72ee931e56e7abd28803c26032441eea9ff4a4df2b0c31"}, ] [[package]] @@ -2822,14 +2933,14 @@ requests = ">=2.0,<3.0" [[package]] name = "zipp" -version = "3.12.0" +version = "3.13.0" description = "Backport of pathlib-compatible object wrapper for zip files" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "zipp-3.12.0-py3-none-any.whl", hash = "sha256:9eb0a4c5feab9b08871db0d672745b53450d7f26992fd1e4653aa43345e97b86"}, - {file = "zipp-3.12.0.tar.gz", hash = "sha256:73efd63936398aac78fd92b6f4865190119d6c91b531532e798977ea8dd402eb"}, + {file = "zipp-3.13.0-py3-none-any.whl", hash = "sha256:e8b2a36ea17df80ffe9e2c4fda3f693c3dad6df1697d3cd3af232db680950b0b"}, + {file = "zipp-3.13.0.tar.gz", hash = "sha256:23f70e964bc11a34cef175bc90ba2914e1e4545ea1e3e2f67c079671883f9cb6"}, ] [package.extras] @@ -2846,4 +2957,4 @@ validation = ["fastjsonschema"] [metadata] lock-version = "2.0" python-versions = "^3.7.4" -content-hash = "ddd991646d99a0521be85e8210ba52d1d85bd645d6e0624f44168fff1887af6c" +content-hash = "62a6b0896bad16de0b814e025384cc7c078c72cead1e5c4926700c118d8b7dda" diff --git a/pyproject.toml b/pyproject.toml index 04b0af0fd10..e8fdc91ca0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ aws-sdk = ["boto3"] cfn-lint = "0.67.0" mypy = "^0.982" types-python-dateutil = "^2.8.19.6" +httpx = "^0.23.3" [tool.coverage.run] source = ["aws_lambda_powertools"] diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index 1d50de9e85e..6dcfc3d179d 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -1,13 +1,15 @@ import json from random import randint -from typing import Callable, Dict, Optional +from typing import Any, Awaitable, Callable, Dict, Optional import pytest from botocore.config import Config from aws_lambda_powertools.utilities.batch import ( + AsyncBatchProcessor, BatchProcessor, EventType, + async_batch_processor, batch_processor, ) from aws_lambda_powertools.utilities.batch.exceptions import BatchProcessingError @@ -115,6 +117,17 @@ def handler(record): return handler +@pytest.fixture(scope="module") +def async_record_handler() -> Callable[..., Awaitable[Any]]: + async def handler(record): + body = record["body"] + if "fail" in body: + raise Exception("Failed to process record.") + return body + + return handler + + @pytest.fixture(scope="module") def kinesis_record_handler() -> Callable: def handler(record: KinesisStreamRecord): @@ -639,3 +652,82 @@ def lambda_handler(event, context): # THEN raise BatchProcessingError assert "All records failed processing. " in str(e.value) + + +def test_async_batch_processor_middleware_success_only(sqs_event_factory, async_record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("success")) + second_record = SQSRecord(sqs_event_factory("success")) + event = {"Records": [first_record.raw_event, second_record.raw_event]} + + processor = AsyncBatchProcessor(event_type=EventType.SQS) + + @async_batch_processor(record_handler=async_record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN + assert result["batchItemFailures"] == [] + + +def test_async_batch_processor_middleware_with_failure(sqs_event_factory, async_record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("fail")) + second_record = SQSRecord(sqs_event_factory("success")) + third_record = SQSRecord(sqs_event_factory("fail")) + event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]} + + processor = AsyncBatchProcessor(event_type=EventType.SQS) + + @async_batch_processor(record_handler=async_record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN + assert len(result["batchItemFailures"]) == 2 + + +def test_async_batch_processor_context_success_only(sqs_event_factory, async_record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("success")) + second_record = SQSRecord(sqs_event_factory("success")) + records = [first_record.raw_event, second_record.raw_event] + processor = AsyncBatchProcessor(event_type=EventType.SQS) + + # WHEN + with processor(records, async_record_handler) as batch: + processed_messages = batch.async_process() + + # THEN + assert processed_messages == [ + ("success", first_record.body, first_record.raw_event), + ("success", second_record.body, second_record.raw_event), + ] + + assert batch.response() == {"batchItemFailures": []} + + +def test_async_batch_processor_context_with_failure(sqs_event_factory, async_record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("failure")) + second_record = SQSRecord(sqs_event_factory("success")) + third_record = SQSRecord(sqs_event_factory("fail")) + records = [first_record.raw_event, second_record.raw_event, third_record.raw_event] + processor = AsyncBatchProcessor(event_type=EventType.SQS) + + # WHEN + with processor(records, async_record_handler) as batch: + processed_messages = batch.async_process() + + # THEN + assert processed_messages[1] == ("success", second_record.body, second_record.raw_event) + assert len(batch.fail_messages) == 2 + assert batch.response() == { + "batchItemFailures": [{"itemIdentifier": first_record.message_id}, {"itemIdentifier": third_record.message_id}] + }