Skip to content

Commit 9545f69

Browse files
authored
feat(batch): inject lambda_context if record handler signature accepts it (#1561)
1 parent e076037 commit 9545f69

File tree

2 files changed

+90
-5
lines changed

2 files changed

+90
-5
lines changed

Diff for: aws_lambda_powertools/utilities/batch/base.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Batch processing utilities
55
"""
66
import copy
7+
import inspect
78
import logging
89
import sys
910
from abc import ABC, abstractmethod
@@ -15,6 +16,7 @@
1516
from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord
1617
from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord
1718
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
19+
from aws_lambda_powertools.utilities.typing import LambdaContext
1820

1921
logger = logging.getLogger(__name__)
2022

@@ -55,6 +57,8 @@ class BasePartialProcessor(ABC):
5557
Abstract class for batch processors.
5658
"""
5759

60+
lambda_context: LambdaContext
61+
5862
def __init__(self):
5963
self.success_messages: List[BatchEventTypes] = []
6064
self.fail_messages: List[BatchEventTypes] = []
@@ -94,7 +98,7 @@ def __enter__(self):
9498
def __exit__(self, exception_type, exception_value, traceback):
9599
self._clean()
96100

97-
def __call__(self, records: List[dict], handler: Callable):
101+
def __call__(self, records: List[dict], handler: Callable, lambda_context: Optional[LambdaContext] = None):
98102
"""
99103
Set instance attributes before execution
100104
@@ -107,6 +111,31 @@ def __call__(self, records: List[dict], handler: Callable):
107111
"""
108112
self.records = records
109113
self.handler = handler
114+
115+
# NOTE: If a record handler has `lambda_context` parameter in its function signature, we inject it.
116+
# This is the earliest we can inspect for signature to prevent impacting performance.
117+
#
118+
# Mechanism:
119+
#
120+
# 1. When using the `@batch_processor` decorator, this happens automatically.
121+
# 2. When using the context manager, customers have to include `lambda_context` param.
122+
#
123+
# Scenario: Injects Lambda context
124+
#
125+
# def record_handler(record, lambda_context): ... # noqa: E800
126+
# with processor(records=batch, handler=record_handler, lambda_context=context): ... # noqa: E800
127+
#
128+
# Scenario: Does NOT inject Lambda context (default)
129+
#
130+
# def record_handler(record): pass # noqa: E800
131+
# with processor(records=batch, handler=record_handler): ... # noqa: E800
132+
#
133+
if lambda_context is None:
134+
self._handler_accepts_lambda_context = False
135+
else:
136+
self.lambda_context = lambda_context
137+
self._handler_accepts_lambda_context = "lambda_context" in inspect.signature(self.handler).parameters
138+
110139
return self
111140

112141
def success_handler(self, record, result: Any) -> SuccessResponse:
@@ -155,7 +184,7 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
155184

156185
@lambda_handler_decorator
157186
def batch_processor(
158-
handler: Callable, event: Dict, context: Dict, record_handler: Callable, processor: BasePartialProcessor
187+
handler: Callable, event: Dict, context: LambdaContext, record_handler: Callable, processor: BasePartialProcessor
159188
):
160189
"""
161190
Middleware to handle batch event processing
@@ -166,7 +195,7 @@ def batch_processor(
166195
Lambda's handler
167196
event: Dict
168197
Lambda's Event
169-
context: Dict
198+
context: LambdaContext
170199
Lambda's Context
171200
record_handler: Callable
172201
Callable to process each record from the batch
@@ -193,7 +222,7 @@ def batch_processor(
193222
"""
194223
records = event["Records"]
195224

196-
with processor(records, record_handler):
225+
with processor(records, record_handler, lambda_context=context):
197226
processor.process()
198227

199228
return handler(event, context)
@@ -365,7 +394,11 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons
365394
"""
366395
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
367396
try:
368-
result = self.handler(record=data)
397+
if self._handler_accepts_lambda_context:
398+
result = self.handler(record=data, lambda_context=self.lambda_context)
399+
else:
400+
result = self.handler(record=data)
401+
369402
return self.success_handler(record=record, result=result)
370403
except Exception:
371404
return self.failure_handler(record=data, exception=sys.exc_info())

Diff for: tests/functional/test_utilities_batch.py

+52
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from random import randint
44
from typing import Callable, Dict, Optional
55
from unittest.mock import patch
6+
from uuid import uuid4
67

78
import pytest
89
from botocore.config import Config
@@ -24,6 +25,7 @@
2425
from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord as KinesisDataStreamRecordModel
2526
from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecordPayload, SqsRecordModel
2627
from aws_lambda_powertools.utilities.parser.types import Literal
28+
from aws_lambda_powertools.utilities.typing import LambdaContext
2729
from tests.functional.utils import b64_to_str, str_to_b64
2830

2931

@@ -167,6 +169,18 @@ def factory(item: Dict) -> str:
167169
return factory
168170

169171

172+
@pytest.fixture(scope="module")
173+
def lambda_context() -> LambdaContext:
174+
class DummyLambdaContext:
175+
def __init__(self):
176+
self.function_name = "test-func"
177+
self.memory_limit_in_mb = 128
178+
self.invoked_function_arn = "arn:aws:lambda:eu-west-1:809313241234:function:test-func"
179+
self.aws_request_id = f"{uuid4()}"
180+
181+
return DummyLambdaContext
182+
183+
170184
@pytest.mark.parametrize(
171185
"success_messages_count",
172186
([1, 18, 34]),
@@ -908,3 +922,41 @@ def lambda_handler(event, context):
908922

909923
# THEN raise BatchProcessingError
910924
assert "All records failed processing. " in str(e.value)
925+
926+
927+
def test_batch_processor_handler_receives_lambda_context(sqs_event_factory, lambda_context: LambdaContext):
928+
# GIVEN
929+
def record_handler(record, lambda_context: LambdaContext = None):
930+
return lambda_context.function_name == "test-func"
931+
932+
first_record = SQSRecord(sqs_event_factory("success"))
933+
event = {"Records": [first_record.raw_event]}
934+
935+
processor = BatchProcessor(event_type=EventType.SQS)
936+
937+
@batch_processor(record_handler=record_handler, processor=processor)
938+
def lambda_handler(event, context):
939+
return processor.response()
940+
941+
# WHEN/THEN
942+
lambda_handler(event, lambda_context())
943+
944+
945+
def test_batch_processor_context_manager_handler_receives_lambda_context(
946+
sqs_event_factory, lambda_context: LambdaContext
947+
):
948+
# GIVEN
949+
def record_handler(record, lambda_context: LambdaContext = None):
950+
return lambda_context.function_name == "test-func"
951+
952+
first_record = SQSRecord(sqs_event_factory("success"))
953+
event = {"Records": [first_record.raw_event]}
954+
955+
processor = BatchProcessor(event_type=EventType.SQS)
956+
957+
def lambda_handler(event, context):
958+
with processor(records=event["Records"], handler=record_handler, lambda_context=context) as batch:
959+
batch.process()
960+
961+
# WHEN/THEN
962+
lambda_handler(event, lambda_context())

0 commit comments

Comments
 (0)