Skip to content

Commit 573f048

Browse files
pushkarchawdaPushkar Chawda
and
Pushkar Chawda
authored
Tidy up changes added with PR#115 and limit it to only work with Python3.12 in Lambda. (#124)
* Tidy up changes added with PR#115 andlimit it to only work with Python3.12 in Lambda. --------- Co-authored-by: Pushkar Chawda <[email protected]>
1 parent 80eefef commit 573f048

File tree

4 files changed

+43
-6
lines changed

4 files changed

+43
-6
lines changed

awslambdaric/bootstrap.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,14 @@ def run(app_root, handler, lambda_runtime_api_addr):
462462
sys.stdout = Unbuffered(sys.stdout)
463463
sys.stderr = Unbuffered(sys.stderr)
464464

465+
use_thread_for_polling_next = (
466+
os.environ.get("AWS_EXECUTION_ENV") == "AWS_Lambda_python3.12"
467+
)
468+
465469
with create_log_sink() as log_sink:
466-
lambda_runtime_client = LambdaRuntimeClient(lambda_runtime_api_addr)
470+
lambda_runtime_client = LambdaRuntimeClient(
471+
lambda_runtime_api_addr, use_thread_for_polling_next
472+
)
467473

468474
try:
469475
_setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink)

awslambdaric/lambda_runtime_client.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
"""
44

55
import sys
6-
from concurrent.futures import ThreadPoolExecutor
76
from awslambdaric import __version__
7+
from .lambda_runtime_exception import FaultException
88

99

1010
def _user_agent():
@@ -49,8 +49,9 @@ class LambdaRuntimeClient(object):
4949
and response. It allows for function authors to override the the default implementation, LambdaMarshaller which
5050
unmarshals and marshals JSON, to an instance of a class that implements the same interface."""
5151

52-
def __init__(self, lambda_runtime_address):
52+
def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False):
5353
self.lambda_runtime_address = lambda_runtime_address
54+
self.use_thread_for_polling_next = use_thread_for_polling_next
5455

5556
def post_init_error(self, error_response_data):
5657
# These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`.
@@ -69,9 +70,23 @@ def post_init_error(self, error_response_data):
6970
raise LambdaRuntimeClientError(endpoint, response.code, response_body)
7071

7172
def wait_next_invocation(self):
72-
with ThreadPoolExecutor() as e:
73-
fut = e.submit(runtime_client.next)
74-
response_body, headers = fut.result()
73+
# Calling runtime_client.next() from a separate thread unblocks the main thread,
74+
# which can then process signals.
75+
if self.use_thread_for_polling_next:
76+
try:
77+
from concurrent.futures import ThreadPoolExecutor
78+
79+
with ThreadPoolExecutor(max_workers=1) as executor:
80+
future = executor.submit(runtime_client.next)
81+
response_body, headers = future.result()
82+
except Exception as e:
83+
raise FaultException(
84+
FaultException.LAMBDA_RUNTIME_CLIENT_ERROR,
85+
"LAMBDA_RUNTIME Failed to get next invocation: {}".format(str(e)),
86+
None,
87+
)
88+
else:
89+
response_body, headers = runtime_client.next()
7590
return InvocationRequest(
7691
invoke_id=headers.get("Lambda-Runtime-Aws-Request-Id"),
7792
x_amzn_trace_id=headers.get("Lambda-Runtime-Trace-Id"),

awslambdaric/lambda_runtime_exception.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class FaultException(Exception):
1212
BUILT_IN_MODULE_CONFLICT = "Runtime.BuiltInModuleConflict"
1313
MALFORMED_HANDLER_NAME = "Runtime.MalformedHandlerName"
1414
LAMBDA_CONTEXT_UNMARSHAL_ERROR = "Runtime.LambdaContextUnmarshalError"
15+
LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError"
1516

1617
def __init__(self, exception_type, msg, trace=None):
1718
self.msg = msg

tests/test_lambda_runtime_client.py

+15
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,21 @@ def test_wait_next_invocation(self, mock_runtime_client):
8484
self.assertEqual(event_request.content_type, "application/json")
8585
self.assertEqual(event_request.event_body, response_body)
8686

87+
# Using ThreadPoolExecutor to polling next()
88+
runtime_client = LambdaRuntimeClient("localhost:1234", True)
89+
90+
event_request = runtime_client.wait_next_invocation()
91+
92+
self.assertIsNotNone(event_request)
93+
self.assertEqual(event_request.invoke_id, "RID1234")
94+
self.assertEqual(event_request.x_amzn_trace_id, "TID1234")
95+
self.assertEqual(event_request.invoked_function_arn, "FARN1234")
96+
self.assertEqual(event_request.deadline_time_in_ms, 12)
97+
self.assertEqual(event_request.client_context, "client_context")
98+
self.assertEqual(event_request.cognito_identity, "cognito_identity")
99+
self.assertEqual(event_request.content_type, "application/json")
100+
self.assertEqual(event_request.event_body, response_body)
101+
87102
@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
88103
def test_post_init_error(self, MockHTTPConnection):
89104
mock_conn = MockHTTPConnection.return_value

0 commit comments

Comments
 (0)