diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 67219c3e21f..9dba4219a95 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -3,7 +3,7 @@ import re import zlib from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 @@ -12,14 +12,11 @@ class ProxyEventType(Enum): - """An enumerations of the supported proxy event types. + """An enumerations of the supported proxy event types.""" - **NOTE:** api_gateway is an alias of http_api_v1""" - - http_api_v1 = "APIGatewayProxyEvent" - http_api_v2 = "APIGatewayProxyEventV2" - alb_event = "ALBEvent" - api_gateway = http_api_v1 + APIGatewayProxyEvent = "APIGatewayProxyEvent" + APIGatewayProxyEventV2 = "APIGatewayProxyEventV2" + ALBEvent = "ALBEvent" class CORSConfig(object): @@ -236,7 +233,7 @@ class ApiGatewayResolver: current_event: BaseProxyEvent lambda_context: LambdaContext - def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1, cors: CORSConfig = None): + def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors: CORSConfig = None): """ Parameters ---------- @@ -310,9 +307,9 @@ def _compile_regex(rule: str): def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: """Convert the event dict to the corresponding data class""" - if self._proxy_type == ProxyEventType.http_api_v1: + if self._proxy_type == ProxyEventType.APIGatewayProxyEvent: return APIGatewayProxyEvent(event) - if self._proxy_type == ProxyEventType.http_api_v2: + if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2: return APIGatewayProxyEventV2(event) return ALBEvent(event) @@ -327,9 +324,9 @@ def _resolve(self) -> ResponseBuilder: if match: return self._call_route(route, match.groupdict()) - return self._not_found(method, path) + return self._not_found(method) - def _not_found(self, method: str, path: str) -> ResponseBuilder: + def _not_found(self, method: str) -> ResponseBuilder: """Called when no matching route was found and includes support for the cors preflight response""" headers = {} if self._cors: @@ -344,7 +341,7 @@ def _not_found(self, method: str, path: str) -> ResponseBuilder: status_code=404, content_type="application/json", headers=headers, - body=json.dumps({"message": f"No route found for '{method}.{path}'"}), + body=json.dumps({"message": "Not found"}), ) ) @@ -353,12 +350,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: return ResponseBuilder(self._to_response(route.func(**args)), route) @staticmethod - def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response: + def _to_response(result: Union[Dict, Response]) -> Response: """Convert the route's result to a Response - 3 main result types are supported: + 2 main result types are supported: - - Tuple[int, str, bytes] and Tuple[int, str, str]: status code, content-type and body (str|bytes) - Dict[str, Any]: Rest api response with just the Dict to json stringify and content-type is set to application/json - Response: returned as is, and allows for more flexibility diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 7ee1785f9d0..860a9918e47 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -42,7 +42,7 @@ from aws_lambda_powertools.event_handler.api_gateway import ( tracer = Tracer() # Other supported proxy_types: "APIGatewayProxyEvent", "APIGatewayProxyEventV2", "ALBEvent" app = ApiGatewayResolver( - proxy_type=ProxyEventType.http_api_v1, + proxy_type=ProxyEventType.APIGatewayProxyEvent, cors=CORSConfig( allow_origin="https://www.example.com/", expose_headers=["x-exposed-response-header"], @@ -52,24 +52,28 @@ app = ApiGatewayResolver( ) ) + @app.get("/foo", compress=True) def get_foo() -> Tuple[int, str, str]: # Matches on http GET and proxy path "/foo" # and return status code: 200, content-type: text/html and body: Hello return 200, "text/html", "Hello" + @app.get("/logo.png") def get_logo() -> Tuple[int, str, bytes]: # Base64 encodes the return bytes body automatically logo: bytes = load_logo() return 200, "image/png", logo + @app.post("/make_foo", cors=True) def make_foo() -> Tuple[int, str, str]: # Matches on http POST and proxy path "/make_foo" - post_data: dict = app. current_event.json_body + post_data: dict = app.current_event.json_body return 200, "application/json", json.dumps(post_data["value"]) + @app.delete("/delete/") def delete_foo(uid: str) -> Tuple[int, str, str]: # Matches on http DELETE and proxy path starting with "/delete/" @@ -78,16 +82,19 @@ def delete_foo(uid: str) -> Tuple[int, str, str]: assert app.current_event.request_context.authorizer.claims["username"] == "Mike" return 200, "application/json", json.dumps({"id": uid}) + @app.get("/hello/") def hello_user(username: str) -> Tuple[int, str, str]: return 200, "text/html", f"Hello {username}!" + @app.get("/rest") def rest_fun() -> Dict: # Returns a statusCode: 200, Content-Type: application/json and json.dumps dict # and handles the serialization of decimals to json string return {"message": "Example", "second": Decimal("100.01")} + @app.get("/foo3") def foo3() -> Response: return Response( @@ -97,6 +104,7 @@ def foo3() -> Response: body=json.dumps({"message": "Foo3"}), ) + @tracer.capture_lambda_handler def lambda_handler(event, context) -> Dict: return app.resolve(event, context) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index c9446003163..05c74895eea 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -3,7 +3,7 @@ import zlib from decimal import Decimal from pathlib import Path -from typing import Dict, Tuple +from typing import Dict from aws_lambda_powertools.event_handler.api_gateway import ( ApiGatewayResolver, @@ -29,10 +29,10 @@ def read_media(file_name: str) -> bytes: def test_alb_event(): # GIVEN a Application Load Balancer proxy type event - app = ApiGatewayResolver(proxy_type=ProxyEventType.alb_event) + app = ApiGatewayResolver(proxy_type=ProxyEventType.ALBEvent) @app.get("/lambda") - def foo() -> Tuple[int, str, str]: + def foo(): assert isinstance(app.current_event, ALBEvent) assert app.lambda_context == {} return 200, TEXT_HTML, "foo" @@ -49,13 +49,13 @@ def foo() -> Tuple[int, str, str]: def test_api_gateway_v1(): # GIVEN a Http API V1 proxy type event - app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) @app.get("/my/path") - def get_lambda() -> Tuple[int, str, str]: + def get_lambda() -> Response: assert isinstance(app.current_event, APIGatewayProxyEvent) assert app.lambda_context == {} - return 200, APPLICATION_JSON, json.dumps({"foo": "value"}) + return Response(200, APPLICATION_JSON, json.dumps({"foo": "value"})) # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) @@ -68,12 +68,12 @@ def get_lambda() -> Tuple[int, str, str]: def test_api_gateway(): # GIVEN a Rest API Gateway proxy type event - app = ApiGatewayResolver(proxy_type=ProxyEventType.api_gateway) + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) @app.get("/my/path") - def get_lambda() -> Tuple[int, str, str]: + def get_lambda() -> Response: assert isinstance(app.current_event, APIGatewayProxyEvent) - return 200, TEXT_HTML, "foo" + return Response(200, TEXT_HTML, "foo") # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) @@ -87,13 +87,13 @@ def get_lambda() -> Tuple[int, str, str]: def test_api_gateway_v2(): # GIVEN a Http API V2 proxy type event - app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v2) + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEventV2) @app.post("/my/path") - def my_path() -> Tuple[int, str, str]: + def my_path() -> Response: assert isinstance(app.current_event, APIGatewayProxyEventV2) post_data = app.current_event.json_body - return 200, "plain/text", post_data["username"] + return Response(200, "plain/text", post_data["username"]) # WHEN calling the event handler result = app(load_event("apiGatewayProxyV2Event.json"), {}) @@ -110,9 +110,9 @@ def test_include_rule_matching(): app = ApiGatewayResolver() @app.get("//") - def get_lambda(my_id: str, name: str) -> Tuple[int, str, str]: + def get_lambda(my_id: str, name: str) -> Response: assert name == "my" - return 200, TEXT_HTML, my_id + return Response(200, TEXT_HTML, my_id) # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) @@ -179,8 +179,8 @@ def test_cors(): app = ApiGatewayResolver() @app.get("/my/path", cors=True) - def with_cors() -> Tuple[int, str, str]: - return 200, TEXT_HTML, "test" + def with_cors() -> Response: + return Response(200, TEXT_HTML, "test") def handler(event, context): return app.resolve(event, context) @@ -205,8 +205,8 @@ def test_compress(): expected_value = '{"test": "value"}' @app.get("/my/request", compress=True) - def with_compression() -> Tuple[int, str, str]: - return 200, APPLICATION_JSON, expected_value + def with_compression() -> Response: + return Response(200, APPLICATION_JSON, expected_value) def handler(event, context): return app.resolve(event, context) @@ -230,8 +230,8 @@ def test_base64_encode(): mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} @app.get("/my/path", compress=True) - def read_image() -> Tuple[int, str, bytes]: - return 200, "image/png", read_media("idempotent_sequence_exception.png") + def read_image() -> Response: + return Response(200, "image/png", read_media("idempotent_sequence_exception.png")) # WHEN calling the event handler result = app(mock_event, None) @@ -251,8 +251,8 @@ def test_compress_no_accept_encoding(): expected_value = "Foo" @app.get("/my/path", compress=True) - def return_text() -> Tuple[int, str, str]: - return 200, "text/plain", expected_value + def return_text() -> Response: + return Response(200, "text/plain", expected_value) # WHEN calling the event handler result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None) @@ -267,8 +267,8 @@ def test_cache_control_200(): app = ApiGatewayResolver() @app.get("/success", cache_control="max-age=600") - def with_cache_control() -> Tuple[int, str, str]: - return 200, TEXT_HTML, "has 200 response" + def with_cache_control() -> Response: + return Response(200, TEXT_HTML, "has 200 response") def handler(event, context): return app.resolve(event, context) @@ -288,8 +288,8 @@ def test_cache_control_non_200(): app = ApiGatewayResolver() @app.delete("/fails", cache_control="max-age=600") - def with_cache_control_has_500() -> Tuple[int, str, str]: - return 503, TEXT_HTML, "has 503 response" + def with_cache_control_has_500() -> Response: + return Response(503, TEXT_HTML, "has 503 response") def handler(event, context): return app.resolve(event, context) @@ -306,7 +306,7 @@ def handler(event, context): def test_rest_api(): # GIVEN a function that returns a Dict - app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) expected_dict = {"foo": "value", "second": Decimal("100.01")} @app.get("/my/path") @@ -325,7 +325,7 @@ def rest_func() -> Dict: def test_handling_response_type(): # GIVEN a function that returns Response - app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) @app.get("/my/path") def rest_func() -> Response: