Skip to content

docs(api_gateway): new event handler for API Gateway and ALB #418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 195 additions & 73 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import json
import logging
import re
import zlib
from enum import Enum
Expand All @@ -10,6 +11,8 @@
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)


class ProxyEventType(Enum):
"""An enumerations of the supported proxy event types."""
Expand All @@ -28,47 +31,47 @@ class CORSConfig(object):

Simple cors example using the default permissive cors, not this should only be used during early prototyping

>>> from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver
>>>
>>> app = ApiGatewayResolver()
>>>
>>> @app.get("/my/path", cors=True)
>>> def with_cors():
>>> return {"message": "Foo"}
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

app = ApiGatewayResolver()

@app.get("/my/path", cors=True)
def with_cors():
return {"message": "Foo"}

Using a custom CORSConfig where `with_cors` used the custom provided CORSConfig and `without_cors`
do not include any cors headers.

>>> from aws_lambda_powertools.event_handler.api_gateway import (
>>> ApiGatewayResolver, CORSConfig
>>> )
>>>
>>> cors_config = CORSConfig(
>>> allow_origin="https://wwww.example.com/",
>>> expose_headers=["x-exposed-response-header"],
>>> allow_headers=["x-custom-request-header"],
>>> max_age=100,
>>> allow_credentials=True,
>>> )
>>> app = ApiGatewayResolver(cors=cors_config)
>>>
>>> @app.get("/my/path", cors=True)
>>> def with_cors():
>>> return {"message": "Foo"}
>>>
>>> @app.get("/another-one")
>>> def without_cors():
>>> return {"message": "Foo"}
from aws_lambda_powertools.event_handler.api_gateway import (
ApiGatewayResolver, CORSConfig
)

cors_config = CORSConfig(
allow_origin="https://wwww.example.com/",
expose_headers=["x-exposed-response-header"],
allow_headers=["x-custom-request-header"],
max_age=100,
allow_credentials=True,
)
app = ApiGatewayResolver(cors=cors_config)

@app.get("/my/path", cors=True)
def with_cors():
return {"message": "Foo"}

@app.get("/another-one")
def without_cors():
return {"message": "Foo"}
"""

_REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"]

def __init__(
self,
allow_origin: str = "*",
allow_headers: List[str] = None,
expose_headers: List[str] = None,
max_age: int = None,
allow_headers: Optional[List[str]] = None,
expose_headers: Optional[List[str]] = None,
max_age: Optional[int] = None,
allow_credentials: bool = False,
):
"""
Expand All @@ -77,13 +80,13 @@ def __init__(
allow_origin: str
The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should
only be used during development.
allow_headers: str
allow_headers: Optional[List[str]]
The list of additional allowed headers. This list is added to list of
built in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`,
`X-Api-Key`, `X-Amz-Security-Token`.
expose_headers: str
expose_headers: Optional[List[str]]
A list of values to return for the Access-Control-Expose-Headers
max_age: int
max_age: Optional[int]
The value for the `Access-Control-Max-Age`
allow_credentials: bool
A boolean value that sets the value of `Access-Control-Allow-Credentials`
Expand Down Expand Up @@ -170,6 +173,7 @@ def _compress(self):
"""Compress the response body, but only if `Accept-Encoding` headers includes gzip."""
self.response.headers["Content-Encoding"] = "gzip"
if isinstance(self.response.body, str):
logger.debug("Converting string response to bytes before compressing it")
self.response.body = bytes(self.response.body, "utf-8")
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
self.response.body = gzip.compress(self.response.body) + gzip.flush()
Expand All @@ -190,6 +194,7 @@ def build(self, event: BaseProxyEvent, cors: CORSConfig = None) -> Dict[str, Any
self._route(event, cors)

if isinstance(self.response.body, bytes):
logger.debug("Encoding bytes response with base64")
self.response.base64_encoded = True
self.response.body = base64.b64encode(self.response.body).decode()
return {
Expand All @@ -207,27 +212,26 @@ class ApiGatewayResolver:
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

>>> from aws_lambda_powertools import Tracer
>>> from aws_lambda_powertools.event_handler.api_gateway import (
>>> ApiGatewayResolver
>>> )
>>>
>>> tracer = Tracer()
>>> app = ApiGatewayResolver()
>>>
>>> @app.get("/get-call")
>>> def simple_get():
>>> return {"message": "Foo"}
>>>
>>> @app.post("/post-call")
>>> def simple_post():
>>> post_data: dict = app.current_event.json_body
>>> return {"message": post_data["value"]}
>>>
>>> @tracer.capture_lambda_handler
>>> def lambda_handler(event, context):
>>> return app.resolve(event, context)
```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.get("/get-call")
def simple_get():
return {"message": "Foo"}

@app.post("/post-call")
def simple_post():
post_data: dict = app.current_event.json_body
return {"message": post_data["value"]}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""

current_event: BaseProxyEvent
Expand All @@ -247,32 +251,144 @@ def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors:
self._cors = cors
self._cors_methods: Set[str] = {"OPTIONS"}

def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
"""Get route decorator with GET `method`"""
def get(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
"""Get route decorator with GET `method`

Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.get("/get-call")
def simple_get():
return {"message": "Foo"}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
return self.route(rule, "GET", cors, compress, cache_control)

def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
"""Post route decorator with POST `method`"""
def post(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
"""Post route decorator with POST `method`

Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.post("/post-call")
def simple_post():
post_data: dict = app.current_event.json_body
return {"message": post_data["value"]}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
return self.route(rule, "POST", cors, compress, cache_control)

def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
"""Put route decorator with PUT `method`"""
def put(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
"""Put route decorator with PUT `method`

Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.put("/put-call")
def simple_post():
put_data: dict = app.current_event.json_body
return {"message": put_data["value"]}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
return self.route(rule, "PUT", cors, compress, cache_control)

def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
"""Delete route decorator with DELETE `method`"""
def delete(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
"""Delete route decorator with DELETE `method`

Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.delete("/delete-call")
def simple_delete():
return {"message": "deleted"}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
return self.route(rule, "DELETE", cors, compress, cache_control)

def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
"""Patch route decorator with PATCH `method`"""
def patch(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
"""Patch route decorator with PATCH `method`

Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.patch("/patch-call")
def simple_patch():
patch_data: dict = app.current_event.json_body
patch_data["value"] = patched

return {"message": patch_data}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
return self.route(rule, "PATCH", cors, compress, cache_control)

def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None):
def route(self, rule: str, method: str, cors: bool = True, compress: bool = False, cache_control: str = None):
"""Route decorator includes parameter `method`"""

def register_resolver(func: Callable):
logger.debug(f"Adding route using rule {rule} and method {method.upper()}")
self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control))
if cors:
logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS")
self._cors_methods.add(method.upper())
return func

Expand Down Expand Up @@ -308,9 +424,12 @@ def _compile_regex(rule: str):
def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
"""Convert the event dict to the corresponding data class"""
if self._proxy_type == ProxyEventType.APIGatewayProxyEvent:
logger.debug("Converting event to API Gateway REST API contract")
return APIGatewayProxyEvent(event)
if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2:
logger.debug("Converting event to API Gateway HTTP API contract")
return APIGatewayProxyEventV2(event)
logger.debug("Converting event to ALB contract")
return ALBEvent(event)

def _resolve(self) -> ResponseBuilder:
Expand All @@ -322,17 +441,21 @@ def _resolve(self) -> ResponseBuilder:
continue
match: Optional[re.Match] = route.rule.match(path)
if match:
logger.debug("Found a registered route. Calling function")
return self._call_route(route, match.groupdict())

logger.debug(f"No match found for path {path} and method {method}")
return self._not_found(method)

def _not_found(self, method: str) -> ResponseBuilder:
"""Called when no matching route was found and includes support for the cors preflight response"""
headers = {}
if self._cors:
logger.debug("CORS is enabled, updating headers.")
headers.update(self._cors.to_dict())

if method == "OPTIONS": # Preflight
if method == "OPTIONS": # Pre-flight
logger.debug("Pre-flight request detected. Returning CORS with null response")
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))

Expand Down Expand Up @@ -361,11 +484,10 @@ def _to_response(result: Union[Dict, Response]) -> Response:
"""
if isinstance(result, Response):
return result
elif isinstance(result, dict):
return Response(
status_code=200,
content_type="application/json",
body=json.dumps(result, separators=(",", ":"), cls=Encoder),
)
else: # Tuple[int, str, Union[bytes, str]]
return Response(*result)

logger.debug("Simple response detected, serializing return before constructing final response")
return Response(
status_code=200,
content_type="application/json",
body=json.dumps(result, separators=(",", ":"), cls=Encoder),
)
Loading