From eda8b26446f4d637f96870cbbd97598c4360778e Mon Sep 17 00:00:00 2001 From: Muthu Venkatachalam Date: Sat, 25 Feb 2023 20:52:16 -0500 Subject: [PATCH] add route kwargs - event & context --- .../event_handler/api_gateway.py | 3 +- .../event_handler/test_api_gateway.py | 197 +++++++++--------- .../event_handler/test_lambda_function_url.py | 9 +- tests/functional/event_handler/test_router.py | 9 +- 4 files changed, 111 insertions(+), 107 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 78993f92c5e..e9c63aaa962 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -667,7 +667,8 @@ def _not_found(self, method: str) -> ResponseBuilder: def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: """Actually call the matching route with any provided keyword arguments.""" try: - return ResponseBuilder(self._to_response(route.func(**args)), route) + kwargs = {"event": self.current_event, "context": self.context} + return ResponseBuilder(self._to_response(route.func(**args, **kwargs)), route) except Exception as exc: response_builder = self._call_exception_handler(exc, route) if response_builder: diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index ad9f834dbb2..5c46840bcd9 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -38,6 +38,7 @@ APIGatewayProxyEventV2, event_source, ) +from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext from tests.functional.utils import load_event @@ -61,10 +62,10 @@ def test_alb_event(): app = ALBResolver() @app.get("/lambda") - def foo(): - assert isinstance(app.current_event, ALBEvent) - assert app.lambda_context == {} - assert app.current_event.request_context.elb_target_group_arn is not None + def foo(event: ALBEvent, context: LambdaContext): + assert isinstance(event, ALBEvent) + assert context == {} + assert event.request_context.elb_target_group_arn is not None return Response(200, content_types.TEXT_HTML, "foo") # WHEN calling the event handler @@ -82,10 +83,10 @@ def test_alb_event_path_trailing_slash(json_dump): app = ALBResolver() @app.get("/lambda") - def foo(): - assert isinstance(app.current_event, ALBEvent) - assert app.lambda_context == {} - assert app.current_event.request_context.elb_target_group_arn is not None + def foo(event: ALBEvent, context: LambdaContext): + assert isinstance(event, ALBEvent) + assert context == {} + assert event.request_context.elb_target_group_arn is not None return Response(200, content_types.TEXT_HTML, "foo") # WHEN calling the event handler using path with trailing "/" @@ -103,10 +104,10 @@ def test_api_gateway_v1(): app = APIGatewayRestResolver() @app.get("/my/path") - def get_lambda() -> Response: - assert isinstance(app.current_event, APIGatewayProxyEvent) - assert app.lambda_context == {} - assert app.current_event.request_context.domain_name == "id.execute-api.us-east-1.amazonaws.com" + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: + assert isinstance(event, APIGatewayProxyEvent) + assert context == {} + assert event.request_context.domain_name == "id.execute-api.us-east-1.amazonaws.com" return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"})) # WHEN calling the event handler @@ -123,7 +124,7 @@ def test_api_gateway_v1_path_trailing_slash(): app = APIGatewayRestResolver() @app.get("/my/path") - def get_lambda() -> Response: + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"})) # WHEN calling the event handler @@ -141,8 +142,8 @@ def test_api_gateway_v1_cookies(): cookie = Cookie(name="CookieMonster", value="MonsterCookie") @app.get("/my/path") - def get_lambda() -> Response: - assert isinstance(app.current_event, APIGatewayProxyEvent) + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: + assert isinstance(event, APIGatewayProxyEvent) return Response(200, content_types.TEXT_PLAIN, "Hello world", cookies=[cookie]) # WHEN calling the event handler @@ -159,8 +160,8 @@ def test_api_gateway(): app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) @app.get("/my/path") - def get_lambda() -> Response: - assert isinstance(app.current_event, APIGatewayProxyEvent) + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: + assert isinstance(event, APIGatewayProxyEvent) return Response(200, content_types.TEXT_HTML, "foo") # WHEN calling the event handler @@ -178,8 +179,8 @@ def test_api_gateway_event_path_trailing_slash(json_dump): app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) @app.get("/my/path") - def get_lambda() -> Response: - assert isinstance(app.current_event, APIGatewayProxyEvent) + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: + assert isinstance(event, APIGatewayProxyEvent) return Response(200, content_types.TEXT_HTML, "foo") # WHEN calling the event handler @@ -196,10 +197,10 @@ def test_api_gateway_v2(): app = APIGatewayHttpResolver() @app.post("/my/path") - def my_path() -> Response: - assert isinstance(app.current_event, APIGatewayProxyEventV2) - post_data = app.current_event.json_body - assert app.current_event.cookies[0] == "cookie1" + def my_path(event: APIGatewayProxyEventV2, context: LambdaContext) -> Response: + assert isinstance(event, APIGatewayProxyEventV2) + post_data = event.json_body + assert event.cookies[0] == "cookie1" return Response(200, content_types.TEXT_PLAIN, post_data["username"]) # WHEN calling the event handler @@ -218,8 +219,8 @@ def test_api_gateway_v2_http_path_trailing_slash(json_dump): app = APIGatewayHttpResolver() @app.post("/my/path") - def my_path() -> Response: - post_data = app.current_event.json_body + def my_path(event: APIGatewayProxyEventV2, context: LambdaContext) -> Response: + post_data = event.json_body return Response(200, content_types.TEXT_PLAIN, post_data["username"]) # WHEN calling the event handler @@ -238,8 +239,8 @@ def test_api_gateway_v2_cookies(): cookie = Cookie(name="CookieMonster", value="MonsterCookie") @app.post("/my/path") - def my_path() -> Response: - assert isinstance(app.current_event, APIGatewayProxyEventV2) + def my_path(event: APIGatewayProxyEventV2, context: LambdaContext) -> Response: + assert isinstance(event, APIGatewayProxyEventV2) return Response(200, content_types.TEXT_PLAIN, "Hello world", cookies=[cookie]) # WHEN calling the event handler @@ -257,7 +258,7 @@ def test_include_rule_matching(): app = ApiGatewayResolver() @app.get("//") - def get_lambda(my_id: str, name: str) -> Response: + def get_lambda(my_id: str, name: str, event: APIGatewayProxyEvent, context: LambdaContext) -> Response: assert name == "my" return Response(200, content_types.TEXT_HTML, my_id) @@ -275,23 +276,23 @@ def test_no_matches(): app = ApiGatewayResolver() @app.get("/not_matching_get") - def get_func(): + def get_func(event: APIGatewayProxyEvent, context: LambdaContext): raise RuntimeError() @app.post("/no_matching_post") - def post_func(): + def post_func(event: APIGatewayProxyEvent, context: LambdaContext): raise RuntimeError() @app.put("/no_matching_put") - def put_func(): + def put_func(event: APIGatewayProxyEvent, context: LambdaContext): raise RuntimeError() @app.delete("/no_matching_delete") - def delete_func(): + def delete_func(event: APIGatewayProxyEvent, context: LambdaContext): raise RuntimeError() @app.patch("/no_matching_patch") - def patch_func(): + def patch_func(event: APIGatewayProxyEvent, context: LambdaContext): raise RuntimeError() def handler(event, context): @@ -326,11 +327,11 @@ def test_cors(): app = ApiGatewayResolver() @app.get("/my/path", cors=True) - def with_cors() -> Response: + def with_cors(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, content_types.TEXT_HTML, "test") @app.get("/without-cors") - def without_cors() -> Response: + def without_cors(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, content_types.TEXT_HTML, "test") def handler(event, context): @@ -374,7 +375,7 @@ def test_compress(): expected_value = '{"test": "value"}' @app.get("/my/request", compress=True) - def with_compression() -> Response: + def with_compression(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, content_types.APPLICATION_JSON, expected_value) def handler(event, context): @@ -399,7 +400,7 @@ def test_base64_encode(): mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} @app.get("/my/path", compress=True) - def read_image() -> Response: + def read_image(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, "image/png", read_media("tracer_utility_showcase.png")) # WHEN calling the event handler @@ -420,7 +421,7 @@ def test_compress_no_accept_encoding(): expected_value = "Foo" @app.get("/my/path", compress=True) - def return_text() -> Response: + def return_text(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, content_types.TEXT_PLAIN, expected_value) # WHEN calling the event handler @@ -438,7 +439,7 @@ def test_compress_no_accept_encoding_null_headers(): expected_value = "Foo" @app.get("/my/path", compress=True) - def return_text() -> Response: + def return_text(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, content_types.TEXT_PLAIN, expected_value) # WHEN calling the event handler @@ -454,7 +455,7 @@ def test_cache_control_200(): app = ApiGatewayResolver() @app.get("/success", cache_control="max-age=600") - def with_cache_control() -> Response: + def with_cache_control(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, content_types.TEXT_HTML, "has 200 response") def handler(event, context): @@ -475,7 +476,7 @@ def test_cache_control_non_200(): app = ApiGatewayResolver() @app.delete("/fails", cache_control="max-age=600") - def with_cache_control_has_500() -> Response: + def with_cache_control_has_500(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(503, content_types.TEXT_HTML, "has 503 response") def handler(event, context): @@ -497,7 +498,7 @@ def test_rest_api(): expected_dict = {"foo": "value", "second": Decimal("100.01")} @app.get("/my/path") - def rest_func() -> Dict: + def rest_func(event: APIGatewayProxyEvent, context: LambdaContext) -> Dict: return expected_dict # WHEN calling the event handler @@ -515,7 +516,7 @@ def test_handling_response_type(): app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) @app.get("/my/path") - def rest_func() -> Response: + def rest_func(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response( status_code=404, content_type="used-if-not-set-in-header", @@ -547,11 +548,11 @@ def test_custom_cors_config(): event = {"path": "/cors", "httpMethod": "GET"} @app.get("/cors") - def get_with_cors(): + def get_with_cors(event: APIGatewayProxyEvent, context: LambdaContext): return {} @app.get("/another-one", cors=False) - def another_one(): + def another_one(event: APIGatewayProxyEvent, context: LambdaContext): return {} # WHEN calling the event handler @@ -616,15 +617,15 @@ def test_cors_preflight(): app = ApiGatewayResolver(cors=CORSConfig()) @app.get("/foo") - def foo_cors(): + def foo_cors(event: APIGatewayProxyEvent, context: LambdaContext): ... @app.route(method="delete", rule="/foo") - def foo_delete_cors(): + def foo_delete_cors(event: APIGatewayProxyEvent, context: LambdaContext): ... @app.post("/foo", cors=False) - def post_no_cors(): + def post_no_cors(event: APIGatewayProxyEvent, context: LambdaContext): ... # WHEN calling the handler @@ -647,7 +648,7 @@ def test_custom_preflight_response(): app = ApiGatewayResolver(cors=CORSConfig()) @app.route(method="OPTIONS", rule="/some-call", cors=True) - def custom_preflight(): + def custom_preflight(event: APIGatewayProxyEvent, context: LambdaContext): return Response( status_code=200, content_type=content_types.TEXT_HTML, @@ -656,7 +657,7 @@ def custom_preflight(): ) @app.route(method="CUSTOM", rule="/some-call", cors=True) - def custom_method(): + def custom_method(event: APIGatewayProxyEvent, context: LambdaContext): ... # WHEN calling the handler @@ -677,7 +678,7 @@ def test_service_error_responses(json_dump): # GIVEN an BadRequestError @app.get(rule="/bad-request-error", cors=False) - def bad_request_error(): + def bad_request_error(event: APIGatewayProxyEvent, context: LambdaContext): raise BadRequestError("Missing required parameter") # WHEN calling the handler @@ -692,7 +693,7 @@ def bad_request_error(): # GIVEN an UnauthorizedError @app.get(rule="/unauthorized-error", cors=False) - def unauthorized_error(): + def unauthorized_error(event: APIGatewayProxyEvent, context: LambdaContext): raise UnauthorizedError("Unauthorized") # WHEN calling the handler @@ -707,7 +708,7 @@ def unauthorized_error(): # GIVEN an NotFoundError @app.get(rule="/not-found-error", cors=False) - def not_found_error(): + def not_found_error(event: APIGatewayProxyEvent, context: LambdaContext): raise NotFoundError # WHEN calling the handler @@ -722,7 +723,7 @@ def not_found_error(): # GIVEN an InternalServerError @app.get(rule="/internal-server-error", cors=False) - def internal_server_error(): + def internal_server_error(event: APIGatewayProxyEvent, context: LambdaContext): raise InternalServerError("Internal server error") # WHEN calling the handler @@ -737,7 +738,7 @@ def internal_server_error(): # GIVEN an ServiceError with a custom status code @app.get(rule="/service-error", cors=True) - def service_error(): + def service_error(event: APIGatewayProxyEvent, context: LambdaContext): raise ServiceError(502, "Something went wrong!") # WHEN calling the handler @@ -759,7 +760,7 @@ def test_debug_unhandled_exceptions_debug_on(): assert app._debug @app.get("/raises-error") - def raises_error(): + def raises_error(event: APIGatewayProxyEvent, context: LambdaContext): raise RuntimeError("Foo") # WHEN calling the handler @@ -781,7 +782,7 @@ def test_debug_unhandled_exceptions_debug_off(): assert not app._debug @app.get("/raises-error") - def raises_error(): + def raises_error(event: APIGatewayProxyEvent, context: LambdaContext): raise RuntimeError("Foo") # WHEN calling the handler @@ -809,7 +810,7 @@ def test_debug_json_formatting(json_dump): response = {"message": "Foo"} @app.get("/foo") - def foo(): + def foo(event: APIGatewayProxyEvent, context: LambdaContext): return response # WHEN calling the handler @@ -841,17 +842,17 @@ def test_similar_dynamic_routes(): # WHEN # r'^/accounts/(?P\\w+\\b)$' # noqa: E800 @app.get("/accounts/") - def get_account(account_id: str): + def get_account(account_id: str, event: APIGatewayProxyEvent, context: LambdaContext): assert account_id == "single_account" # r'^/accounts/(?P\\w+\\b)/source_networks$' # noqa: E800 @app.get("/accounts//source_networks") - def get_account_networks(account_id: str): + def get_account_networks(account_id: str, event: APIGatewayProxyEvent, context: LambdaContext): assert account_id == "nested_account" # r'^/accounts/(?P\\w+\\b)/source_networks/(?P\\w+\\b)$' # noqa: E800 @app.get("/accounts//source_networks/") - def get_network_account(account_id: str, network_id: str): + def get_network_account(account_id: str, network_id: str, event: APIGatewayProxyEvent, context: LambdaContext): assert account_id == "nested_account" assert network_id == "network" @@ -877,17 +878,17 @@ def test_similar_dynamic_routes_with_whitespaces(): # WHEN # r'^/accounts/(?P\\w+\\b)$' # noqa: E800 @app.get("/accounts/") - def get_account(account_id: str): + def get_account(account_id: str, event: APIGatewayProxyEvent, context: LambdaContext): assert account_id == "single account" # r'^/accounts/(?P\\w+\\b)/source_networks$' # noqa: E800 @app.get("/accounts//source_networks") - def get_account_networks(account_id: str): + def get_account_networks(account_id: str, event: APIGatewayProxyEvent, context: LambdaContext): assert account_id == "nested account" # r'^/accounts/(?P\\w+\\b)/source_networks/(?P\\w+\\b)$' # noqa: E800 @app.get("/accounts//source_networks/") - def get_network_account(account_id: str, network_id: str): + def get_network_account(account_id: str, network_id: str, event: APIGatewayProxyEvent, context: LambdaContext): assert account_id == "nested account" assert network_id == "network 123" @@ -921,7 +922,7 @@ def test_non_word_chars_route(req): # WHEN @app.get("/accounts/") - def get_account(account_id: str): + def get_account(account_id: str, event: APIGatewayProxyEvent, context: LambdaContext): assert account_id == f"{req}" # THEN @@ -956,7 +957,7 @@ class Color(Enum): BLUE = 2 @app.get("/colors") - def get_color() -> Dict: + def get_color(event: APIGatewayProxyEvent, context: LambdaContext) -> Dict: return { "color": Color.RED, "variations": {"light", "dark"}, @@ -985,11 +986,11 @@ def test_remove_prefix(path: str): app = ApiGatewayResolver(strip_prefixes=["/pay", "/payment"]) @app.get("/pay/foo") - def pay_foo(): + def pay_foo(event: APIGatewayProxyEvent, context: LambdaContext): raise ValueError("should not be matching") @app.get("/foo") - def foo(): + def foo(event: APIGatewayProxyEvent, context: LambdaContext): ... # WHEN calling handler @@ -1014,7 +1015,7 @@ def test_ignore_invalid(prefix): app = ApiGatewayResolver(strip_prefixes=prefix) @app.get("/foo/status") - def foo(): + def foo(event: APIGatewayProxyEvent, context: LambdaContext): ... # WHEN calling handler @@ -1032,7 +1033,7 @@ def test_api_gateway_v2_raw_path(): event = {"rawPath": "/dev/foo", "requestContext": {"http": {"method": "GET"}, "stage": "dev"}} @app.get("/foo") - def foo(): + def foo(event: APIGatewayProxyEvent, context: LambdaContext): return {} # WHEN calling the event handler @@ -1050,7 +1051,7 @@ def test_api_gateway_request_path_equals_strip_prefix(): event = {"httpMethod": "GET", "path": "/foo"} @app.get("/") - def base(): + def base(event: APIGatewayProxyEvent, context: LambdaContext): return {} # WHEN calling the event handler @@ -1068,7 +1069,7 @@ def test_api_gateway_app_router(): router = Router() @router.get("/my/path") - def foo(): + def foo(event: APIGatewayProxyEvent, context: LambdaContext): return {} app.include_router(router) @@ -1091,9 +1092,9 @@ def test_api_gateway_app_router_with_params(): lambda_context = {} @router.route(rule="/accounts/", method=["GET", "POST"]) - def foo(account_id): - assert router.current_event.raw_event == event - assert router.lambda_context == lambda_context + def foo(account_id, event: APIGatewayProxyEvent, context: LambdaContext): + assert router.current_event == event + assert context == lambda_context assert account_id == f"{req}" return {} @@ -1113,7 +1114,7 @@ def test_api_gateway_app_router_with_prefix(): router = Router() @router.get(rule="/path") - def foo(): + def foo(event: APIGatewayProxyEvent, context: LambdaContext): return {} app.include_router(router, prefix="/my") @@ -1132,7 +1133,7 @@ def test_api_gateway_app_router_with_prefix_equals_path(): router = Router() @router.get(rule="/") - def foo(): + def foo(event: APIGatewayProxyEvent, context: LambdaContext): return {} app.include_router(router, prefix="/my/path") @@ -1201,15 +1202,15 @@ def test_duplicate_routes(): router = Router() @router.get("/my/path") - def get_func_duplicate(): + def get_func_duplicate(event: APIGatewayProxyEvent, context: LambdaContext): raise RuntimeError() @app.get("/my/path") - def get_func(): + def get_func(event: APIGatewayProxyEvent, context: LambdaContext): return {} @router.get("/my/path") - def get_func_another_duplicate(): + def get_func_another_duplicate(event: APIGatewayProxyEvent, context: LambdaContext): raise RuntimeError() app.include_router(router) @@ -1239,8 +1240,8 @@ def test_route_multiple_methods(): lambda_context = {} @app.route(rule="/accounts/", method=["GET", "POST"]) - def foo(account_id): - assert app.lambda_context == lambda_context + def foo(account_id, event: APIGatewayProxyEvent, context: LambdaContext): + assert context == lambda_context assert account_id == f"{req}" return {} @@ -1264,7 +1265,7 @@ def test_api_gateway_app_router_access_to_resolver(): router = Router() @router.get("/my/path") - def foo(): + def foo(event: APIGatewayProxyEvent, context: LambdaContext): # WHEN accessing the api resolver instance via the router # THEN it is accessible and equal to the instantiated api resolver assert app == router.api_resolver @@ -1291,7 +1292,7 @@ def handle_value_error(ex: ValueError): ) @app.get("/my/path") - def get_lambda() -> Response: + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: raise ValueError("Foo!") # WHEN calling the event handler @@ -1318,7 +1319,7 @@ def service_error(ex: ServiceError): ) @app.get("/my/path") - def get_lambda() -> Response: + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: raise InternalServerError("Something sensitive") # WHEN calling the event handler @@ -1375,7 +1376,7 @@ def client_error(ex: ValueError): raise BadRequestError("Bad request") @app.get("/my/path") - def get_lambda() -> Response: + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: raise ValueError("foo") # WHEN calling the event handler @@ -1399,11 +1400,11 @@ def multiple_error(ex: Exception): raise BadRequestError("Bad request") @app.get("/path/a") - def path_a() -> Response: + def path_a(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: raise ValueError("foo") @app.get("/path/b") - def path_b() -> Response: + def path_b(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: raise NotFoundError # WHEN calling the app generating each exception @@ -1429,11 +1430,11 @@ def multiple_error(ex: Exception): raise BadRequestError("Bad request") @app.get("/path/a") - def path_a() -> Response: + def path_a(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: raise ValueError("foo") @app.get("/path/b") - def path_b() -> Response: + def path_b(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: raise NotFoundError # WHEN calling the app generating each exception @@ -1453,8 +1454,8 @@ def test_event_source_compatibility(): app = APIGatewayHttpResolver() @app.post("/my/path") - def my_path(): - assert isinstance(app.current_event, APIGatewayProxyEventV2) + def my_path(event: APIGatewayProxyEventV2, context: LambdaContext): + assert isinstance(event, APIGatewayProxyEventV2) return {} # WHEN @@ -1493,7 +1494,7 @@ def test_route_context_is_cleared_after_resolve(): app.append_context(is_admin=True) @app.get("/my/path") - def my_path(): + def my_path(event: APIGatewayProxyEvent, context: LambdaContext): return {"is_admin": app.context["is_admin"]} # WHEN event resolution kicks in @@ -1510,7 +1511,7 @@ def test_router_has_access_to_app_context(json_dump): ctx = {"is_admin": True} @router.get("/my/path") - def my_path(): + def my_path(event: APIGatewayProxyEvent, context: LambdaContext): return {"is_admin": router.context["is_admin"]} app.include_router(router) @@ -1545,7 +1546,7 @@ def test_nested_app_decorator(): @app.get("/my/path") @app.get("/my/anotherPath") - def get_lambda() -> Response: + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"})) # WHEN calling the event handler @@ -1566,7 +1567,7 @@ def test_nested_router_decorator(): @router.get("/my/path") @router.get("/my/anotherPath") - def get_lambda() -> Response: + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext) -> Response: return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"})) app.include_router(router) @@ -1586,7 +1587,7 @@ def test_dict_response(): app = ApiGatewayResolver() @app.get("/lambda") - def get_message(): + def get_message(event: APIGatewayProxyEvent, context: LambdaContext): return {"message": "success"} # WHEN calling handler @@ -1604,7 +1605,7 @@ def test_dict_response_with_status_code(): app = ApiGatewayResolver() @app.get("/lambda") - def get_message(): + def get_message(event: APIGatewayProxyEvent, context: LambdaContext): return {"message": "success"}, 201 # WHEN calling handler diff --git a/tests/functional/event_handler/test_lambda_function_url.py b/tests/functional/event_handler/test_lambda_function_url.py index 41baed68a7c..c87f4679335 100644 --- a/tests/functional/event_handler/test_lambda_function_url.py +++ b/tests/functional/event_handler/test_lambda_function_url.py @@ -5,6 +5,7 @@ ) from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.utilities.data_classes import LambdaFunctionUrlEvent +from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext from tests.functional.utils import load_event @@ -13,7 +14,7 @@ def test_lambda_function_url_event(): app = LambdaFunctionUrlResolver() @app.post("/my/path") - def foo(): + def foo(event: LambdaFunctionUrlEvent, context: LambdaContext): assert isinstance(app.current_event, LambdaFunctionUrlEvent) assert app.lambda_context == {} assert app.current_event.request_context.stage is not None @@ -35,7 +36,7 @@ def test_lambda_function_url_event_path_trailing_slash(): app = LambdaFunctionUrlResolver() @app.post("/my/path") - def foo(): + def foo(event: LambdaFunctionUrlEvent, context: LambdaContext): return Response(200, content_types.TEXT_HTML, "foo") # WHEN calling the event handler with an event with a trailing slash @@ -52,7 +53,7 @@ def test_lambda_function_url_event_with_cookies(): cookie = Cookie(name="CookieMonster", value="MonsterCookie") @app.get("/") - def foo(): + def foo(event: LambdaFunctionUrlEvent, context: LambdaContext): assert isinstance(app.current_event, LambdaFunctionUrlEvent) assert app.lambda_context == {} return Response(200, content_types.TEXT_PLAIN, "foo", cookies=[cookie]) @@ -71,7 +72,7 @@ def test_lambda_function_url_no_matches(): app = LambdaFunctionUrlResolver() @app.post("/no_match") - def foo(): + def foo(event: LambdaFunctionUrlEvent, context: LambdaContext): raise RuntimeError() # WHEN calling the event handler diff --git a/tests/functional/event_handler/test_router.py b/tests/functional/event_handler/test_router.py index d96f5035114..4fbb5c4aefe 100644 --- a/tests/functional/event_handler/test_router.py +++ b/tests/functional/event_handler/test_router.py @@ -17,6 +17,7 @@ APIGatewayProxyEventV2, LambdaFunctionUrlEvent, ) +from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext from tests.functional.utils import load_event @@ -25,7 +26,7 @@ def test_alb_router_event_type(): router = ALBRouter() @router.route(rule="/lambda", method=["GET"]) - def foo(): + def foo(event: ALBEvent, context: LambdaContext): assert type(router.current_event) is ALBEvent return Response(status_code=200, body="routed") @@ -39,7 +40,7 @@ def test_apigateway_router_event_type(): router = APIGatewayRouter() @router.route(rule="/my/path", method=["GET"]) - def foo(): + def foo(event: APIGatewayProxyEvent, context: LambdaContext): assert type(router.current_event) is APIGatewayProxyEvent return Response(status_code=200, body="routed") @@ -53,7 +54,7 @@ def test_apigatewayhttp_router_event_type(): router = APIGatewayHttpRouter() @router.route(rule="/my/path", method=["POST"]) - def foo(): + def foo(event: APIGatewayProxyEventV2, context: LambdaContext): assert type(router.current_event) is APIGatewayProxyEventV2 return Response(status_code=200, body="routed") @@ -67,7 +68,7 @@ def test_lambda_function_url_router_event_type(): router = LambdaFunctionUrlRouter() @router.route(rule="/", method=["GET"]) - def foo(): + def foo(event: LambdaFunctionUrlEvent, context: LambdaContext): assert type(router.current_event) is LambdaFunctionUrlEvent return Response(status_code=200, body="routed")