import json import uuid from random import randint from typing import Any, Awaitable, Callable, Dict import pytest from aws_lambda_powertools.utilities.batch import ( AsyncBatchProcessor, BatchProcessor, EventType, SqsFifoPartialProcessor, async_batch_processor, async_process_partial_response, batch_processor, process_partial_response, ) from aws_lambda_powertools.utilities.batch.exceptions import BatchProcessingError, UnexpectedBatchTypeError from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( DynamoDBRecord, ) from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import ( KinesisStreamRecord, ) from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord from aws_lambda_powertools.warnings import PowertoolsDeprecationWarning from tests.functional.utils import b64_to_str, str_to_b64 @pytest.fixture(scope="module") def sqs_event_fifo_factory() -> Callable: def factory(body: str, message_group_id: str = ""): return { "messageId": f"{uuid.uuid4()}", "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", "body": body, "attributes": { "ApproximateReceiveCount": "1", "SentTimestamp": "1703675223472", "SequenceNumber": "18882884930918384133", "MessageGroupId": message_group_id, "SenderId": "SenderId", "MessageDeduplicationId": "1eea03c3f7e782c7bdc2f2a917f40389314733ff39f5ab16219580c0109ade98", "ApproximateFirstReceiveTimestamp": "1703675223484", }, "messageAttributes": {}, "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", "awsRegion": "us-east-1", } return factory @pytest.fixture(scope="module") def sqs_event_factory() -> Callable: def factory(body: str): return { "messageId": f"{uuid.uuid4()}", "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", "body": body, "attributes": { "ApproximateReceiveCount": "1", "SentTimestamp": "1545082649183", "SenderId": "SenderId", "ApproximateFirstReceiveTimestamp": "1545082649185", }, "messageAttributes": {}, "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", "awsRegion": "us-east-1", } return factory @pytest.fixture(scope="module") def kinesis_event_factory() -> Callable: def factory(body: str): seq = "".join(str(randint(0, 9)) for _ in range(52)) return { "kinesis": { "kinesisSchemaVersion": "1.0", "partitionKey": "1", "sequenceNumber": seq, "data": str_to_b64(body), "approximateArrivalTimestamp": 1545084650.987, }, "eventSource": "aws:kinesis", "eventVersion": "1.0", "eventID": f"shardId-000000000006:{seq}", "eventName": "aws:kinesis:record", "invokeIdentityArn": "arn:aws:iam::123456789012:role/lambda-role", "awsRegion": "us-east-2", "eventSourceARN": "arn:aws:kinesis:us-east-2:123456789012:stream/lambda-stream", } return factory @pytest.fixture(scope="module") def dynamodb_event_factory() -> Callable: def factory(body: str): seq = "".join(str(randint(0, 9)) for _ in range(10)) return { "eventID": "1", "eventVersion": "1.0", "dynamodb": { "Keys": {"Id": {"N": "101"}}, "NewImage": {"Message": {"S": body}}, "StreamViewType": "NEW_AND_OLD_IMAGES", "SequenceNumber": seq, "SizeBytes": 26, }, "awsRegion": "us-west-2", "eventName": "INSERT", "eventSourceARN": "eventsource_arn", "eventSource": "aws:dynamodb", } return factory @pytest.fixture(scope="module") def record_handler() -> Callable: def handler(record): body = record["body"] if "fail" in body: raise Exception("Failed to process record.") return body return handler @pytest.fixture(scope="module") def async_record_handler() -> Callable[..., Awaitable[Any]]: async def handler(record): body = record["body"] if "fail" in body: raise Exception("Failed to process record.") return body return handler @pytest.fixture(scope="module") def kinesis_record_handler() -> Callable: def handler(record: KinesisStreamRecord): body = b64_to_str(record.kinesis.data) if "fail" in body: raise Exception("Failed to process record.") return body return handler @pytest.fixture(scope="module") def dynamodb_record_handler() -> Callable: def handler(record: DynamoDBRecord): body = record.dynamodb.new_image.get("Message") if "fail" in body: raise ValueError("Failed to process record.") return body return handler @pytest.fixture(scope="module") def order_event_factory() -> Callable: def factory(item: Dict) -> str: return json.dumps({"item": item}) return factory def test_batch_processor_middleware_success_only(sqs_event_factory, record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("success")) second_record = SQSRecord(sqs_event_factory("success")) event = {"Records": [first_record.raw_event, second_record.raw_event]} processor = BatchProcessor(event_type=EventType.SQS) @batch_processor(record_handler=record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN result = lambda_handler(event, {}) # THEN assert result["batchItemFailures"] == [] def test_batch_processor_middleware_with_failure(sqs_event_factory, record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("fail")) second_record = SQSRecord(sqs_event_factory("success")) third_record = SQSRecord(sqs_event_factory("fail")) event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]} processor = BatchProcessor(event_type=EventType.SQS) @batch_processor(record_handler=record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN result = lambda_handler(event, {}) # THEN assert len(result["batchItemFailures"]) == 2 def test_batch_processor_context_success_only(sqs_event_factory, record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("success")) second_record = SQSRecord(sqs_event_factory("success")) records = [first_record.raw_event, second_record.raw_event] processor = BatchProcessor(event_type=EventType.SQS) # WHEN with processor(records, record_handler) as batch: processed_messages = batch.process() # THEN assert processed_messages == [ ("success", first_record.body, first_record.raw_event), ("success", second_record.body, second_record.raw_event), ] assert batch.response() == {"batchItemFailures": []} def test_batch_processor_context_with_failure(sqs_event_factory, record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("failure")) second_record = SQSRecord(sqs_event_factory("success")) third_record = SQSRecord(sqs_event_factory("fail")) records = [first_record.raw_event, second_record.raw_event, third_record.raw_event] processor = BatchProcessor(event_type=EventType.SQS) # WHEN with processor(records, record_handler) as batch: processed_messages = batch.process() # THEN assert processed_messages[1] == ("success", second_record.body, second_record.raw_event) assert len(batch.fail_messages) == 2 assert batch.response() == { "batchItemFailures": [{"itemIdentifier": first_record.message_id}, {"itemIdentifier": third_record.message_id}], } def test_batch_processor_kinesis_context_success_only(kinesis_event_factory, kinesis_record_handler): # GIVEN first_record = KinesisStreamRecord(kinesis_event_factory("success")) second_record = KinesisStreamRecord(kinesis_event_factory("success")) records = [first_record.raw_event, second_record.raw_event] processor = BatchProcessor(event_type=EventType.KinesisDataStreams) # WHEN with processor(records, kinesis_record_handler) as batch: processed_messages = batch.process() # THEN assert processed_messages == [ ("success", b64_to_str(first_record.kinesis.data), first_record.raw_event), ("success", b64_to_str(second_record.kinesis.data), second_record.raw_event), ] assert batch.response() == {"batchItemFailures": []} def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kinesis_record_handler): # GIVEN first_record = KinesisStreamRecord(kinesis_event_factory("failure")) second_record = KinesisStreamRecord(kinesis_event_factory("success")) third_record = KinesisStreamRecord(kinesis_event_factory("failure")) records = [first_record.raw_event, second_record.raw_event, third_record.raw_event] processor = BatchProcessor(event_type=EventType.KinesisDataStreams) # WHEN with processor(records, kinesis_record_handler) as batch: processed_messages = batch.process() # THEN assert processed_messages[1] == ("success", b64_to_str(second_record.kinesis.data), second_record.raw_event) assert len(batch.fail_messages) == 2 assert batch.response() == { "batchItemFailures": [ {"itemIdentifier": first_record.kinesis.sequence_number}, {"itemIdentifier": third_record.kinesis.sequence_number}, ], } def test_batch_processor_kinesis_middleware_with_failure(kinesis_event_factory, kinesis_record_handler): # GIVEN first_record = KinesisStreamRecord(kinesis_event_factory("failure")) second_record = KinesisStreamRecord(kinesis_event_factory("success")) third_record = KinesisStreamRecord(kinesis_event_factory("failure")) event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]} processor = BatchProcessor(event_type=EventType.KinesisDataStreams) @batch_processor(record_handler=kinesis_record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN result = lambda_handler(event, {}) # THEN assert len(result["batchItemFailures"]) == 2 def test_batch_processor_dynamodb_context_success_only(dynamodb_event_factory, dynamodb_record_handler): # GIVEN first_record = dynamodb_event_factory("success") second_record = dynamodb_event_factory("success") records = [first_record, second_record] processor = BatchProcessor(event_type=EventType.DynamoDBStreams) # WHEN with processor(records, dynamodb_record_handler) as batch: processed_messages = batch.process() # THEN assert processed_messages == [ ("success", first_record["dynamodb"]["NewImage"]["Message"]["S"], first_record), ("success", second_record["dynamodb"]["NewImage"]["Message"]["S"], second_record), ] assert batch.response() == {"batchItemFailures": []} def test_batch_processor_dynamodb_context_with_failure(dynamodb_event_factory, dynamodb_record_handler): # GIVEN first_record = dynamodb_event_factory("failure") second_record = dynamodb_event_factory("success") third_record = dynamodb_event_factory("failure") records = [first_record, second_record, third_record] processor = BatchProcessor(event_type=EventType.DynamoDBStreams) # WHEN with processor(records, dynamodb_record_handler) as batch: processed_messages = batch.process() # THEN assert processed_messages[1] == ("success", second_record["dynamodb"]["NewImage"]["Message"]["S"], second_record) assert len(batch.fail_messages) == 2 assert batch.response() == { "batchItemFailures": [ {"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}, {"itemIdentifier": third_record["dynamodb"]["SequenceNumber"]}, ], } def test_batch_processor_dynamodb_middleware_with_failure(dynamodb_event_factory, dynamodb_record_handler): # GIVEN first_record = dynamodb_event_factory("failure") second_record = dynamodb_event_factory("success") third_record = dynamodb_event_factory("failure") event = {"Records": [first_record, second_record, third_record]} processor = BatchProcessor(event_type=EventType.DynamoDBStreams) @batch_processor(record_handler=dynamodb_record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN result = lambda_handler(event, {}) # THEN assert len(result["batchItemFailures"]) == 2 def test_batch_processor_error_when_entire_batch_fails(sqs_event_factory, record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("fail")) second_record = SQSRecord(sqs_event_factory("fail")) event = {"Records": [first_record.raw_event, second_record.raw_event]} processor = BatchProcessor(event_type=EventType.SQS) @batch_processor(record_handler=record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN calling `lambda_handler` in cold start with pytest.raises(BatchProcessingError) as e: lambda_handler(event, {}) # THEN raise BatchProcessingError assert "All records failed processing. " in str(e.value) # WHEN calling `lambda_handler` in warm start with pytest.raises(BatchProcessingError) as e: lambda_handler(event, {}) # THEN raise BatchProcessingError assert "All records failed processing. " in str(e.value) def test_batch_processor_not_raise_when_entire_batch_fails_sync(sqs_event_factory, record_handler): first_record = SQSRecord(sqs_event_factory("fail")) second_record = SQSRecord(sqs_event_factory("fail")) event = {"Records": [first_record.raw_event, second_record.raw_event]} # GIVEN the BatchProcessor constructor with raise_on_entire_batch_failure False processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False) # WHEN processing the messages @batch_processor(record_handler=record_handler, processor=processor) def lambda_handler(event, context): return processor.response() response = lambda_handler(event, {}) # THEN assert the `itemIdentifier` of each failure matches the message ID of the corresponding record assert len(response["batchItemFailures"]) == 2 assert response["batchItemFailures"][0]["itemIdentifier"] == first_record.message_id assert response["batchItemFailures"][1]["itemIdentifier"] == second_record.message_id def test_batch_processor_not_raise_when_entire_batch_fails_async(sqs_event_factory, record_handler): first_record = SQSRecord(sqs_event_factory("fail")) second_record = SQSRecord(sqs_event_factory("fail")) event = {"Records": [first_record.raw_event, second_record.raw_event]} # GIVEN the BatchProcessor constructor with raise_on_entire_batch_failure False processor = AsyncBatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False) # WHEN processing the messages @async_batch_processor(record_handler=record_handler, processor=processor) def lambda_handler(event, context): return processor.response() response = lambda_handler(event, {}) # THEN assert the `itemIdentifier` of each failure matches the message ID of the corresponding record assert len(response["batchItemFailures"]) == 2 assert response["batchItemFailures"][0]["itemIdentifier"] == first_record.message_id assert response["batchItemFailures"][1]["itemIdentifier"] == second_record.message_id def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_fifo_factory, record_handler): # GIVEN first_record = SQSRecord(sqs_event_fifo_factory("success")) second_record = SQSRecord(sqs_event_fifo_factory("success")) event = {"Records": [first_record.raw_event, second_record.raw_event]} processor = SqsFifoPartialProcessor() @batch_processor(record_handler=record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN result = lambda_handler(event, {}) # THEN assert result["batchItemFailures"] == [] def test_sqs_fifo_batch_processor_middleware_with_failure(sqs_event_fifo_factory, record_handler): # GIVEN first_record = SQSRecord(sqs_event_fifo_factory("success")) second_record = SQSRecord(sqs_event_fifo_factory("fail")) # this would normally succeed, but since it's a FIFO queue, it will be marked as failure third_record = SQSRecord(sqs_event_fifo_factory("success")) event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]} processor = SqsFifoPartialProcessor() @batch_processor(record_handler=record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN result = lambda_handler(event, {}) # THEN assert len(result["batchItemFailures"]) == 2 assert result["batchItemFailures"][0]["itemIdentifier"] == second_record.message_id assert result["batchItemFailures"][1]["itemIdentifier"] == third_record.message_id def test_sqs_fifo_batch_processor_middleware_with_skip_group_on_error(sqs_event_fifo_factory, record_handler): # GIVEN a batch of 5 records with 3 different MessageGroupID first_record = SQSRecord(sqs_event_fifo_factory("success", "1")) second_record = SQSRecord(sqs_event_fifo_factory("success", "1")) third_record = SQSRecord(sqs_event_fifo_factory("fail", "2")) fourth_record = SQSRecord(sqs_event_fifo_factory("success", "2")) fifth_record = SQSRecord(sqs_event_fifo_factory("fail", "3")) event = { "Records": [ first_record.raw_event, second_record.raw_event, third_record.raw_event, fourth_record.raw_event, fifth_record.raw_event, ], } # WHEN the FIFO processor is set to continue processing even after encountering errors in specific MessageGroupID processor = SqsFifoPartialProcessor(skip_group_on_error=True) @batch_processor(record_handler=record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN result = lambda_handler(event, {}) # THEN only failed messages should originate from MessageGroupID 3 assert len(result["batchItemFailures"]) == 3 assert result["batchItemFailures"][0]["itemIdentifier"] == third_record.message_id assert result["batchItemFailures"][1]["itemIdentifier"] == fourth_record.message_id assert result["batchItemFailures"][2]["itemIdentifier"] == fifth_record.message_id def test_sqs_fifo_batch_processor_middleware_with_skip_group_on_error_first_message_fail( sqs_event_fifo_factory, record_handler, ): # GIVEN a batch of 5 records with 3 different MessageGroupID first_record = SQSRecord(sqs_event_fifo_factory("fail", "1")) second_record = SQSRecord(sqs_event_fifo_factory("success", "1")) third_record = SQSRecord(sqs_event_fifo_factory("fail", "2")) fourth_record = SQSRecord(sqs_event_fifo_factory("success", "2")) fifth_record = SQSRecord(sqs_event_fifo_factory("success", "3")) event = { "Records": [ first_record.raw_event, second_record.raw_event, third_record.raw_event, fourth_record.raw_event, fifth_record.raw_event, ], } # WHEN the FIFO processor is set to continue processing even after encountering errors in specific MessageGroupID processor = SqsFifoPartialProcessor(skip_group_on_error=True) @batch_processor(record_handler=record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN the handler is onvoked result = lambda_handler(event, {}) # THEN messages from group 1 and 2 should fail, but not group 3 assert len(result["batchItemFailures"]) == 4 assert result["batchItemFailures"][0]["itemIdentifier"] == first_record.message_id assert result["batchItemFailures"][1]["itemIdentifier"] == second_record.message_id assert result["batchItemFailures"][2]["itemIdentifier"] == third_record.message_id assert result["batchItemFailures"][3]["itemIdentifier"] == fourth_record.message_id def test_async_batch_processor_middleware_success_only(sqs_event_factory, async_record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("success")) second_record = SQSRecord(sqs_event_factory("success")) event = {"Records": [first_record.raw_event, second_record.raw_event]} processor = AsyncBatchProcessor(event_type=EventType.SQS) @async_batch_processor(record_handler=async_record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN with pytest.warns(PowertoolsDeprecationWarning, match="The `async_batch_processor` decorator is deprecated in V3*"): result = lambda_handler(event, {}) # THEN assert result["batchItemFailures"] == [] def test_async_batch_processor_middleware_with_failure(sqs_event_factory, async_record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("fail")) second_record = SQSRecord(sqs_event_factory("success")) third_record = SQSRecord(sqs_event_factory("fail")) event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]} processor = AsyncBatchProcessor(event_type=EventType.SQS) @async_batch_processor(record_handler=async_record_handler, processor=processor) def lambda_handler(event, context): return processor.response() # WHEN with pytest.warns(PowertoolsDeprecationWarning, match="The `async_batch_processor` decorator is deprecated in V3*"): result = lambda_handler(event, {}) # THEN assert len(result["batchItemFailures"]) == 2 def test_async_batch_processor_context_success_only(sqs_event_factory, async_record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("success")) second_record = SQSRecord(sqs_event_factory("success")) records = [first_record.raw_event, second_record.raw_event] processor = AsyncBatchProcessor(event_type=EventType.SQS) # WHEN with processor(records, async_record_handler) as batch: processed_messages = batch.async_process() # THEN assert processed_messages == [ ("success", first_record.body, first_record.raw_event), ("success", second_record.body, second_record.raw_event), ] assert batch.response() == {"batchItemFailures": []} def test_async_batch_processor_context_with_failure(sqs_event_factory, async_record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("failure")) second_record = SQSRecord(sqs_event_factory("success")) third_record = SQSRecord(sqs_event_factory("fail")) records = [first_record.raw_event, second_record.raw_event, third_record.raw_event] processor = AsyncBatchProcessor(event_type=EventType.SQS) # WHEN with processor(records, async_record_handler) as batch: processed_messages = batch.async_process() # THEN assert processed_messages[1] == ("success", second_record.body, second_record.raw_event) assert len(batch.fail_messages) == 2 assert batch.response() == { "batchItemFailures": [{"itemIdentifier": first_record.message_id}, {"itemIdentifier": third_record.message_id}], } def test_process_partial_response(sqs_event_factory, record_handler): # GIVEN records = [sqs_event_factory("success"), sqs_event_factory("success")] batch = {"Records": records} processor = BatchProcessor(event_type=EventType.SQS) # WHEN ret = process_partial_response(batch, record_handler, processor) # THEN assert ret == {"batchItemFailures": []} @pytest.mark.parametrize( "batch", [ pytest.param(123456789, id="num"), pytest.param([], id="list"), pytest.param(False, id="bool"), pytest.param(object, id="object"), pytest.param(lambda x: x, id="callable"), ], ) def test_process_partial_response_invalid_input(record_handler: Callable, batch: Any): # GIVEN processor = BatchProcessor(event_type=EventType.SQS) # WHEN/THEN with pytest.raises(ValueError): process_partial_response(batch, record_handler, processor) def test_async_process_partial_response(sqs_event_factory, async_record_handler): # GIVEN records = [sqs_event_factory("success"), sqs_event_factory("success")] batch = {"Records": records} processor = AsyncBatchProcessor(event_type=EventType.SQS) # WHEN ret = async_process_partial_response(batch, async_record_handler, processor) # THEN assert ret == {"batchItemFailures": []} @pytest.mark.parametrize( "batch", [ pytest.param(123456789, id="num"), pytest.param([], id="list"), pytest.param(False, id="bool"), pytest.param(object, id="object"), pytest.param(lambda x: x, id="callable"), ], ) def test_async_process_partial_response_invalid_input(async_record_handler: Callable, batch: Any): # GIVEN processor = AsyncBatchProcessor(event_type=EventType.SQS) # WHEN/THEN with pytest.raises(ValueError): async_process_partial_response(batch, record_handler, processor) @pytest.mark.parametrize( "event", [ {}, {"Records": None}, {"Records": "not a list"}, ], ) def test_process_partial_response_raises_unexpected_batch_type(event, record_handler): # GIVEN a batch processor configured for SQS events processor = BatchProcessor(event_type=EventType.SQS) # WHEN processing an event with invalid Records with pytest.raises(UnexpectedBatchTypeError) as exc_info: process_partial_response( event=event, record_handler=record_handler, processor=processor, ) # THEN the correct error message is raised assert "Unexpected batch event type. Possible values are: SQS, KinesisDataStreams, DynamoDBStreams" in str( exc_info.value, ) @pytest.mark.asyncio @pytest.mark.parametrize( "event", [ {}, {"Records": None}, {"Records": "not a list"}, ], ) async def test_async_process_partial_response_raises_unexpected_batch_type(event, async_record_handler): # GIVEN a batch processor configured for SQS events processor = BatchProcessor(event_type=EventType.SQS) # WHEN processing an event with invalid Records asynchronously with pytest.raises(UnexpectedBatchTypeError) as exc_info: await async_process_partial_response( event=event, record_handler=async_record_handler, processor=processor, ) # THEN the correct error message is raised assert "Unexpected batch event type. Possible values are: SQS, KinesisDataStreams, DynamoDBStreams" in str( exc_info.value, )