diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index fc744055e6c..67219c3e21f 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -12,6 +12,10 @@ class ProxyEventType(Enum): + """An enumerations of the supported proxy event types. + + **NOTE:** api_gateway is an alias of http_api_v1""" + http_api_v1 = "APIGatewayProxyEvent" http_api_v2 = "APIGatewayProxyEventV2" alb_event = "ALBEvent" @@ -19,7 +23,46 @@ class ProxyEventType(Enum): class CORSConfig(object): - """CORS Config""" + """CORS Config + + + Examples + -------- + + Simple cors example using the default permissive cors, not this should only be used during early prototyping + + >>> from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver + >>> + >>> app = ApiGatewayResolver() + >>> + >>> @app.get("/my/path", cors=True) + >>> def with_cors(): + >>> return {"message": "Foo"} + + Using a custom CORSConfig where `with_cors` used the custom provided CORSConfig and `without_cors` + do not include any cors headers. + + >>> from aws_lambda_powertools.event_handler.api_gateway import ( + >>> ApiGatewayResolver, CORSConfig + >>> ) + >>> + >>> cors_config = CORSConfig( + >>> allow_origin="https://wwww.example.com/", + >>> expose_headers=["x-exposed-response-header"], + >>> allow_headers=["x-custom-request-header"], + >>> max_age=100, + >>> allow_credentials=True, + >>> ) + >>> app = ApiGatewayResolver(cors=cors_config) + >>> + >>> @app.get("/my/path", cors=True) + >>> def with_cors(): + >>> return {"message": "Foo"} + >>> + >>> @app.get("/another-one") + >>> def without_cors(): + >>> return {"message": "Foo"} + """ _REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"] @@ -55,6 +98,7 @@ def __init__( self.allow_credentials = allow_credentials def to_dict(self) -> Dict[str, str]: + """Builds the configured Access-Control http headers""" headers = { "Access-Control-Allow-Origin": self.allow_origin, "Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)), @@ -68,7 +112,37 @@ def to_dict(self) -> Dict[str, str]: return headers +class Response: + """Response data class that provides greater control over what is returned from the proxy event""" + + def __init__( + self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None + ): + """ + + Parameters + ---------- + status_code: int + Http status code, example 200 + content_type: str + Optionally set the Content-Type header, example "application/json". Note this will be merged into any + provided http headers + body: Union[str, bytes, None] + Optionally set the response body. Note: bytes body will be automatically base64 encoded + headers: dict + Optionally set specific http headers. Setting "Content-Type" hear would override the `content_type` value. + """ + self.status_code = status_code + self.body = body + self.base64_encoded = False + self.headers: Dict = headers or {} + if content_type: + self.headers.setdefault("Content-Type", content_type) + + class Route: + """Internally used Route Configuration""" + def __init__( self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str] ): @@ -80,68 +154,125 @@ def __init__( self.cache_control = cache_control -class Response: - def __init__( - self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None - ): - self.status_code = status_code - self.body = body - self.base64_encoded = False - self.headers: Dict = headers or {} - if content_type: - self.headers.setdefault("Content-Type", content_type) +class ResponseBuilder: + """Internally used Response builder""" - def add_cors(self, cors: CORSConfig): - self.headers.update(cors.to_dict()) + def __init__(self, response: Response, route: Route = None): + self.response = response + self.route = route - def add_cache_control(self, cache_control: str): - self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache" + def _add_cors(self, cors: CORSConfig): + """Update headers to include the configured Access-Control headers""" + self.response.headers.update(cors.to_dict()) - def compress(self): - self.headers["Content-Encoding"] = "gzip" - if isinstance(self.body, str): - self.body = bytes(self.body, "utf-8") - gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) - self.body = gzip.compress(self.body) + gzip.flush() + def _add_cache_control(self, cache_control: str): + """Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used.""" + self.response.headers["Cache-Control"] = cache_control if self.response.status_code == 200 else "no-cache" - def to_dict(self) -> Dict[str, Any]: - if isinstance(self.body, bytes): - self.base64_encoded = True - self.body = base64.b64encode(self.body).decode() + def _compress(self): + """Compress the response body, but only if `Accept-Encoding` headers includes gzip.""" + self.response.headers["Content-Encoding"] = "gzip" + if isinstance(self.response.body, str): + self.response.body = bytes(self.response.body, "utf-8") + gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) + self.response.body = gzip.compress(self.response.body) + gzip.flush() + + def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]): + """Optionally handle any of the route's configure response handling""" + if self.route is None: + return + if self.route.cors: + self._add_cors(cors or CORSConfig()) + if self.route.cache_control: + self._add_cache_control(self.route.cache_control) + if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""): + self._compress() + + def build(self, event: BaseProxyEvent, cors: CORSConfig = None) -> Dict[str, Any]: + """Build the full response dict to be returned by the lambda""" + self._route(event, cors) + + if isinstance(self.response.body, bytes): + self.response.base64_encoded = True + self.response.body = base64.b64encode(self.response.body).decode() return { - "statusCode": self.status_code, - "headers": self.headers, - "body": self.body, - "isBase64Encoded": self.base64_encoded, + "statusCode": self.response.status_code, + "headers": self.response.headers, + "body": self.response.body, + "isBase64Encoded": self.response.base64_encoded, } class ApiGatewayResolver: + """API Gateway and ALB proxy resolver + + Examples + -------- + Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator + + >>> from aws_lambda_powertools import Tracer + >>> from aws_lambda_powertools.event_handler.api_gateway import ( + >>> ApiGatewayResolver + >>> ) + >>> + >>> tracer = Tracer() + >>> app = ApiGatewayResolver() + >>> + >>> @app.get("/get-call") + >>> def simple_get(): + >>> return {"message": "Foo"} + >>> + >>> @app.post("/post-call") + >>> def simple_post(): + >>> post_data: dict = app.current_event.json_body + >>> return {"message": post_data["value"]} + >>> + >>> @tracer.capture_lambda_handler + >>> def lambda_handler(event, context): + >>> return app.resolve(event, context) + + """ + current_event: BaseProxyEvent lambda_context: LambdaContext def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1, cors: CORSConfig = None): + """ + Parameters + ---------- + proxy_type: ProxyEventType + Proxy request type, defaults to API Gateway V1 + cors: CORSConfig + Optionally configure and enabled CORS. Not each route will need to have to cors=True + """ self._proxy_type = proxy_type self._routes: List[Route] = [] self._cors = cors self._cors_methods: Set[str] = {"OPTIONS"} def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + """Get route decorator with GET `method`""" return self.route(rule, "GET", cors, compress, cache_control) def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + """Post route decorator with POST `method`""" return self.route(rule, "POST", cors, compress, cache_control) def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + """Put route decorator with PUT `method`""" return self.route(rule, "PUT", cors, compress, cache_control) def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + """Delete route decorator with DELETE `method`""" return self.route(rule, "DELETE", cors, compress, cache_control) def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + """Patch route decorator with PATCH `method`""" return self.route(rule, "PATCH", cors, compress, cache_control) def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): + """Route decorator includes parameter `method`""" + def register_resolver(func: Callable): self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control)) if cors: @@ -151,34 +282,44 @@ def register_resolver(func: Callable): return register_resolver def resolve(self, event, context) -> Dict[str, Any]: - self.current_event = self._to_data_class(event) - self.lambda_context = context - route, response = self._find_route(self.current_event.http_method.upper(), self.current_event.path) - if route is None: # No matching route was found - return response.to_dict() + """Resolves the response based on the provide event and decorator routes - if route.cors: - response.add_cors(self._cors or CORSConfig()) - if route.cache_control: - response.add_cache_control(route.cache_control) - if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): - response.compress() + Parameters + ---------- + event: Dict[str, Any] + Event + context: LambdaContext + Lambda context + Returns + ------- + dict + Returns the dict response + """ + self.current_event = self._to_proxy_event(event) + self.lambda_context = context + return self._resolve().build(self.current_event, self._cors) - return response.to_dict() + def __call__(self, event, context) -> Any: + return self.resolve(event, context) @staticmethod def _compile_regex(rule: str): + """Precompile regex pattern""" rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) return re.compile("^{}$".format(rule_regex)) - def _to_data_class(self, event: Dict) -> BaseProxyEvent: + def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: + """Convert the event dict to the corresponding data class""" if self._proxy_type == ProxyEventType.http_api_v1: return APIGatewayProxyEvent(event) if self._proxy_type == ProxyEventType.http_api_v2: return APIGatewayProxyEventV2(event) return ALBEvent(event) - def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]: + def _resolve(self) -> ResponseBuilder: + """Resolves the response or return the not found response""" + method = self.current_event.http_method.upper() + path = self.current_event.path for route in self._routes: if method != route.method: continue @@ -186,25 +327,42 @@ def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response if match: return self._call_route(route, match.groupdict()) + return self._not_found(method, path) + + def _not_found(self, method: str, path: str) -> ResponseBuilder: + """Called when no matching route was found and includes support for the cors preflight response""" headers = {} if self._cors: headers.update(self._cors.to_dict()) + if method == "OPTIONS": # Preflight headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) - return None, Response(status_code=204, content_type=None, body=None, headers=headers) + return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None)) - return None, Response( - status_code=404, - content_type="application/json", - body=json.dumps({"message": f"No route found for '{method}.{path}'"}), - headers=headers, + return ResponseBuilder( + Response( + status_code=404, + content_type="application/json", + headers=headers, + body=json.dumps({"message": f"No route found for '{method}.{path}'"}), + ) ) - def _call_route(self, route: Route, args: Dict[str, str]) -> Tuple[Route, Response]: - return route, self._to_response(route.func(**args)) + def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: + """Actually call the matching route with any provided keyword arguments.""" + return ResponseBuilder(self._to_response(route.func(**args)), route) @staticmethod def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response: + """Convert the route's result to a Response + + 3 main result types are supported: + + - Tuple[int, str, bytes] and Tuple[int, str, str]: status code, content-type and body (str|bytes) + - Dict[str, Any]: Rest api response with just the Dict to json stringify and content-type is set to + application/json + - Response: returned as is, and allows for more flexibility + """ if isinstance(result, Response): return result elif isinstance(result, dict): @@ -215,6 +373,3 @@ def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Respons ) else: # Tuple[int, str, Union[bytes, str]] return Response(*result) - - def __call__(self, event, context) -> Any: - return self.resolve(event, context) diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md new file mode 100644 index 00000000000..7ee1785f9d0 --- /dev/null +++ b/docs/core/event_handler/api_gateway.md @@ -0,0 +1,313 @@ +--- +title: API Gateway +description: Core utility +--- + +Event handler for AWS API Gateway and Application Loader Balancers. + +### Key Features + +* Routes - `@app.get("/foo")` +* Path expressions - `@app.delete("/delete/")` +* Cors - `@app.post("/make_foo", cors=True)` or via `CORSConfig` and builtin CORS preflight route +* Base64 encode binary - `@app.get("/logo.png")` +* Gzip Compression - `@app.get("/large-json", compress=True)` +* Cache-control - `@app.get("/foo", cache_control="max-age=600")` +* Rest API simplification with function returns a Dict +* Support function returns a Response object which give fine-grained control of the headers +* JSON encoding of Decimals + +## Examples + +> TODO - Break on into smaller examples + +### All in one example + +=== "app.py" + +```python +from decimal import Decimal +import json +from typing import Dict, Tuple + +from aws_lambda_powertools import Tracer +from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent +from aws_lambda_powertools.event_handler.api_gateway import ( + ApiGatewayResolver, + CORSConfig, + ProxyEventType, + Response, +) + +tracer = Tracer() +# Other supported proxy_types: "APIGatewayProxyEvent", "APIGatewayProxyEventV2", "ALBEvent" +app = ApiGatewayResolver( + proxy_type=ProxyEventType.http_api_v1, + cors=CORSConfig( + allow_origin="https://www.example.com/", + expose_headers=["x-exposed-response-header"], + allow_headers=["x-custom-request-header"], + max_age=100, + allow_credentials=True, + ) +) + +@app.get("/foo", compress=True) +def get_foo() -> Tuple[int, str, str]: + # Matches on http GET and proxy path "/foo" + # and return status code: 200, content-type: text/html and body: Hello + return 200, "text/html", "Hello" + +@app.get("/logo.png") +def get_logo() -> Tuple[int, str, bytes]: + # Base64 encodes the return bytes body automatically + logo: bytes = load_logo() + return 200, "image/png", logo + +@app.post("/make_foo", cors=True) +def make_foo() -> Tuple[int, str, str]: + # Matches on http POST and proxy path "/make_foo" + post_data: dict = app. current_event.json_body + return 200, "application/json", json.dumps(post_data["value"]) + +@app.delete("/delete/") +def delete_foo(uid: str) -> Tuple[int, str, str]: + # Matches on http DELETE and proxy path starting with "/delete/" + assert isinstance(app.current_event, APIGatewayProxyEvent) + assert app.current_event.request_context.authorizer.claims is not None + assert app.current_event.request_context.authorizer.claims["username"] == "Mike" + return 200, "application/json", json.dumps({"id": uid}) + +@app.get("/hello/") +def hello_user(username: str) -> Tuple[int, str, str]: + return 200, "text/html", f"Hello {username}!" + +@app.get("/rest") +def rest_fun() -> Dict: + # Returns a statusCode: 200, Content-Type: application/json and json.dumps dict + # and handles the serialization of decimals to json string + return {"message": "Example", "second": Decimal("100.01")} + +@app.get("/foo3") +def foo3() -> Response: + return Response( + status_code=200, + content_type="application/json", + headers={"custom-header": "value"}, + body=json.dumps({"message": "Foo3"}), + ) + +@tracer.capture_lambda_handler +def lambda_handler(event, context) -> Dict: + return app.resolve(event, context) +``` + +### Compress examples + +=== "app.py" + + ```python + from aws_lambda_powertools.event_handler.api_gateway import ( + ApiGatewayResolver + ) + + app = ApiGatewayResolver() + + @app.get("/foo", compress=True) + def get_foo() -> Tuple[int, str, str]: + # Matches on http GET and proxy path "/foo" + # and return status code: 200, content-type: text/html and body: Hello + return 200, "text/html", "Hello" + ``` + +=== "GET /foo: request" + ```json + { + "headers": { + "Accept-Encoding": "gzip" + }, + "httpMethod": "GET", + "path": "/foo" + } + ``` + +=== "GET /foo: response" + + ```json + { + "body": "H4sIAAAAAAACE/NIzcnJBwCCidH3BQAAAA==", + "headers": { + "Content-Encoding": "gzip", + "Content-Type": "text/html" + }, + "isBase64Encoded": true, + "statusCode": 200 + } + ``` + +### CORS examples + +=== "app.py" + + ```python + from aws_lambda_powertools.event_handler.api_gateway import ( + ApiGatewayResolver, + CORSConfig, + ) + + app = ApiGatewayResolver( + proxy_type=ProxyEventType.http_api_v1, + cors=CORSConfig( + allow_origin="https://www.example.com/", + expose_headers=["x-exposed-response-header"], + allow_headers=["x-custom-request-header"], + max_age=100, + allow_credentials=True, + ) + ) + + @app.post("/make_foo", cors=True) + def make_foo() -> Tuple[int, str, str]: + # Matches on http POST and proxy path "/make_foo" + post_data: dict = app. current_event.json_body + return 200, "application/json", json.dumps(post_data["value"]) + ``` + +=== "OPTIONS /make_foo" + + ```json + { + "httpMethod": "OPTIONS", + "path": "/make_foo" + } + ``` + +=== "<< OPTIONS /make_foo" + + ```json + { + "body": null, + "headers": { + "Access-Control-Allow-Credentials": "true", + "Access-Control-Allow-Headers": "Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,X-Api-Key,x-custom-request-header", + "Access-Control-Allow-Methods": "OPTIONS,POST", + "Access-Control-Allow-Origin": "https://www.example.com/", + "Access-Control-Expose-Headers": "x-exposed-response-header", + "Access-Control-Max-Age": "100" + }, + "isBase64Encoded": false, + "statusCode": 204 + } + ``` + +=== "POST /make_foo" + + ```json + { + "body": "{\"value\": \"Hello World\"}", + "httpMethod": "POST", + "path": "/make_foo" + } + ``` + +=== "<< POST /make_foo" + + ```json + { + "body": "\"Hello World\"", + "headers": { + "Access-Control-Allow-Credentials": "true", + "Access-Control-Allow-Headers": "Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,X-Api-Key,x-custom-request-header", + "Access-Control-Allow-Origin": "https://www.example.com/", + "Access-Control-Expose-Headers": "x-exposed-response-header", + "Access-Control-Max-Age": "100", + "Content-Type": "application/json" + }, + "isBase64Encoded": false, + "statusCode": 200 + } + ``` + +### Simple rest example + +=== "app.py" + + ```python + from aws_lambda_powertools.event_handler.api_gateway import ( + ApiGatewayResolver + ) + + app = ApiGatewayResolver() + + @app.get("/rest") + def rest_fun() -> Dict: + # Returns a statusCode: 200, Content-Type: application/json and json.dumps dict + # and handles the serialization of decimals to json string + return {"message": "Example", "second": Decimal("100.01")} + ``` + +=== "GET /rest: request" + + ```json + { + "httpMethod": "GET", + "path": "/rest" + } + ``` + +=== "GET /rest: response" + + ```json + { + "body": "{\"message\":\"Example\",\"second\":\"100.01\"}", + "headers": { + "Content-Type": "application/json" + }, + "isBase64Encoded": false, + "statusCode": 200 + } + ``` + +### Custom response + +=== "app.py" + + ```python + from aws_lambda_powertools.event_handler.api_gateway import ( + ApiGatewayResolver + ) + + app = ApiGatewayResolver() + + @app.get("/foo3") + def foo3() -> Response: + return Response( + status_code=200, + content_type="application/json", + headers={"custom-header": "value"}, + body=json.dumps({"message": "Foo3"}), + ) + ``` + +=== "GET /foo3: request" + + ```json + { + "httpMethod": "GET", + "path": "/foo3" + } + ``` + +=== "GET /foo3: response" + + ```json + { + "body": "{\"message\": \"Foo3\"}", + "headers": { + "Content-Type": "application/json", + "custom-header": "value" + }, + "isBase64Encoded": false, + "statusCode": 200 + } + ``` diff --git a/mkdocs.yml b/mkdocs.yml index 43a7e125696..b07e30386dd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -14,6 +14,7 @@ nav: - core/metrics.md - Event Handler: - core/event_handler/appsync.md + - core/event_handler/api_gateway.md - Utilities: - utilities/middleware_factory.md - utilities/parameters.md diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index df13b047d0d..c9446003163 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -5,7 +5,13 @@ from pathlib import Path from typing import Dict, Tuple -from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, CORSConfig, ProxyEventType, Response +from aws_lambda_powertools.event_handler.api_gateway import ( + ApiGatewayResolver, + CORSConfig, + ProxyEventType, + Response, + ResponseBuilder, +) from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from tests.functional.utils import load_event @@ -106,14 +112,14 @@ def test_include_rule_matching(): @app.get("//") def get_lambda(my_id: str, name: str) -> Tuple[int, str, str]: assert name == "my" - return 200, "plain/html", my_id + return 200, TEXT_HTML, my_id # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) # THEN assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == "plain/html" + assert result["headers"]["Content-Type"] == TEXT_HTML assert result["body"] == "path" @@ -389,14 +395,16 @@ def another_one(): def test_no_content_response(): # GIVEN a response with no content-type or body response = Response(status_code=204, content_type=None, body=None, headers=None) + response_builder = ResponseBuilder(response) # WHEN calling to_dict - result = response.to_dict() + result = response_builder.build(APIGatewayProxyEvent(LOAD_GW_EVENT)) # THEN return an None body and no Content-Type header + assert result["statusCode"] == response.status_code assert result["body"] is None - assert result["statusCode"] == 204 - assert "Content-Type" not in result["headers"] + headers = result["headers"] + assert "Content-Type" not in headers def test_no_matches_with_cors(): @@ -413,7 +421,7 @@ def test_no_matches_with_cors(): assert "Access-Control-Allow-Origin" in result["headers"] -def test_preflight(): +def test_cors_preflight(): # GIVEN an event for an OPTIONS call that does not match any of the given routes # AND cors is enabled app = ApiGatewayResolver(cors=CORSConfig())