4
4
Batch processing utilities
5
5
"""
6
6
import copy
7
+ import inspect
7
8
import logging
8
9
import sys
9
10
from abc import ABC , abstractmethod
15
16
from aws_lambda_powertools .utilities .data_classes .dynamo_db_stream_event import DynamoDBRecord
16
17
from aws_lambda_powertools .utilities .data_classes .kinesis_stream_event import KinesisStreamRecord
17
18
from aws_lambda_powertools .utilities .data_classes .sqs_event import SQSRecord
19
+ from aws_lambda_powertools .utilities .typing import LambdaContext
18
20
19
21
logger = logging .getLogger (__name__ )
20
22
@@ -55,6 +57,8 @@ class BasePartialProcessor(ABC):
55
57
Abstract class for batch processors.
56
58
"""
57
59
60
+ lambda_context : LambdaContext
61
+
58
62
def __init__ (self ):
59
63
self .success_messages : List [BatchEventTypes ] = []
60
64
self .fail_messages : List [BatchEventTypes ] = []
@@ -94,7 +98,7 @@ def __enter__(self):
94
98
def __exit__ (self , exception_type , exception_value , traceback ):
95
99
self ._clean ()
96
100
97
- def __call__ (self , records : List [dict ], handler : Callable ):
101
+ def __call__ (self , records : List [dict ], handler : Callable , lambda_context : Optional [ LambdaContext ] = None ):
98
102
"""
99
103
Set instance attributes before execution
100
104
@@ -107,6 +111,31 @@ def __call__(self, records: List[dict], handler: Callable):
107
111
"""
108
112
self .records = records
109
113
self .handler = handler
114
+
115
+ # NOTE: If a record handler has `lambda_context` parameter in its function signature, we inject it.
116
+ # This is the earliest we can inspect for signature to prevent impacting performance.
117
+ #
118
+ # Mechanism:
119
+ #
120
+ # 1. When using the `@batch_processor` decorator, this happens automatically.
121
+ # 2. When using the context manager, customers have to include `lambda_context` param.
122
+ #
123
+ # Scenario: Injects Lambda context
124
+ #
125
+ # def record_handler(record, lambda_context): ... # noqa: E800
126
+ # with processor(records=batch, handler=record_handler, lambda_context=context): ... # noqa: E800
127
+ #
128
+ # Scenario: Does NOT inject Lambda context (default)
129
+ #
130
+ # def record_handler(record): pass # noqa: E800
131
+ # with processor(records=batch, handler=record_handler): ... # noqa: E800
132
+ #
133
+ if lambda_context is None :
134
+ self ._handler_accepts_lambda_context = False
135
+ else :
136
+ self .lambda_context = lambda_context
137
+ self ._handler_accepts_lambda_context = "lambda_context" in inspect .signature (self .handler ).parameters
138
+
110
139
return self
111
140
112
141
def success_handler (self , record , result : Any ) -> SuccessResponse :
@@ -155,7 +184,7 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
155
184
156
185
@lambda_handler_decorator
157
186
def batch_processor (
158
- handler : Callable , event : Dict , context : Dict , record_handler : Callable , processor : BasePartialProcessor
187
+ handler : Callable , event : Dict , context : LambdaContext , record_handler : Callable , processor : BasePartialProcessor
159
188
):
160
189
"""
161
190
Middleware to handle batch event processing
@@ -166,7 +195,7 @@ def batch_processor(
166
195
Lambda's handler
167
196
event: Dict
168
197
Lambda's Event
169
- context: Dict
198
+ context: LambdaContext
170
199
Lambda's Context
171
200
record_handler: Callable
172
201
Callable to process each record from the batch
@@ -193,7 +222,7 @@ def batch_processor(
193
222
"""
194
223
records = event ["Records" ]
195
224
196
- with processor (records , record_handler ):
225
+ with processor (records , record_handler , lambda_context = context ):
197
226
processor .process ()
198
227
199
228
return handler (event , context )
@@ -365,7 +394,11 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons
365
394
"""
366
395
data = self ._to_batch_type (record = record , event_type = self .event_type , model = self .model )
367
396
try :
368
- result = self .handler (record = data )
397
+ if self ._handler_accepts_lambda_context :
398
+ result = self .handler (record = data , lambda_context = self .lambda_context )
399
+ else :
400
+ result = self .handler (record = data )
401
+
369
402
return self .success_handler (record = record , result = result )
370
403
except Exception :
371
404
return self .failure_handler (record = data , exception = sys .exc_info ())
0 commit comments