From f3b64c35e225645e9b481d1e082552afc8993883 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Fri, 30 Sep 2022 16:14:26 +0200 Subject: [PATCH 1/2] feat(batch): pass lambda_context if record handler signature accepts it --- aws_lambda_powertools/utilities/batch/base.py | 16 ++++++++--- tests/functional/test_utilities_batch.py | 27 +++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 1122bea4c03..6f61f05d88c 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -4,6 +4,7 @@ Batch processing utilities """ import copy +import inspect import logging import sys from abc import ABC, abstractmethod @@ -15,6 +16,7 @@ from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -55,6 +57,8 @@ class BasePartialProcessor(ABC): Abstract class for batch processors. """ + lambda_context: LambdaContext + def __init__(self): self.success_messages: List[BatchEventTypes] = [] self.fail_messages: List[BatchEventTypes] = [] @@ -155,7 +159,7 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse: @lambda_handler_decorator def batch_processor( - handler: Callable, event: Dict, context: Dict, record_handler: Callable, processor: BasePartialProcessor + handler: Callable, event: Dict, context: LambdaContext, record_handler: Callable, processor: BasePartialProcessor ): """ Middleware to handle batch event processing @@ -166,7 +170,7 @@ def batch_processor( Lambda's handler event: Dict Lambda's Event - context: Dict + context: LambdaContext Lambda's Context record_handler: Callable Callable to process each record from the batch @@ -193,6 +197,7 @@ def batch_processor( """ records = event["Records"] + processor.lambda_context = context with processor(records, record_handler): processor.process() @@ -364,8 +369,13 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons A batch record to be processed. """ data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) + handler_signature = inspect.signature(self.handler).parameters try: - result = self.handler(record=data) + # NOTE: negative first for faster execution, since that's how >80% customers use + if "lambda_context" not in handler_signature: + result = self.handler(record=data) + else: + result = self.handler(record=data, lambda_context=self.lambda_context) return self.success_handler(record=record, result=result) except Exception: return self.failure_handler(record=data, exception=sys.exc_info()) diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index a5e1e706437..0f3d286c329 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -3,6 +3,7 @@ from random import randint from typing import Callable, Dict, Optional from unittest.mock import patch +from uuid import uuid4 import pytest from botocore.config import Config @@ -24,6 +25,7 @@ from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord as KinesisDataStreamRecordModel from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecordPayload, SqsRecordModel from aws_lambda_powertools.utilities.parser.types import Literal +from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.utils import b64_to_str, str_to_b64 @@ -908,3 +910,28 @@ def lambda_handler(event, context): # THEN raise BatchProcessingError assert "All records failed processing. " in str(e.value) + + +def test_batch_processor_handler_receives_lambda_context(sqs_event_factory): + # GIVEN + class DummyLambdaContext: + def __init__(self): + self.function_name = "test-func" + self.memory_limit_in_mb = 128 + self.invoked_function_arn = "arn:aws:lambda:eu-west-1:809313241234:function:test-func" + self.aws_request_id = f"{uuid4()}" + + def record_handler(record, lambda_context: LambdaContext = None): + return lambda_context.function_name == "test-func" + + first_record = SQSRecord(sqs_event_factory("success")) + event = {"Records": [first_record.raw_event]} + + processor = BatchProcessor(event_type=EventType.SQS) + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN/THEN + lambda_handler(event, DummyLambdaContext()) From 280355d1788b0e78557382c472cd11a779e7ce4f Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Fri, 30 Sep 2022 16:45:28 +0200 Subject: [PATCH 2/2] feat(batch): optimize logic for perf and to support customers using context manager --- aws_lambda_powertools/utilities/batch/base.py | 39 +++++++++++++---- tests/functional/test_utilities_batch.py | 43 +++++++++++++++---- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 6f61f05d88c..80191503055 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -98,7 +98,7 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): self._clean() - def __call__(self, records: List[dict], handler: Callable): + def __call__(self, records: List[dict], handler: Callable, lambda_context: Optional[LambdaContext] = None): """ Set instance attributes before execution @@ -111,6 +111,31 @@ def __call__(self, records: List[dict], handler: Callable): """ self.records = records self.handler = handler + + # NOTE: If a record handler has `lambda_context` parameter in its function signature, we inject it. + # This is the earliest we can inspect for signature to prevent impacting performance. + # + # Mechanism: + # + # 1. When using the `@batch_processor` decorator, this happens automatically. + # 2. When using the context manager, customers have to include `lambda_context` param. + # + # Scenario: Injects Lambda context + # + # def record_handler(record, lambda_context): ... # noqa: E800 + # with processor(records=batch, handler=record_handler, lambda_context=context): ... # noqa: E800 + # + # Scenario: Does NOT inject Lambda context (default) + # + # def record_handler(record): pass # noqa: E800 + # with processor(records=batch, handler=record_handler): ... # noqa: E800 + # + if lambda_context is None: + self._handler_accepts_lambda_context = False + else: + self.lambda_context = lambda_context + self._handler_accepts_lambda_context = "lambda_context" in inspect.signature(self.handler).parameters + return self def success_handler(self, record, result: Any) -> SuccessResponse: @@ -197,8 +222,7 @@ def batch_processor( """ records = event["Records"] - processor.lambda_context = context - with processor(records, record_handler): + with processor(records, record_handler, lambda_context=context): processor.process() return handler(event, context) @@ -369,13 +393,12 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons A batch record to be processed. """ data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) - handler_signature = inspect.signature(self.handler).parameters try: - # NOTE: negative first for faster execution, since that's how >80% customers use - if "lambda_context" not in handler_signature: - result = self.handler(record=data) - else: + if self._handler_accepts_lambda_context: result = self.handler(record=data, lambda_context=self.lambda_context) + else: + result = self.handler(record=data) + return self.success_handler(record=record, result=result) except Exception: return self.failure_handler(record=data, exception=sys.exc_info()) diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index 0f3d286c329..8654b96e9b1 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -169,6 +169,18 @@ def factory(item: Dict) -> str: return factory +@pytest.fixture(scope="module") +def lambda_context() -> LambdaContext: + class DummyLambdaContext: + def __init__(self): + self.function_name = "test-func" + self.memory_limit_in_mb = 128 + self.invoked_function_arn = "arn:aws:lambda:eu-west-1:809313241234:function:test-func" + self.aws_request_id = f"{uuid4()}" + + return DummyLambdaContext + + @pytest.mark.parametrize( "success_messages_count", ([1, 18, 34]), @@ -912,15 +924,8 @@ def lambda_handler(event, context): assert "All records failed processing. " in str(e.value) -def test_batch_processor_handler_receives_lambda_context(sqs_event_factory): +def test_batch_processor_handler_receives_lambda_context(sqs_event_factory, lambda_context: LambdaContext): # GIVEN - class DummyLambdaContext: - def __init__(self): - self.function_name = "test-func" - self.memory_limit_in_mb = 128 - self.invoked_function_arn = "arn:aws:lambda:eu-west-1:809313241234:function:test-func" - self.aws_request_id = f"{uuid4()}" - def record_handler(record, lambda_context: LambdaContext = None): return lambda_context.function_name == "test-func" @@ -934,4 +939,24 @@ def lambda_handler(event, context): return processor.response() # WHEN/THEN - lambda_handler(event, DummyLambdaContext()) + lambda_handler(event, lambda_context()) + + +def test_batch_processor_context_manager_handler_receives_lambda_context( + sqs_event_factory, lambda_context: LambdaContext +): + # GIVEN + def record_handler(record, lambda_context: LambdaContext = None): + return lambda_context.function_name == "test-func" + + first_record = SQSRecord(sqs_event_factory("success")) + event = {"Records": [first_record.raw_event]} + + processor = BatchProcessor(event_type=EventType.SQS) + + def lambda_handler(event, context): + with processor(records=event["Records"], handler=record_handler, lambda_context=context) as batch: + batch.process() + + # WHEN/THEN + lambda_handler(event, lambda_context())