diff --git a/aws_lambda_powertools/utilities/batch/decorators.py b/aws_lambda_powertools/utilities/batch/decorators.py index 0cba41f98fe..2b9f5433e70 100644 --- a/aws_lambda_powertools/utilities/batch/decorators.py +++ b/aws_lambda_powertools/utilities/batch/decorators.py @@ -12,6 +12,7 @@ BatchProcessor, EventType, ) +from aws_lambda_powertools.utilities.batch.exceptions import UnexpectedBatchTypeError from aws_lambda_powertools.warnings import PowertoolsDeprecationWarning if TYPE_CHECKING: @@ -204,6 +205,11 @@ def handler(event, context): """ try: records: list[dict] = event.get("Records", []) + if not records or not isinstance(records, list): + raise UnexpectedBatchTypeError( + "Unexpected batch event type. Possible values are: SQS, KinesisDataStreams, DynamoDBStreams", + ) + except AttributeError: event_types = ", ".join(list(EventType.__members__)) docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line @@ -268,6 +274,11 @@ def handler(event, context): """ try: records: list[dict] = event.get("Records", []) + if not records or not isinstance(records, list): + raise UnexpectedBatchTypeError( + "Unexpected batch event type. Possible values are: SQS, KinesisDataStreams, DynamoDBStreams", + ) + except AttributeError: event_types = ", ".join(list(EventType.__members__)) docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line diff --git a/aws_lambda_powertools/utilities/batch/exceptions.py b/aws_lambda_powertools/utilities/batch/exceptions.py index c93b96a8f34..87a2df22d6d 100644 --- a/aws_lambda_powertools/utilities/batch/exceptions.py +++ b/aws_lambda_powertools/utilities/batch/exceptions.py @@ -38,6 +38,12 @@ def __str__(self): return self.format_exceptions(parent_exception_str) +class UnexpectedBatchTypeError(BatchProcessingError): + """Error thrown by the Batch Processing utility when a partial processor receives an unexpected batch type""" + + pass + + class SQSFifoCircuitBreakerError(Exception): """ Signals a record not processed due to the SQS FIFO processing being interrupted diff --git a/tests/functional/batch/required_dependencies/test_utilities_batch.py b/tests/functional/batch/required_dependencies/test_utilities_batch.py index 9327a7d70fc..4c91dd54a1e 100644 --- a/tests/functional/batch/required_dependencies/test_utilities_batch.py +++ b/tests/functional/batch/required_dependencies/test_utilities_batch.py @@ -15,7 +15,7 @@ batch_processor, process_partial_response, ) -from aws_lambda_powertools.utilities.batch.exceptions import BatchProcessingError +from aws_lambda_powertools.utilities.batch.exceptions import BatchProcessingError, UnexpectedBatchTypeError from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( DynamoDBRecord, ) @@ -708,3 +708,56 @@ def test_async_process_partial_response_invalid_input(async_record_handler: Call # WHEN/THEN with pytest.raises(ValueError): async_process_partial_response(batch, record_handler, processor) + + +@pytest.mark.parametrize( + "event", + [ + {}, + {"Records": None}, + {"Records": "not a list"}, + ], +) +def test_process_partial_response_raises_unexpected_batch_type(event, record_handler): + # GIVEN a batch processor configured for SQS events + processor = BatchProcessor(event_type=EventType.SQS) + + # WHEN processing an event with invalid Records + with pytest.raises(UnexpectedBatchTypeError) as exc_info: + process_partial_response( + event=event, + record_handler=record_handler, + processor=processor, + ) + + # THEN the correct error message is raised + assert "Unexpected batch event type. Possible values are: SQS, KinesisDataStreams, DynamoDBStreams" in str( + exc_info.value, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "event", + [ + {}, + {"Records": None}, + {"Records": "not a list"}, + ], +) +async def test_async_process_partial_response_raises_unexpected_batch_type(event, async_record_handler): + # GIVEN a batch processor configured for SQS events + processor = BatchProcessor(event_type=EventType.SQS) + + # WHEN processing an event with invalid Records asynchronously + with pytest.raises(UnexpectedBatchTypeError) as exc_info: + await async_process_partial_response( + event=event, + record_handler=async_record_handler, + processor=processor, + ) + + # THEN the correct error message is raised + assert "Unexpected batch event type. Possible values are: SQS, KinesisDataStreams, DynamoDBStreams" in str( + exc_info.value, + )