Skip to content

feat(batch): inject lambda_context if record handler signature accepts it #1561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions aws_lambda_powertools/utilities/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Batch processing utilities
"""
import copy
import inspect
import logging
import sys
from abc import ABC, abstractmethod
Expand All @@ -15,6 +16,7 @@
from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord
from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,6 +57,8 @@ class BasePartialProcessor(ABC):
Abstract class for batch processors.
"""

lambda_context: LambdaContext

def __init__(self):
self.success_messages: List[BatchEventTypes] = []
self.fail_messages: List[BatchEventTypes] = []
Expand Down Expand Up @@ -94,7 +98,7 @@ def __enter__(self):
def __exit__(self, exception_type, exception_value, traceback):
self._clean()

def __call__(self, records: List[dict], handler: Callable):
def __call__(self, records: List[dict], handler: Callable, lambda_context: Optional[LambdaContext] = None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: take the opportunity to make records typing more specific. Need to check with Mypy accepts it due to liskov substitution principle.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I can access record properties in my Lambda code, I'm not sure this is a real concern for now. Maybe in the future it's worth spending some time on it?
But yes, I love static and strong typing ❤️

"""
Set instance attributes before execution

Expand All @@ -107,6 +111,31 @@ def __call__(self, records: List[dict], handler: Callable):
"""
self.records = records
self.handler = handler

# NOTE: If a record handler has `lambda_context` parameter in its function signature, we inject it.
# This is the earliest we can inspect for signature to prevent impacting performance.
#
# Mechanism:
#
# 1. When using the `@batch_processor` decorator, this happens automatically.
# 2. When using the context manager, customers have to include `lambda_context` param.
#
# Scenario: Injects Lambda context
#
# def record_handler(record, lambda_context): ... # noqa: E800
# with processor(records=batch, handler=record_handler, lambda_context=context): ... # noqa: E800
#
# Scenario: Does NOT inject Lambda context (default)
#
# def record_handler(record): pass # noqa: E800
# with processor(records=batch, handler=record_handler): ... # noqa: E800
#
if lambda_context is None:
self._handler_accepts_lambda_context = False
else:
self.lambda_context = lambda_context
self._handler_accepts_lambda_context = "lambda_context" in inspect.signature(self.handler).parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update the documentation to make it clear that the method signature must have a lambda_context parameter otherwise it won't work. The first time I tried to add a context parameter and I couldn't access LambdaContext properties LOL 😵‍💫

Mentioning #1369 not to forget when refactoring documentation.


return self

def success_handler(self, record, result: Any) -> SuccessResponse:
Expand Down Expand Up @@ -155,7 +184,7 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:

@lambda_handler_decorator
def batch_processor(
handler: Callable, event: Dict, context: Dict, record_handler: Callable, processor: BasePartialProcessor
handler: Callable, event: Dict, context: LambdaContext, record_handler: Callable, processor: BasePartialProcessor
):
"""
Middleware to handle batch event processing
Expand All @@ -166,7 +195,7 @@ def batch_processor(
Lambda's handler
event: Dict
Lambda's Event
context: Dict
context: LambdaContext
Lambda's Context
record_handler: Callable
Callable to process each record from the batch
Expand All @@ -193,7 +222,7 @@ def batch_processor(
"""
records = event["Records"]

with processor(records, record_handler):
with processor(records, record_handler, lambda_context=context):
processor.process()

return handler(event, context)
Expand Down Expand Up @@ -365,7 +394,11 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons
"""
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
try:
result = self.handler(record=data)
if self._handler_accepts_lambda_context:
result = self.handler(record=data, lambda_context=self.lambda_context)
else:
result = self.handler(record=data)

return self.success_handler(record=record, result=result)
except Exception:
return self.failure_handler(record=data, exception=sys.exc_info())
Expand Down
52 changes: 52 additions & 0 deletions tests/functional/test_utilities_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from random import randint
from typing import Callable, Dict, Optional
from unittest.mock import patch
from uuid import uuid4

import pytest
from botocore.config import Config
Expand All @@ -24,6 +25,7 @@
from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord as KinesisDataStreamRecordModel
from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecordPayload, SqsRecordModel
from aws_lambda_powertools.utilities.parser.types import Literal
from aws_lambda_powertools.utilities.typing import LambdaContext
from tests.functional.utils import b64_to_str, str_to_b64


Expand Down Expand Up @@ -167,6 +169,18 @@ def factory(item: Dict) -> str:
return factory


@pytest.fixture(scope="module")
def lambda_context() -> LambdaContext:
class DummyLambdaContext:
def __init__(self):
self.function_name = "test-func"
self.memory_limit_in_mb = 128
self.invoked_function_arn = "arn:aws:lambda:eu-west-1:809313241234:function:test-func"
self.aws_request_id = f"{uuid4()}"

return DummyLambdaContext


@pytest.mark.parametrize(
"success_messages_count",
([1, 18, 34]),
Expand Down Expand Up @@ -908,3 +922,41 @@ def lambda_handler(event, context):

# THEN raise BatchProcessingError
assert "All records failed processing. " in str(e.value)


def test_batch_processor_handler_receives_lambda_context(sqs_event_factory, lambda_context: LambdaContext):
# GIVEN
def record_handler(record, lambda_context: LambdaContext = None):
return lambda_context.function_name == "test-func"

first_record = SQSRecord(sqs_event_factory("success"))
event = {"Records": [first_record.raw_event]}

processor = BatchProcessor(event_type=EventType.SQS)

@batch_processor(record_handler=record_handler, processor=processor)
def lambda_handler(event, context):
return processor.response()

# WHEN/THEN
lambda_handler(event, lambda_context())


def test_batch_processor_context_manager_handler_receives_lambda_context(
sqs_event_factory, lambda_context: LambdaContext
):
# GIVEN
def record_handler(record, lambda_context: LambdaContext = None):
return lambda_context.function_name == "test-func"

first_record = SQSRecord(sqs_event_factory("success"))
event = {"Records": [first_record.raw_event]}

processor = BatchProcessor(event_type=EventType.SQS)

def lambda_handler(event, context):
with processor(records=event["Records"], handler=record_handler, lambda_context=context) as batch:
batch.process()

# WHEN/THEN
lambda_handler(event, lambda_context())