diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f42337d5c5b..61e98378017 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,10 +11,6 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: check-toml - - repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.5.1 - hooks: - - id: python-use-type-annotations - repo: local hooks: - id: black diff --git a/aws_lambda_powertools/utilities/batch/__init__.py b/aws_lambda_powertools/utilities/batch/__init__.py index d308a56abda..584342e5fd0 100644 --- a/aws_lambda_powertools/utilities/batch/__init__.py +++ b/aws_lambda_powertools/utilities/batch/__init__.py @@ -4,7 +4,25 @@ Batch processing utility """ -from .base import BasePartialProcessor, batch_processor -from .sqs import PartialSQSProcessor, sqs_batch_processor +from aws_lambda_powertools.utilities.batch.base import ( + BasePartialProcessor, + BatchProcessor, + EventType, + ExceptionInfo, + FailureResponse, + SuccessResponse, + batch_processor, +) +from aws_lambda_powertools.utilities.batch.sqs import PartialSQSProcessor, sqs_batch_processor -__all__ = ("BasePartialProcessor", "PartialSQSProcessor", "batch_processor", "sqs_batch_processor") +__all__ = ( + "BatchProcessor", + "BasePartialProcessor", + "ExceptionInfo", + "EventType", + "FailureResponse", + "PartialSQSProcessor", + "SuccessResponse", + "batch_processor", + "sqs_batch_processor", +) diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index a0ad18a9ec1..02eb00ffaed 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -3,24 +3,64 @@ """ Batch processing utilities """ - +import copy import logging +import sys from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Tuple +from enum import Enum +from types import TracebackType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, overload from aws_lambda_powertools.middleware_factory import lambda_handler_decorator +from aws_lambda_powertools.utilities.batch.exceptions import BatchProcessingError +from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord +from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord +from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord logger = logging.getLogger(__name__) +class EventType(Enum): + SQS = "SQS" + KinesisDataStreams = "KinesisDataStreams" + DynamoDBStreams = "DynamoDBStreams" + + +# +# type specifics +# +has_pydantic = "pydantic" in sys.modules +ExceptionInfo = Tuple[Type[BaseException], BaseException, TracebackType] +OptExcInfo = Union[ExceptionInfo, Tuple[None, None, None]] + +# For IntelliSense and Mypy to work, we need to account for possible SQS, Kinesis and DynamoDB subclasses +# We need them as subclasses as we must access their message ID or sequence number metadata via dot notation +if has_pydantic: + from aws_lambda_powertools.utilities.parser.models import DynamoDBStreamRecordModel + from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord as KinesisDataStreamRecordModel + from aws_lambda_powertools.utilities.parser.models import SqsRecordModel + + BatchTypeModels = Optional[ + Union[Type[SqsRecordModel], Type[DynamoDBStreamRecordModel], Type[KinesisDataStreamRecordModel]] + ] + +# When using processor with default arguments, records will carry EventSourceDataClassTypes +# and depending on what EventType it's passed it'll correctly map to the right record +# When using Pydantic Models, it'll accept any subclass from SQS, DynamoDB and Kinesis +EventSourceDataClassTypes = Union[SQSRecord, KinesisStreamRecord, DynamoDBRecord] +BatchEventTypes = Union[EventSourceDataClassTypes, "BatchTypeModels"] +SuccessResponse = Tuple[str, Any, BatchEventTypes] +FailureResponse = Tuple[str, str, BatchEventTypes] + + class BasePartialProcessor(ABC): """ Abstract class for batch processors. """ def __init__(self): - self.success_messages: List = [] - self.fail_messages: List = [] + self.success_messages: List[BatchEventTypes] = [] + self.fail_messages: List[BatchEventTypes] = [] self.exceptions: List = [] @abstractmethod @@ -38,7 +78,7 @@ def _clean(self): raise NotImplementedError() @abstractmethod - def _process_record(self, record: Any): + def _process_record(self, record: dict): """ Process record with handler. """ @@ -57,13 +97,13 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): self._clean() - def __call__(self, records: List[Any], handler: Callable): + def __call__(self, records: List[dict], handler: Callable): """ Set instance attributes before execution Parameters ---------- - records: List[Any] + records: List[dict] List with objects to be processed. handler: Callable Callable to process "records" entries. @@ -72,26 +112,40 @@ def __call__(self, records: List[Any], handler: Callable): self.handler = handler return self - def success_handler(self, record: Any, result: Any): + def success_handler(self, record, result: Any) -> SuccessResponse: """ - Success callback + Keeps track of batch records that were processed successfully + + Parameters + ---------- + record: Any + record that failed processing + result: Any + result from record handler Returns ------- - tuple + SuccessResponse "success", result, original record """ entry = ("success", result, record) self.success_messages.append(record) return entry - def failure_handler(self, record: Any, exception: Tuple): + def failure_handler(self, record, exception: OptExcInfo) -> FailureResponse: """ - Failure callback + Keeps track of batch records that failed processing + + Parameters + ---------- + record: Any + record that failed processing + exception: OptExcInfo + Exception information containing type, value, and traceback (sys.exc_info()) Returns ------- - tuple + FailureResponse "fail", exceptions args, original record """ exception_string = f"{exception[0]}:{exception[1]}" @@ -146,3 +200,243 @@ def batch_processor( processor.process() return handler(event, context) + + +class BatchProcessor(BasePartialProcessor): + """Process native partial responses from SQS, Kinesis Data Streams, and DynamoDB. + + + Example + ------- + + ## Process batch triggered by SQS + + ```python + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.SQS) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: SQSRecord): + payload: str = record.body + if payload: + item: dict = json.loads(payload) + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + + ## Process batch triggered by Kinesis Data Streams + + ```python + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.KinesisDataStreams) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: KinesisStreamRecord): + logger.info(record.kinesis.data_as_text) + payload: dict = record.kinesis.data_as_json() + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + + + ## Process batch triggered by DynamoDB Data Streams + + ```python + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.DynamoDBStreams) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: DynamoDBRecord): + logger.info(record.dynamodb.new_image) + payload: dict = json.loads(record.dynamodb.new_image.get("item").s_value) + # alternatively: + # changes: Dict[str, dynamo_db_stream_event.AttributeValue] = record.dynamodb.new_image # noqa: E800 + # payload = change.get("Message").raw_event -> {"S": ""} + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + def lambda_handler(event, context: LambdaContext): + batch = event["Records"] + with processor(records=batch, processor=processor): + processed_messages = processor.process() # kick off processing, return list[tuple] + + return processor.response() + ``` + + + Raises + ------ + BatchProcessingError + When all batch records fail processing + """ + + DEFAULT_RESPONSE: Dict[str, List[Optional[dict]]] = {"batchItemFailures": []} + + def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None): + """Process batch and partially report failed items + + Parameters + ---------- + event_type: EventType + Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event + model: Optional["BatchTypeModels"] + Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord + + Exceptions + ---------- + BatchProcessingError + Raised when the entire batch has failed processing + """ + self.event_type = event_type + self.model = model + self.batch_response = copy.deepcopy(self.DEFAULT_RESPONSE) + self._COLLECTOR_MAPPING = { + EventType.SQS: self._collect_sqs_failures, + EventType.KinesisDataStreams: self._collect_kinesis_failures, + EventType.DynamoDBStreams: self._collect_dynamodb_failures, + } + self._DATA_CLASS_MAPPING = { + EventType.SQS: SQSRecord, + EventType.KinesisDataStreams: KinesisStreamRecord, + EventType.DynamoDBStreams: DynamoDBRecord, + } + + super().__init__() + + def response(self): + """Batch items that failed processing, if any""" + return self.batch_response + + def _prepare(self): + """ + Remove results from previous execution. + """ + self.success_messages.clear() + self.fail_messages.clear() + self.batch_response = copy.deepcopy(self.DEFAULT_RESPONSE) + + def _process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]: + """ + Process a record with instance's handler + + Parameters + ---------- + record: dict + A batch record to be processed. + """ + data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) + try: + result = self.handler(record=data) + return self.success_handler(record=record, result=result) + except Exception: + return self.failure_handler(record=data, exception=sys.exc_info()) + + def _clean(self): + """ + Report messages to be deleted in case of partial failure. + """ + + if not self._has_messages_to_report(): + return + + if self._entire_batch_failed(): + raise BatchProcessingError( + msg=f"All records failed processing. {len(self.exceptions)} individual errors logged" + f"separately below.", + child_exceptions=self.exceptions, + ) + + messages = self._get_messages_to_report() + self.batch_response = {"batchItemFailures": [messages]} + + def _has_messages_to_report(self) -> bool: + if self.fail_messages: + return True + + logger.debug(f"All {len(self.success_messages)} records successfully processed") + return False + + def _entire_batch_failed(self) -> bool: + return len(self.exceptions) == len(self.records) + + def _get_messages_to_report(self) -> Dict[str, str]: + """ + Format messages to use in batch deletion + """ + return self._COLLECTOR_MAPPING[self.event_type]() + + # Event Source Data Classes follow python idioms for fields + # while Parser/Pydantic follows the event field names to the latter + def _collect_sqs_failures(self): + if self.model: + return {"itemIdentifier": msg.messageId for msg in self.fail_messages} + else: + return {"itemIdentifier": msg.message_id for msg in self.fail_messages} + + def _collect_kinesis_failures(self): + if self.model: + # Pydantic model uses int but Lambda poller expects str + return {"itemIdentifier": msg.kinesis.sequenceNumber for msg in self.fail_messages} + else: + return {"itemIdentifier": msg.kinesis.sequence_number for msg in self.fail_messages} + + def _collect_dynamodb_failures(self): + if self.model: + return {"itemIdentifier": msg.dynamodb.SequenceNumber for msg in self.fail_messages} + else: + return {"itemIdentifier": msg.dynamodb.sequence_number for msg in self.fail_messages} + + @overload + def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels": + ... + + @overload + def _to_batch_type(self, record: dict, event_type: EventType) -> EventSourceDataClassTypes: + ... + + def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None): + if model is not None: + return model.parse_obj(record) + else: + return self._DATA_CLASS_MAPPING[event_type](record) diff --git a/aws_lambda_powertools/utilities/batch/exceptions.py b/aws_lambda_powertools/utilities/batch/exceptions.py index c2ead04a7b1..fe51433a5d6 100644 --- a/aws_lambda_powertools/utilities/batch/exceptions.py +++ b/aws_lambda_powertools/utilities/batch/exceptions.py @@ -2,20 +2,16 @@ Batch processing exceptions """ import traceback +from typing import Optional, Tuple -class SQSBatchProcessingError(Exception): - """When at least one message within a batch could not be processed""" - +class BaseBatchProcessingError(Exception): def __init__(self, msg="", child_exceptions=()): super().__init__(msg) self.msg = msg self.child_exceptions = child_exceptions - # Overriding this method so we can output all child exception tracebacks when we raise this exception to prevent - # errors being lost. See https://github.com/awslabs/aws-lambda-powertools-python/issues/275 - def __str__(self): - parent_exception_str = super(SQSBatchProcessingError, self).__str__() + def format_exceptions(self, parent_exception_str): exception_list = [f"{parent_exception_str}\n"] for exception in self.child_exceptions: extype, ex, tb = exception @@ -23,3 +19,27 @@ def __str__(self): exception_list.append(formatted) return "\n".join(exception_list) + + +class SQSBatchProcessingError(BaseBatchProcessingError): + """When at least one message within a batch could not be processed""" + + def __init__(self, msg="", child_exceptions: Optional[Tuple[Exception]] = None): + super().__init__(msg, child_exceptions) + + # Overriding this method so we can output all child exception tracebacks when we raise this exception to prevent + # errors being lost. See https://github.com/awslabs/aws-lambda-powertools-python/issues/275 + def __str__(self): + parent_exception_str = super(SQSBatchProcessingError, self).__str__() + return self.format_exceptions(parent_exception_str) + + +class BatchProcessingError(BaseBatchProcessingError): + """When all batch records failed to be processed""" + + def __init__(self, msg="", child_exceptions: Optional[Tuple[Exception]] = None): + super().__init__(msg, child_exceptions) + + def __str__(self): + parent_exception_str = super(BatchProcessingError, self).__str__() + return self.format_exceptions(parent_exception_str) diff --git a/docs/utilities/batch.md b/docs/utilities/batch.md index 56ab160e9f9..3ea9413749e 100644 --- a/docs/utilities/batch.md +++ b/docs/utilities/batch.md @@ -1,33 +1,642 @@ --- -title: SQS Batch Processing +title: Batch Processing description: Utility --- -The SQS batch processing utility provides a way to handle partial failures when processing batches of messages from SQS. +The batch processing utility handles partial failures when processing batches from Amazon SQS, Amazon Kinesis Data Streams, and Amazon DynamoDB Streams. ## Key Features -* Prevent successfully processed messages being returned to SQS -* Simple interface for individually processing messages from a batch -* Build your own batch processor using the base classes +* Reports batch item failures to reduce number of retries for a record upon errors +* Simple interface to process each batch record +* Integrates with [Event Source Data Classes](./data_classes.md){target="_blank} and [Parser (Pydantic)](parser.md){target="_blank} for self-documenting record schema +* Build your own batch processor by extending primitives ## Background -When using SQS as a Lambda event source mapping, Lambda functions are triggered with a batch of messages from SQS. +When using SQS, Kinesis Data Streams, or DynamoDB Streams as a Lambda event source, your Lambda functions are triggered with a batch of messages. -If your function fails to process any message from the batch, the entire batch returns to your SQS queue, and your Lambda function is triggered with the same batch one more time. +If your function fails to process any message from the batch, the entire batch returns to your queue or stream. This same batch is then retried until either condition happens first: **a)** your Lambda function returns a successful response, **b)** record reaches maximum retry attempts, or **c)** when records expire. -With this utility, messages within a batch are handled individually - only messages that were not successfully processed -are returned to the queue. +With this utility, batch records are processed individually – only messages that failed to be processed return to the queue or stream for a further retry. This works when two mechanisms are in place: -!!! warning - While this utility lowers the chance of processing messages more than once, it is not guaranteed. We recommend implementing processing logic in an idempotent manner wherever possible. +1. `ReportBatchItemFailures` is set in your SQS, Kinesis, or DynamoDB event source properties +2. [A specific response](https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html#sqs-batchfailurereporting-syntax){target="_blank"} is returned so Lambda knows which records should not be deleted during partial responses + +!!! warning "This utility lowers the chance of processing records more than once; it does not guarantee it" + We recommend implementing processing logic in an [idempotent manner](idempotency.md){target="_blank"} wherever possible. - More details on how Lambda works with SQS can be found in the [AWS documentation](https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html) + You can find more details on how Lambda works with either [SQS](https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html){target="_blank"}, [Kinesis](https://docs.aws.amazon.com/lambda/latest/dg/with-kinesis.html){target="_blank"}, or [DynamoDB](https://docs.aws.amazon.com/lambda/latest/dg/with-ddb.html){target="_blank"} in the AWS Documentation. ## Getting started -### IAM Permissions +Regardless whether you're using SQS, Kinesis Data Streams or DynamoDB Streams, you must configure your Lambda function event source to use ``ReportBatchItemFailures`. + +You do not need any additional IAM permissions to use this utility, except for what each event source requires. + +### Required resources + +The remaining sections of the documentation will rely on these samples. For completeness, this demonstrates IAM permissions and Dead Letter Queue where batch records will be sent after 2 retries were attempted. + + +=== "SQS" + + ```yaml title="template.yaml" hl_lines="31-32" + AWSTemplateFormatVersion: '2010-09-09' + Transform: AWS::Serverless-2016-10-31 + Description: partial batch response sample + + Globals: + Function: + Timeout: 5 + MemorySize: 256 + Runtime: python3.8 + Tracing: Active + Environment: + Variables: + LOG_LEVEL: INFO + POWERTOOLS_SERVICE_NAME: hello + + Resources: + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + Handler: app.lambda_handler + CodeUri: hello_world + Policies: + - SQSPollerPolicy: + QueueName: !GetAtt SampleQueue.QueueName + Events: + Batch: + Type: SQS + Properties: + Queue: !GetAtt SampleQueue.Arn + FunctionResponseTypes: + - ReportBatchItemFailures + + SampleDLQ: + Type: AWS::SQS::Queue + + SampleQueue: + Type: AWS::SQS::Queue + Properties: + VisibilityTimeout: 30 # Fn timeout * 6 + RedrivePolicy: + maxReceiveCount: 2 + deadLetterTargetArn: !GetAtt SampleDLQ.Arn + ``` + +=== "Kinesis Data Streams" + + ```yaml title="template.yaml" hl_lines="44-45" + AWSTemplateFormatVersion: '2010-09-09' + Transform: AWS::Serverless-2016-10-31 + Description: partial batch response sample + + Globals: + Function: + Timeout: 5 + MemorySize: 256 + Runtime: python3.8 + Tracing: Active + Environment: + Variables: + LOG_LEVEL: INFO + POWERTOOLS_SERVICE_NAME: hello + + Resources: + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + Handler: app.lambda_handler + CodeUri: hello_world + Policies: + # Lambda Destinations require additional permissions + # to send failure records to DLQ from Kinesis/DynamoDB + - Version: "2012-10-17" + Statement: + Effect: "Allow" + Action: + - sqs:GetQueueAttributes + - sqs:GetQueueUrl + - sqs:SendMessage + Resource: !GetAtt SampleDLQ.Arn + Events: + KinesisStream: + Type: Kinesis + Properties: + Stream: !GetAtt SampleStream.Arn + BatchSize: 100 + StartingPosition: LATEST + MaximumRetryAttempts: 2 + DestinationConfig: + OnFailure: + Destination: !GetAtt SampleDLQ.Arn + FunctionResponseTypes: + - ReportBatchItemFailures + + SampleDLQ: + Type: AWS::SQS::Queue + + SampleStream: + Type: AWS::Kinesis::Stream + Properties: + ShardCount: 1 + ``` + +=== "DynamoDB Streams" + + ```yaml title="template.yaml" hl_lines="43-44" + AWSTemplateFormatVersion: '2010-09-09' + Transform: AWS::Serverless-2016-10-31 + Description: partial batch response sample + + Globals: + Function: + Timeout: 5 + MemorySize: 256 + Runtime: python3.8 + Tracing: Active + Environment: + Variables: + LOG_LEVEL: INFO + POWERTOOLS_SERVICE_NAME: hello + + Resources: + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + Handler: app.lambda_handler + CodeUri: hello_world + Policies: + # Lambda Destinations require additional permissions + # to send failure records from Kinesis/DynamoDB + - Version: "2012-10-17" + Statement: + Effect: "Allow" + Action: + - sqs:GetQueueAttributes + - sqs:GetQueueUrl + - sqs:SendMessage + Resource: !GetAtt SampleDLQ.Arn + Events: + DynamoDBStream: + Type: DynamoDB + Properties: + Stream: !GetAtt SampleTable.StreamArn + StartingPosition: LATEST + MaximumRetryAttempts: 2 + DestinationConfig: + OnFailure: + Destination: !GetAtt SampleDLQ.Arn + FunctionResponseTypes: + - ReportBatchItemFailures + + SampleDLQ: + Type: AWS::SQS::Queue + + SampleTable: + Type: AWS::DynamoDB::Table + Properties: + BillingMode: PAY_PER_REQUEST + AttributeDefinitions: + - AttributeName: pk + AttributeType: S + - AttributeName: sk + AttributeType: S + KeySchema: + - AttributeName: pk + KeyType: HASH + - AttributeName: sk + KeyType: RANGE + SSESpecification: + SSEEnabled: yes + StreamSpecification: + StreamViewType: NEW_AND_OLD_IMAGES + + ``` + +### Processing messages from SQS + +Processing batches from SQS works in four stages: + +1. Instantiate **`BatchProcessor`** and choose **`EventType.SQS`** for the event type +2. Define your function to handle each batch record, and use [`SQSRecord`](data_classes.md#sqs){target="_blank"} type annotation for autocompletion +3. Use either **`batch_processor`** decorator or your instantiated processor as a context manager to kick off processing +4. Return the appropriate response contract to Lambda via **`.response()`** processor method + +!!! info "This code example optionally uses Tracer and Logger for completion" + +=== "As a decorator" + + ```python hl_lines="4-5 9 15 23 25" + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.SQS) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: SQSRecord): + payload: str = record.body + if payload: + item: dict = json.loads(payload) + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + +=== "As a context manager" + + ```python hl_lines="4-5 9 15 24-26 28" + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.SQS) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: SQSRecord): + payload: str = record.body + if payload: + item: dict = json.loads(payload) + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + def lambda_handler(event, context: LambdaContext): + batch = event["Records"] + with processor(records=batch, processor=processor): + processed_messages = processor.process() # kick off processing, return list[tuple] + + return processor.response() + ``` + +=== "Sample response" + + The second record failed to be processed, therefore the processor added its message ID in the response. + + ```python + { + 'batchItemFailures': [ + { + 'itemIdentifier': '244fc6b4-87a3-44ab-83d2-361172410c3a' + } + ] + } + ``` + +=== "Sample event" + + ```json + { + "Records": [ + { + "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", + "body": "{\"Message\": \"success\"}", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185" + }, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2: 123456789012:my-queue", + "awsRegion": "us-east-1" + }, + { + "messageId": "244fc6b4-87a3-44ab-83d2-361172410c3a", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", + "body": "SGVsbG8sIHRoaXMgaXMgYSB0ZXN0Lg==", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185" + }, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2: 123456789012:my-queue", + "awsRegion": "us-east-1" + } + ] + } + ``` + +### Processing messages from Kinesis + +Processing batches from Kinesis works in four stages: + +1. Instantiate **`BatchProcessor`** and choose **`EventType.KinesisDataStreams`** for the event type +2. Define your function to handle each batch record, and use [`KinesisStreamRecord`](data_classes.md#kinesis-streams){target="_blank"} type annotation for autocompletion +3. Use either **`batch_processor`** decorator or your instantiated processor as a context manager to kick off processing +4. Return the appropriate response contract to Lambda via **`.response()`** processor method + +!!! info "This code example optionally uses Tracer and Logger for completion" + +=== "As a decorator" + + ```python hl_lines="4-5 9 15 22 24" + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.KinesisDataStreams) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: KinesisStreamRecord): + logger.info(record.kinesis.data_as_text) + payload: dict = record.kinesis.data_as_json() + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + +=== "As a context manager" + + ```python hl_lines="4-5 9 15 23-25 27" + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.KinesisDataStreams) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: KinesisStreamRecord): + logger.info(record.kinesis.data_as_text) + payload: dict = record.kinesis.data_as_json() + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + def lambda_handler(event, context: LambdaContext): + batch = event["Records"] + with processor(records=batch, processor=processor): + processed_messages = processor.process() # kick off processing, return list[tuple] + + return processor.response() + ``` + +=== "Sample response" + + The second record failed to be processed, therefore the processor added its sequence number in the response. + + ```python + { + 'batchItemFailures': [ + { + 'itemIdentifier': '6006958808509702859251049540584488075644979031228738' + } + ] + } + ``` + + +=== "Sample event" + + ```json + { + "Records": [ + { + "kinesis": { + "kinesisSchemaVersion": "1.0", + "partitionKey": "1", + "sequenceNumber": "4107859083838847772757075850904226111829882106684065", + "data": "eyJNZXNzYWdlIjogInN1Y2Nlc3MifQ==", + "approximateArrivalTimestamp": 1545084650.987 + }, + "eventSource": "aws:kinesis", + "eventVersion": "1.0", + "eventID": "shardId-000000000006:4107859083838847772757075850904226111829882106684065", + "eventName": "aws:kinesis:record", + "invokeIdentityArn": "arn:aws:iam::123456789012:role/lambda-role", + "awsRegion": "us-east-2", + "eventSourceARN": "arn:aws:kinesis:us-east-2:123456789012:stream/lambda-stream" + }, + { + "kinesis": { + "kinesisSchemaVersion": "1.0", + "partitionKey": "1", + "sequenceNumber": "6006958808509702859251049540584488075644979031228738", + "data": "c3VjY2Vzcw==", + "approximateArrivalTimestamp": 1545084650.987 + }, + "eventSource": "aws:kinesis", + "eventVersion": "1.0", + "eventID": "shardId-000000000006:6006958808509702859251049540584488075644979031228738", + "eventName": "aws:kinesis:record", + "invokeIdentityArn": "arn:aws:iam::123456789012:role/lambda-role", + "awsRegion": "us-east-2", + "eventSourceARN": "arn:aws:kinesis:us-east-2:123456789012:stream/lambda-stream" + } + ] + } + ``` + + +### Processing messages from DynamoDB + +Processing batches from Kinesis works in four stages: + +1. Instantiate **`BatchProcessor`** and choose **`EventType.DynamoDBStreams`** for the event type +2. Define your function to handle each batch record, and use [`DynamoDBRecord`](data_classes.md#dynamodb-streams){target="_blank"} type annotation for autocompletion +3. Use either **`batch_processor`** decorator or your instantiated processor as a context manager to kick off processing +4. Return the appropriate response contract to Lambda via **`.response()`** processor method + +!!! info "This code example optionally uses Tracer and Logger for completion" + +=== "As a decorator" + + ```python hl_lines="4-5 9 15 25 27" + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.DynamoDBStreams) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: DynamoDBRecord): + logger.info(record.dynamodb.new_image) + payload: dict = json.loads(record.dynamodb.new_image.get("Message").get_value) + # alternatively: + # changes: Dict[str, dynamo_db_stream_event.AttributeValue] = record.dynamodb.new_image + # payload = change.get("Message").raw_event -> {"S": ""} + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + +=== "As a context manager" + + ```python hl_lines="4-5 9 15 26-28 30" + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.DynamoDBStreams) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: DynamoDBRecord): + logger.info(record.dynamodb.new_image) + payload: dict = json.loads(record.dynamodb.new_image.get("item").s_value) + # alternatively: + # changes: Dict[str, dynamo_db_stream_event.AttributeValue] = record.dynamodb.new_image + # payload = change.get("Message").raw_event -> {"S": ""} + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + def lambda_handler(event, context: LambdaContext): + batch = event["Records"] + with processor(records=batch, processor=processor): + processed_messages = processor.process() # kick off processing, return list[tuple] + + return processor.response() + ``` + +=== "Sample response" + + The second record failed to be processed, therefore the processor added its sequence number in the response. + + ```python + { + 'batchItemFailures': [ + { + 'itemIdentifier': '8640712661' + } + ] + } + ``` + + +=== "Sample event" + + ```json + { + "Records": [ + { + "eventID": "1", + "eventVersion": "1.0", + "dynamodb": { + "Keys": { + "Id": { + "N": "101" + } + }, + "NewImage": { + "Message": { + "S": "failure" + } + }, + "StreamViewType": "NEW_AND_OLD_IMAGES", + "SequenceNumber": "3275880929", + "SizeBytes": 26 + }, + "awsRegion": "us-west-2", + "eventName": "INSERT", + "eventSourceARN": "eventsource_arn", + "eventSource": "aws:dynamodb" + }, + { + "eventID": "1", + "eventVersion": "1.0", + "dynamodb": { + "Keys": { + "Id": { + "N": "101" + } + }, + "NewImage": { + "SomethingElse": { + "S": "success" + } + }, + "StreamViewType": "NEW_AND_OLD_IMAGES", + "SequenceNumber": "8640712661", + "SizeBytes": 26 + }, + "awsRegion": "us-west-2", + "eventName": "INSERT", + "eventSourceARN": "eventsource_arn", + "eventSource": "aws:dynamodb" + } + ] + } + ``` + +### Partial failure mechanics + +All records in the batch will be passed to this handler for processing, even if exceptions are thrown - Here's the behaviour after completing the batch: + +* **All records successfully processed**. We will return an empty list of item failures `{'batchItemFailures': []}` +* **Partial success with some exceptions**. We will return a list of all item IDs/sequence numbers that failed processing +* **All records failed to be processed**. We will raise `BatchProcessingError` exception with a list of all exceptions raised when processing + +!!! warning + You will not have access to the **processed messages** within the Lambda Handler; use context manager for that. + + All processing logic will and should be performed by the `record_handler` function. + + + -### Partial failure mechanics -All records in the batch will be passed to this handler for processing, even if exceptions are thrown - Here's the behaviour after completing the batch: +## Advanced -* **Any successfully processed messages**, we will delete them from the queue via `sqs:DeleteMessageBatch` -* **Any unprocessed messages detected**, we will raise `SQSBatchProcessingError` to ensure failed messages return to your SQS queue +### Pydantic integration -!!! warning - You will not have accessed to the **processed messages** within the Lambda Handler. +You can bring your own Pydantic models via **`model`** parameter when inheriting from **`SqsRecordModel`**, **`KinesisDataStreamRecord`**, or **`DynamoDBStreamRecordModel`** - All processing logic will and should be performed by the `record_handler` function. +Inheritance is importance because we need to access message IDs and sequence numbers from these records in the event of failure. Mypy is fully integrated with this utility, so it should identify whether you're passing the incorrect Model. -## Advanced -### Choosing between decorator and context manager +=== "SQS" + + ```python hl_lines="5 9-10 12-19 21 27" + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.parser.models import SqsRecordModel + from aws_lambda_powertools.utilities.typing import LambdaContext + + + class Order(BaseModel): + item: dict + + class OrderSqsRecord(SqsRecordModel): + body: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("body", pre=True) + def transform_body_to_dict(cls, value: str): + return json.loads(value) + + processor = BatchProcessor(event_type=EventType.SQS, model=OrderSqsRecord) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: OrderSqsRecord): + return record.body.item + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + +=== "Kinesis Data Streams" + + ```python hl_lines="5 9-10 12-20 22-23 26 32" + import json -They have nearly the same behaviour when it comes to processing messages from the batch: + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord + from aws_lambda_powertools.utilities.typing import LambdaContext -* **Entire batch has been successfully processed**, where your Lambda handler returned successfully, we will let SQS delete the batch to optimize your cost -* **Entire Batch has been partially processed successfully**, where exceptions were raised within your `record handler`, we will: - * **1)** Delete successfully processed messages from the queue by directly calling `sqs:DeleteMessageBatch` - * **2)** Raise `SQSBatchProcessingError` to ensure failed messages return to your SQS queue -The only difference is that **PartialSQSProcessor** will give you access to processed messages if you need. + class Order(BaseModel): + item: dict + + class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload): + data: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("data", pre=True) + def transform_message_to_dict(cls, value: str): + # Powertools KinesisDataStreamRecordModel already decodes b64 to str here + return json.loads(value) + + class OrderKinesisRecord(KinesisDataStreamRecordModel): + kinesis: OrderKinesisPayloadRecord + + + processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: OrderKinesisRecord): + return record.kinesis.data.item + + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + +=== "DynamoDB Streams" + + ```python hl_lines="7 11-12 14-21 23-25 27-28 31 37" + import json + + from typing import Dict, Literal + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.parser.models import DynamoDBStreamRecordModel + from aws_lambda_powertools.utilities.typing import LambdaContext + + + class Order(BaseModel): + item: dict + + class OrderDynamoDB(BaseModel): + Message: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("Message", pre=True) + def transform_message_to_dict(cls, value: Dict[Literal["S"], str]): + return json.loads(value["S"]) + + class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel): + NewImage: Optional[OrderDynamoDB] + OldImage: Optional[OrderDynamoDB] + + class OrderDynamoDBRecord(DynamoDBStreamRecordModel): + dynamodb: OrderDynamoDBChangeRecord + + + processor = BatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderKinesisRecord) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: OrderDynamoDBRecord): + return record.dynamodb.NewImage.Message.item + + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` ### Accessing processed messages -Use `PartialSQSProcessor` context manager to access a list of all return values from your `record_handler` function. +Use the context manager to access a list of all returned values from your `record_handler` function. + +> Signature: List[Tuple[Union[SuccessResponse, FailureResponse]]] + +* **When successful**. We will include a tuple with `success`, the result of `record_handler`, and the batch record +* **When failed**. We will include a tuple with `fail`, exception as a string, and the batch record + + +=== "app.py" + + ```python hl_lines="31-38" + import json + + from typing import Any, List, Literal, Union + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import (BatchProcessor, + EventType, + FailureResponse, + SuccessResponse, + batch_processor) + from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.SQS) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: SQSRecord): + payload: str = record.body + if payload: + item: dict = json.loads(payload) + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + def lambda_handler(event, context: LambdaContext): + batch = event["Records"] + with processor(records=batch, processor=processor): + processed_messages: List[Union[SuccessResponse, FailureResponse]] = processor.process() + + for messages in processed_messages: + for message in messages: + status: Union[Literal["success"], Literal["fail"]] = message[0] + result: Any = message[1] + record: SQSRecord = message[2] + + + return processor.response() + ``` + + +### Extending BatchProcessor + +You might want to bring custom logic to the existing `BatchProcessor` to slightly override how we handle successes and failures. + +For these scenarios, you can subclass `BatchProcessor` and quickly override `success_handler` and `failure_handler` methods: + +* **`success_handler()`** – Keeps track of successful batch records +* **`failure_handler()`** – Keeps track of failed batch records + +**Example** + +Let's suppose you'd like to add a metric named `BatchRecordFailures` for each batch record that failed processing: === "app.py" ```python - from aws_lambda_powertools.utilities.batch import PartialSQSProcessor + + from typing import Tuple + + from aws_lambda_powertools import Metrics + from aws_lambda_powertools.metrics import MetricUnit + from aws_lambda_powertools.utilities.batch import batch_processor, BatchProcessor, ExceptionInfo, EventType, FailureResponse + from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord + + + class MyProcessor(BatchProcessor): + def failure_handler(self, record: SQSRecord, exception: ExceptionInfo) -> FailureResponse: + metrics.add_metric(name="BatchRecordFailures", unit=MetricUnit.Count, value=1) + return super().failure_handler(record, exception) + + processor = MyProcessor(event_type=EventType.SQS) + metrics = Metrics(namespace="test") + + + @tracer.capture_method + def record_handler(record: SQSRecord): + payload: str = record.body + if payload: + item: dict = json.loads(payload) + ... + + @metrics.log_metrics(capture_cold_start_metric=True) + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + +### Create your own partial processor + +You can create your own partial batch processor from scratch by inheriting the `BasePartialProcessor` class, and implementing `_prepare()`, `_clean()` and `_process_record()`. + +* **`_process_record()`** – handles all processing logic for each individual message of a batch, including calling the `record_handler` (self.handler) +* **`_prepare()`** – called once as part of the processor initialization +* **`clean()`** – teardown logic called once after `_process_record` completes + +You can then use this class as a context manager, or pass it to `batch_processor` to use as a decorator on your Lambda handler function. + +=== "custom_processor.py" + + ```python hl_lines="3 9 24 30 37 57" + from random import randint + + from aws_lambda_powertools.utilities.batch import BasePartialProcessor, batch_processor + import boto3 + import os + + table_name = os.getenv("TABLE_NAME", "table_not_found") + + class MyPartialProcessor(BasePartialProcessor): + """ + Process a record and stores successful results at a Amazon DynamoDB Table + + Parameters + ---------- + table_name: str + DynamoDB table name to write results to + """ + + def __init__(self, table_name: str): + self.table_name = table_name + + super().__init__() + + def _prepare(self): + # It's called once, *before* processing + # Creates table resource and clean previous results + self.ddb_table = boto3.resource("dynamodb").Table(self.table_name) + self.success_messages.clear() + + def _clean(self): + # It's called once, *after* closing processing all records (closing the context manager) + # Here we're sending, at once, all successful messages to a ddb table + with ddb_table.batch_writer() as batch: + for result in self.success_messages: + batch.put_item(Item=result) + + def _process_record(self, record): + # It handles how your record is processed + # Here we're keeping the status of each run + # where self.handler is the record_handler function passed as an argument + try: + result = self.handler(record) # record_handler passed to decorator/context manager + return self.success_handler(record, result) + except Exception as exc: + return self.failure_handler(record, exc) + + def success_handler(self, record): + entry = ("success", result, record) + message = {"age": result} + self.success_messages.append(message) + return entry + + + def record_handler(record): + return randint(0, 100) + + @batch_processor(record_handler=record_handler, processor=MyPartialProcessor(table_name)) + def lambda_handler(event, context): + return {"statusCode": 200} + ``` + +### Caveats + +#### Tracer response auto-capture for large batch sizes + +When using Tracer to capture responses for each batch record processing, you might exceed 64K of tracing data depending on what you return from your `record_handler` function, or how big is your batch size. + +If that's the case, you can configure [Tracer to disable response auto-capturing](../core/tracer.md#disabling-response-auto-capture){target="_blank"}. + + +```python hl_lines="14" title="Disabling Tracer response auto-capturing" +import json + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor +from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.typing import LambdaContext + + +processor = BatchProcessor(event_type=EventType.SQS) +tracer = Tracer() +logger = Logger() + + +@tracer.capture_method(capture_response=False) +def record_handler(record: SQSRecord): + payload: str = record.body + if payload: + item: dict = json.loads(payload) + ... + +@logger.inject_lambda_context +@tracer.capture_lambda_handler +@batch_processor(record_handler=record_handler, processor=processor) +def lambda_handler(event, context: LambdaContext): + return processor.response() + +``` + +## Testing your code + +As there is no external calls, you can unit test your code with `BatchProcessor` quite easily. + +**Example**: Given a SQS batch where the first batch record succeeds and the second fails processing, we should have a single item reported in the function response. + +=== "test_app.py" + + ```python + import json + + from pathlib import Path + from dataclasses import dataclass + + import pytest + from src.app import lambda_handler, processor + + + def load_event(path: Path): + with path.open() as f: + return json.load(f) + + + @pytest.fixture + def lambda_context(): + @dataclass + class LambdaContext: + function_name: str = "test" + memory_limit_in_mb: int = 128 + invoked_function_arn: str = "arn:aws:lambda:eu-west-1:809313241:function:test" + aws_request_id: str = "52fdfc07-2182-154f-163f-5f0f9a621d72" + + return LambdaContext() + + @pytest.fixture() + def sqs_event(): + """Generates API GW Event""" + return load_event(path=Path("events/sqs_event.json")) + + + def test_app_batch_partial_response(sqs_event, lambda_context): + # GIVEN + processor = app.processor # access processor for additional assertions + successful_record = sqs_event["Records"][0] + failed_record = sqs_event["Records"][1] + expected_response = { + "batchItemFailures: [ + { + "itemIdentifier": failed_record["messageId"] + } + ] + } + + # WHEN + ret = app.lambda_handler(sqs_event, lambda_context) + + # THEN + assert ret == expected_response + assert len(processor.fail_messages) == 1 + assert processor.success_messages[0] == successful_record + ``` + +=== "src/app.py" + + ```python + import json + + from aws_lambda_powertools import Logger, Tracer + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord + from aws_lambda_powertools.utilities.typing import LambdaContext + + + processor = BatchProcessor(event_type=EventType.SQS) + tracer = Tracer() + logger = Logger() + + + @tracer.capture_method + def record_handler(record: SQSRecord): + payload: str = record.body + if payload: + item: dict = json.loads(payload) + ... + + @logger.inject_lambda_context + @tracer.capture_lambda_handler + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context: LambdaContext): + return processor.response() + ``` + +=== "Sample SQS event" + + ```json title="events/sqs_sample.json" + { + "Records": [ + { + "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", + "body": "{\"Message\": \"success\"}", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185" + }, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2: 123456789012:my-queue", + "awsRegion": "us-east-1" + }, + { + "messageId": "244fc6b4-87a3-44ab-83d2-361172410c3a", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", + "body": "SGVsbG8sIHRoaXMgaXMgYSB0ZXN0Lg==", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185" + }, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2: 123456789012:my-queue", + "awsRegion": "us-east-1" + } + ] + } + ``` + + + +## FAQ + +### Choosing between decorator and context manager + +Use context manager when you want access to the processed messages or handle `BatchProcessingError` exception when all records within the batch fail to be processed. + +### Integrating exception handling with Sentry.io + +When using Sentry.io for error monitoring, you can override `failure_handler` to capture each processing exception with Sentry SDK: + +> Credits to [Charles-Axel Dein](https://github.com/awslabs/aws-lambda-powertools-python/issues/293#issuecomment-781961732) + +=== "sentry_integration.py" + + ```python hl_lines="4 7-8" + from typing import Tuple + + from aws_lambda_powertools.utilities.batch import BatchProcessor, FailureResponse + from sentry_sdk import capture_exception + + + class MyProcessor(BatchProcessor): + def failure_handler(self, record, exception) -> FailureResponse: + capture_exception() # send exception to Sentry + return super().failure_handler(record, exception) + ``` + + +## Legacy + +!!! tip "This is kept for historical purposes. Use the new [BatchProcessor](#processing-messages-from-sqs) instead." + + +### Migration guide + +!!! info "keep reading if you are using `sqs_batch_processor` or `PartialSQSProcessor`" + +[As of Nov 2021](https://aws.amazon.com/about-aws/whats-new/2021/11/aws-lambda-partial-batch-response-sqs-event-source/){target="_blank"}, this is no longer needed as both SQS, Kinesis, and DynamoDB Streams offer this capability natively with one caveat - it's an [opt-in feature](#required-resources). + +Being a native feature, we no longer need to instantiate boto3 nor other customizations like exception suppressing – this lowers the cost of your Lambda function as you can delegate deleting partial failures to Lambda. + +!!! tip "It's also easier to test since it's mostly a [contract based response](https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html#sqs-batchfailurereporting-syntax){target="_blank"}." + +You can migrate in three steps: + +1. If you are using **`sqs_batch_decorator`** you can now use **`batch_processor`** decorator +2. If you were using **`PartialSQSProcessor`** you can now use **`BatchProcessor`** +3. Change your Lambda Handler to return the new response format + + +=== "Decorator: Before" + + ```python hl_lines="1 6" + from aws_lambda_powertools.utilities.batch import sqs_batch_processor + + def record_handler(record): + return do_something_with(record["body"]) + + @sqs_batch_processor(record_handler=record_handler) + def lambda_handler(event, context): + return {"statusCode": 200} + ``` + +=== "Decorator: After" + + ```python hl_lines="3 5 11" + import json + + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + + processor = BatchProcessor(event_type=EventType.SQS) + def record_handler(record): return do_something_with(record["body"]) + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + ``` + + +=== "Context manager: Before" + + ```python hl_lines="1-2 4 14 19" + from aws_lambda_powertools.utilities.batch import PartialSQSProcessor + from botocore.config import Config + + config = Config(region_name="us-east-1") + + def record_handler(record): + return_value = do_something_with(record["body"]) + return return_value + + def lambda_handler(event, context): records = event["Records"] - processor = PartialSQSProcessor() + processor = PartialSQSProcessor(config=config) - with processor(records, record_handler) as proc: - result = proc.process() # Returns a list of all results from record_handler + with processor(records, record_handler): + result = processor.process() return result ``` +=== "Context manager: After" + + ```python hl_lines="1 11" + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, batch_processor + + + def record_handler(record): + return_value = do_something_with(record["body"]) + return return_value + + def lambda_handler(event, context): + records = event["Records"] + + processor = BatchProcessor(event_type=EventType.SQS) + + with processor(records, record_handler): + result = processor.process() + + return processor.response() + ``` + ### Customizing boto configuration The **`config`** and **`boto3_session`** parameters enable you to pass in a custom [botocore config object](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html) @@ -267,98 +1450,3 @@ If you want to disable the default behavior where `SQSBatchProcessingError` is r with processor(records, record_handler): result = processor.process() ``` - -### Create your own partial processor - -You can create your own partial batch processor by inheriting the `BasePartialProcessor` class, and implementing `_prepare()`, `_clean()` and `_process_record()`. - -* **`_process_record()`** - Handles all processing logic for each individual message of a batch, including calling the `record_handler` (self.handler) -* **`_prepare()`** - Called once as part of the processor initialization -* **`clean()`** - Teardown logic called once after `_process_record` completes - -You can then use this class as a context manager, or pass it to `batch_processor` to use as a decorator on your Lambda handler function. - -=== "custom_processor.py" - - ```python hl_lines="3 9 24 30 37 57" - from random import randint - - from aws_lambda_powertools.utilities.batch import BasePartialProcessor, batch_processor - import boto3 - import os - - table_name = os.getenv("TABLE_NAME", "table_not_found") - - class MyPartialProcessor(BasePartialProcessor): - """ - Process a record and stores successful results at a Amazon DynamoDB Table - - Parameters - ---------- - table_name: str - DynamoDB table name to write results to - """ - - def __init__(self, table_name: str): - self.table_name = table_name - - super().__init__() - - def _prepare(self): - # It's called once, *before* processing - # Creates table resource and clean previous results - self.ddb_table = boto3.resource("dynamodb").Table(self.table_name) - self.success_messages.clear() - - def _clean(self): - # It's called once, *after* closing processing all records (closing the context manager) - # Here we're sending, at once, all successful messages to a ddb table - with ddb_table.batch_writer() as batch: - for result in self.success_messages: - batch.put_item(Item=result) - - def _process_record(self, record): - # It handles how your record is processed - # Here we're keeping the status of each run - # where self.handler is the record_handler function passed as an argument - try: - result = self.handler(record) # record_handler passed to decorator/context manager - return self.success_handler(record, result) - except Exception as exc: - return self.failure_handler(record, exc) - - def success_handler(self, record): - entry = ("success", result, record) - message = {"age": result} - self.success_messages.append(message) - return entry - - - def record_handler(record): - return randint(0, 100) - - @batch_processor(record_handler=record_handler, processor=MyPartialProcessor(table_name)) - def lambda_handler(event, context): - return {"statusCode": 200} - ``` - -### Integrating exception handling with Sentry.io - -When using Sentry.io for error monitoring, you can override `failure_handler` to include to capture each processing exception: - -> Credits to [Charles-Axel Dein](https://github.com/awslabs/aws-lambda-powertools-python/issues/293#issuecomment-781961732) - -=== "sentry_integration.py" - - ```python hl_lines="4 7-8" - from typing import Tuple - - from aws_lambda_powertools.utilities.batch import PartialSQSProcessor - from sentry_sdk import capture_exception - - class SQSProcessor(PartialSQSProcessor): - def failure_handler(self, record: Event, exception: Tuple) -> Tuple: # type: ignore - capture_exception() # send exception to Sentry - logger.exception("got exception while processing SQS message") - return super().failure_handler(record, exception) # type: ignore - ``` diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index a453f0bfe07..cd6fc67ea15 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -1,12 +1,29 @@ -from typing import Callable +import json +from random import randint +from typing import Callable, Dict, Optional from unittest.mock import patch import pytest from botocore.config import Config from botocore.stub import Stubber -from aws_lambda_powertools.utilities.batch import PartialSQSProcessor, batch_processor, sqs_batch_processor -from aws_lambda_powertools.utilities.batch.exceptions import SQSBatchProcessingError +from aws_lambda_powertools.utilities.batch import ( + BatchProcessor, + EventType, + PartialSQSProcessor, + batch_processor, + sqs_batch_processor, +) +from aws_lambda_powertools.utilities.batch.exceptions import BatchProcessingError, SQSBatchProcessingError +from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord +from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord +from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.parser import BaseModel, validator +from aws_lambda_powertools.utilities.parser.models import DynamoDBStreamChangedRecordModel, DynamoDBStreamRecordModel +from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord as KinesisDataStreamRecordModel +from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecordPayload, SqsRecordModel +from aws_lambda_powertools.utilities.parser.types import Literal +from tests.functional.utils import b64_to_str, str_to_b64 @pytest.fixture(scope="module") @@ -16,7 +33,12 @@ def factory(body: str): "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", "body": body, - "attributes": {}, + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185", + }, "messageAttributes": {}, "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", "eventSource": "aws:sqs", @@ -27,6 +49,53 @@ def factory(body: str): return factory +@pytest.fixture(scope="module") +def kinesis_event_factory() -> Callable: + def factory(body: str): + seq = "".join(str(randint(0, 9)) for _ in range(52)) + return { + "kinesis": { + "kinesisSchemaVersion": "1.0", + "partitionKey": "1", + "sequenceNumber": seq, + "data": str_to_b64(body), + "approximateArrivalTimestamp": 1545084650.987, + }, + "eventSource": "aws:kinesis", + "eventVersion": "1.0", + "eventID": f"shardId-000000000006:{seq}", + "eventName": "aws:kinesis:record", + "invokeIdentityArn": "arn:aws:iam::123456789012:role/lambda-role", + "awsRegion": "us-east-2", + "eventSourceARN": "arn:aws:kinesis:us-east-2:123456789012:stream/lambda-stream", + } + + return factory + + +@pytest.fixture(scope="module") +def dynamodb_event_factory() -> Callable: + def factory(body: str): + seq = "".join(str(randint(0, 9)) for _ in range(10)) + return { + "eventID": "1", + "eventVersion": "1.0", + "dynamodb": { + "Keys": {"Id": {"N": "101"}}, + "NewImage": {"Message": {"S": body}}, + "StreamViewType": "NEW_AND_OLD_IMAGES", + "SequenceNumber": seq, + "SizeBytes": 26, + }, + "awsRegion": "us-west-2", + "eventName": "INSERT", + "eventSourceARN": "eventsource_arn", + "eventSource": "aws:dynamodb", + } + + return factory + + @pytest.fixture(scope="module") def record_handler() -> Callable: def handler(record): @@ -38,6 +107,28 @@ def handler(record): return handler +@pytest.fixture(scope="module") +def kinesis_record_handler() -> Callable: + def handler(record: KinesisStreamRecord): + body = b64_to_str(record.kinesis.data) + if "fail" in body: + raise Exception("Failed to process record.") + return body + + return handler + + +@pytest.fixture(scope="module") +def dynamodb_record_handler() -> Callable: + def handler(record: DynamoDBRecord): + body = record.dynamodb.new_image.get("Message").get_value + if "fail" in body: + raise Exception("Failed to process record.") + return body + + return handler + + @pytest.fixture(scope="module") def config() -> Config: return Config(region_name="us-east-1") @@ -67,6 +158,14 @@ def stubbed_partial_processor_suppressed(config) -> PartialSQSProcessor: yield stubber, processor +@pytest.fixture(scope="module") +def order_event_factory() -> Callable: + def factory(item: Dict) -> str: + return json.dumps({"item": item}) + + return factory + + def test_partial_sqs_processor_context_with_failure(sqs_event_factory, record_handler, partial_processor): """ Test processor with one failing record @@ -290,3 +389,448 @@ def test_partial_sqs_processor_context_only_failure(sqs_event_factory, record_ha ctx.process() assert len(error.value.child_exceptions) == 2 + + +def test_batch_processor_middleware_success_only(sqs_event_factory, record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("success")) + second_record = SQSRecord(sqs_event_factory("success")) + event = {"Records": [first_record.raw_event, second_record.raw_event]} + + processor = BatchProcessor(event_type=EventType.SQS) + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN + assert result["batchItemFailures"] == [] + + +def test_batch_processor_middleware_with_failure(sqs_event_factory, record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("fail")) + second_record = SQSRecord(sqs_event_factory("success")) + event = {"Records": [first_record.raw_event, second_record.raw_event]} + + processor = BatchProcessor(event_type=EventType.SQS) + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN + assert len(result["batchItemFailures"]) == 1 + + +def test_batch_processor_context_success_only(sqs_event_factory, record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("success")) + second_record = SQSRecord(sqs_event_factory("success")) + records = [first_record.raw_event, second_record.raw_event] + processor = BatchProcessor(event_type=EventType.SQS) + + # WHEN + with processor(records, record_handler) as batch: + processed_messages = batch.process() + + # THEN + assert processed_messages == [ + ("success", first_record.body, first_record.raw_event), + ("success", second_record.body, second_record.raw_event), + ] + + assert batch.response() == {"batchItemFailures": []} + + +def test_batch_processor_context_with_failure(sqs_event_factory, record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("failure")) + second_record = SQSRecord(sqs_event_factory("success")) + records = [first_record.raw_event, second_record.raw_event] + processor = BatchProcessor(event_type=EventType.SQS) + + # WHEN + with processor(records, record_handler) as batch: + processed_messages = batch.process() + + # THEN + assert processed_messages[1] == ("success", second_record.body, second_record.raw_event) + assert len(batch.fail_messages) == 1 + assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.message_id}]} + + +def test_batch_processor_kinesis_context_success_only(kinesis_event_factory, kinesis_record_handler): + # GIVEN + first_record = KinesisStreamRecord(kinesis_event_factory("success")) + second_record = KinesisStreamRecord(kinesis_event_factory("success")) + + records = [first_record.raw_event, second_record.raw_event] + processor = BatchProcessor(event_type=EventType.KinesisDataStreams) + + # WHEN + with processor(records, kinesis_record_handler) as batch: + processed_messages = batch.process() + + # THEN + assert processed_messages == [ + ("success", b64_to_str(first_record.kinesis.data), first_record.raw_event), + ("success", b64_to_str(second_record.kinesis.data), second_record.raw_event), + ] + + assert batch.response() == {"batchItemFailures": []} + + +def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kinesis_record_handler): + # GIVEN + first_record = KinesisStreamRecord(kinesis_event_factory("failure")) + second_record = KinesisStreamRecord(kinesis_event_factory("success")) + + records = [first_record.raw_event, second_record.raw_event] + processor = BatchProcessor(event_type=EventType.KinesisDataStreams) + + # WHEN + with processor(records, kinesis_record_handler) as batch: + processed_messages = batch.process() + + # THEN + assert processed_messages[1] == ("success", b64_to_str(second_record.kinesis.data), second_record.raw_event) + assert len(batch.fail_messages) == 1 + assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.kinesis.sequence_number}]} + + +def test_batch_processor_kinesis_middleware_with_failure(kinesis_event_factory, kinesis_record_handler): + # GIVEN + first_record = KinesisStreamRecord(kinesis_event_factory("failure")) + second_record = KinesisStreamRecord(kinesis_event_factory("success")) + event = {"Records": [first_record.raw_event, second_record.raw_event]} + + processor = BatchProcessor(event_type=EventType.KinesisDataStreams) + + @batch_processor(record_handler=kinesis_record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN + assert len(result["batchItemFailures"]) == 1 + + +def test_batch_processor_dynamodb_context_success_only(dynamodb_event_factory, dynamodb_record_handler): + # GIVEN + first_record = dynamodb_event_factory("success") + second_record = dynamodb_event_factory("success") + records = [first_record, second_record] + processor = BatchProcessor(event_type=EventType.DynamoDBStreams) + + # WHEN + with processor(records, dynamodb_record_handler) as batch: + processed_messages = batch.process() + + # THEN + assert processed_messages == [ + ("success", first_record["dynamodb"]["NewImage"]["Message"]["S"], first_record), + ("success", second_record["dynamodb"]["NewImage"]["Message"]["S"], second_record), + ] + + assert batch.response() == {"batchItemFailures": []} + + +def test_batch_processor_dynamodb_context_with_failure(dynamodb_event_factory, dynamodb_record_handler): + # GIVEN + first_record = dynamodb_event_factory("failure") + second_record = dynamodb_event_factory("success") + records = [first_record, second_record] + processor = BatchProcessor(event_type=EventType.DynamoDBStreams) + + # WHEN + with processor(records, dynamodb_record_handler) as batch: + processed_messages = batch.process() + + # THEN + assert processed_messages[1] == ("success", second_record["dynamodb"]["NewImage"]["Message"]["S"], second_record) + assert len(batch.fail_messages) == 1 + assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]} + + +def test_batch_processor_dynamodb_middleware_with_failure(dynamodb_event_factory, dynamodb_record_handler): + # GIVEN + first_record = dynamodb_event_factory("failure") + second_record = dynamodb_event_factory("success") + event = {"Records": [first_record, second_record]} + + processor = BatchProcessor(event_type=EventType.DynamoDBStreams) + + @batch_processor(record_handler=dynamodb_record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN + assert len(result["batchItemFailures"]) == 1 + + +def test_batch_processor_context_model(sqs_event_factory, order_event_factory): + # GIVEN + class Order(BaseModel): + item: dict + + class OrderSqs(SqsRecordModel): + body: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("body", pre=True) + def transform_body_to_dict(cls, value: str): + return json.loads(value) + + def record_handler(record: OrderSqs): + return record.body.item + + order_event = order_event_factory({"type": "success"}) + first_record = sqs_event_factory(order_event) + second_record = sqs_event_factory(order_event) + records = [first_record, second_record] + + # WHEN + processor = BatchProcessor(event_type=EventType.SQS, model=OrderSqs) + with processor(records, record_handler) as batch: + processed_messages = batch.process() + + # THEN + order_item = json.loads(order_event)["item"] + assert processed_messages == [ + ("success", order_item, first_record), + ("success", order_item, second_record), + ] + + assert batch.response() == {"batchItemFailures": []} + + +def test_batch_processor_context_model_with_failure(sqs_event_factory, order_event_factory): + # GIVEN + class Order(BaseModel): + item: dict + + class OrderSqs(SqsRecordModel): + body: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("body", pre=True) + def transform_body_to_dict(cls, value: str): + return json.loads(value) + + def record_handler(record: OrderSqs): + if "fail" in record.body.item["type"]: + raise Exception("Failed to process record.") + return record.body.item + + order_event = order_event_factory({"type": "success"}) + order_event_fail = order_event_factory({"type": "fail"}) + first_record = sqs_event_factory(order_event_fail) + second_record = sqs_event_factory(order_event) + records = [first_record, second_record] + + # WHEN + processor = BatchProcessor(event_type=EventType.SQS, model=OrderSqs) + with processor(records, record_handler) as batch: + batch.process() + + # THEN + assert len(batch.fail_messages) == 1 + assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["messageId"]}]} + + +def test_batch_processor_dynamodb_context_model(dynamodb_event_factory, order_event_factory): + # GIVEN + class Order(BaseModel): + item: dict + + class OrderDynamoDB(BaseModel): + Message: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("Message", pre=True) + def transform_message_to_dict(cls, value: Dict[Literal["S"], str]): + return json.loads(value["S"]) + + class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel): + NewImage: Optional[OrderDynamoDB] + OldImage: Optional[OrderDynamoDB] + + class OrderDynamoDBRecord(DynamoDBStreamRecordModel): + dynamodb: OrderDynamoDBChangeRecord + + def record_handler(record: OrderDynamoDBRecord): + return record.dynamodb.NewImage.Message.item + + order_event = order_event_factory({"type": "success"}) + first_record = dynamodb_event_factory(order_event) + second_record = dynamodb_event_factory(order_event) + records = [first_record, second_record] + + # WHEN + processor = BatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderDynamoDBRecord) + with processor(records, record_handler) as batch: + processed_messages = batch.process() + + # THEN + order_item = json.loads(order_event)["item"] + assert processed_messages == [ + ("success", order_item, first_record), + ("success", order_item, second_record), + ] + + assert batch.response() == {"batchItemFailures": []} + + +def test_batch_processor_dynamodb_context_model_with_failure(dynamodb_event_factory, order_event_factory): + # GIVEN + class Order(BaseModel): + item: dict + + class OrderDynamoDB(BaseModel): + Message: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("Message", pre=True) + def transform_message_to_dict(cls, value: Dict[Literal["S"], str]): + return json.loads(value["S"]) + + class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel): + NewImage: Optional[OrderDynamoDB] + OldImage: Optional[OrderDynamoDB] + + class OrderDynamoDBRecord(DynamoDBStreamRecordModel): + dynamodb: OrderDynamoDBChangeRecord + + def record_handler(record: OrderDynamoDBRecord): + if "fail" in record.dynamodb.NewImage.Message.item["type"]: + raise Exception("Failed to process record.") + return record.dynamodb.NewImage.Message.item + + order_event = order_event_factory({"type": "success"}) + order_event_fail = order_event_factory({"type": "fail"}) + first_record = dynamodb_event_factory(order_event_fail) + second_record = dynamodb_event_factory(order_event) + records = [first_record, second_record] + + # WHEN + processor = BatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderDynamoDBRecord) + with processor(records, record_handler) as batch: + batch.process() + + # THEN + assert len(batch.fail_messages) == 1 + assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]} + + +def test_batch_processor_kinesis_context_parser_model(kinesis_event_factory, order_event_factory): + # GIVEN + class Order(BaseModel): + item: dict + + class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload): + data: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("data", pre=True) + def transform_message_to_dict(cls, value: str): + # Powertools KinesisDataStreamRecordModel already decodes b64 to str here + return json.loads(value) + + class OrderKinesisRecord(KinesisDataStreamRecordModel): + kinesis: OrderKinesisPayloadRecord + + def record_handler(record: OrderKinesisRecord): + return record.kinesis.data.item + + order_event = order_event_factory({"type": "success"}) + first_record = kinesis_event_factory(order_event) + second_record = kinesis_event_factory(order_event) + records = [first_record, second_record] + + # WHEN + processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord) + with processor(records, record_handler) as batch: + processed_messages = batch.process() + + # THEN + order_item = json.loads(order_event)["item"] + assert processed_messages == [ + ("success", order_item, first_record), + ("success", order_item, second_record), + ] + + assert batch.response() == {"batchItemFailures": []} + + +def test_batch_processor_kinesis_context_parser_model_with_failure(kinesis_event_factory, order_event_factory): + # GIVEN + class Order(BaseModel): + item: dict + + class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload): + data: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("data", pre=True) + def transform_message_to_dict(cls, value: str): + # Powertools KinesisDataStreamRecordModel + return json.loads(value) + + class OrderKinesisRecord(KinesisDataStreamRecordModel): + kinesis: OrderKinesisPayloadRecord + + def record_handler(record: OrderKinesisRecord): + if "fail" in record.kinesis.data.item["type"]: + raise Exception("Failed to process record.") + return record.kinesis.data.item + + order_event = order_event_factory({"type": "success"}) + order_event_fail = order_event_factory({"type": "fail"}) + + first_record = kinesis_event_factory(order_event_fail) + second_record = kinesis_event_factory(order_event) + records = [first_record, second_record] + + # WHEN + processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord) + with processor(records, record_handler) as batch: + batch.process() + + # THEN + assert len(batch.fail_messages) == 1 + assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["kinesis"]["sequenceNumber"]}]} + + +def test_batch_processor_error_when_entire_batch_fails(sqs_event_factory, record_handler): + # GIVEN + first_record = SQSRecord(sqs_event_factory("fail")) + second_record = SQSRecord(sqs_event_factory("fail")) + event = {"Records": [first_record.raw_event, second_record.raw_event]} + + processor = BatchProcessor(event_type=EventType.SQS) + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN/THEN + with pytest.raises(BatchProcessingError): + lambda_handler(event, {}) diff --git a/tests/functional/utils.py b/tests/functional/utils.py index a58d27f3526..703f21744e2 100644 --- a/tests/functional/utils.py +++ b/tests/functional/utils.py @@ -1,3 +1,4 @@ +import base64 import json from pathlib import Path from typing import Any @@ -6,3 +7,11 @@ def load_event(file_name: str) -> Any: path = Path(str(Path(__file__).parent.parent) + "/events/" + file_name) return json.loads(path.read_text()) + + +def str_to_b64(data: str) -> str: + return base64.b64encode(data.encode()).decode("utf-8") + + +def b64_to_str(data: str) -> str: + return base64.b64decode(data.encode()).decode("utf-8") diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000000..e69de29bb2d