Skip to content

refactor(apigateway): Add BaseRouter and duplicate route check #757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 15, 2021
240 changes: 115 additions & 125 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import os
import re
import traceback
import warnings
import zlib
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial, wraps
from functools import partial
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import ServiceError
Expand Down Expand Up @@ -227,78 +229,20 @@ def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dic
}


class ApiGatewayResolver:
"""API Gateway and ALB proxy resolver

Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.get("/get-call")
def simple_get():
return {"message": "Foo"}

@app.post("/post-call")
def simple_post():
post_data: dict = app.current_event.json_body
return {"message": post_data["value"]}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""

class BaseRouter(ABC):
current_event: BaseProxyEvent
lambda_context: LambdaContext

def __init__(
@abstractmethod
def route(
self,
proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent,
cors: Optional[CORSConfig] = None,
debug: Optional[bool] = None,
serializer: Optional[Callable[[Dict], str]] = None,
strip_prefixes: Optional[List[str]] = None,
rule: str,
method: Any,
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
):
"""
Parameters
----------
proxy_type: ProxyEventType
Proxy request type, defaults to API Gateway V1
cors: CORSConfig
Optionally configure and enabled CORS. Not each route will need to have to cors=True
debug: Optional[bool]
Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG"
environment variable
serializer : Callable, optional
function to serialize `obj` to a JSON formatted `str`, by default json.dumps
strip_prefixes: List[str], optional
optional list of prefixes to be removed from the request path before doing the routing. This is often used
with api gateways with multiple custom mappings.
"""
self._proxy_type = proxy_type
self._routes: List[Route] = []
self._cors = cors
self._cors_enabled: bool = cors is not None
self._cors_methods: Set[str] = {"OPTIONS"}
self._debug = resolve_truthy_env_var_choice(
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
)
self._strip_prefixes = strip_prefixes

# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)

if self._debug:
# Always does a pretty print when in debug mode
self._serializer = partial(json.dumps, indent=4, cls=Encoder)
raise NotImplementedError()

def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
"""Get route decorator with GET `method`
Expand Down Expand Up @@ -434,6 +378,78 @@ def lambda_handler(event, context):
"""
return self.route(rule, "PATCH", cors, compress, cache_control)


class ApiGatewayResolver(BaseRouter):
"""API Gateway and ALB proxy resolver

Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.get("/get-call")
def simple_get():
return {"message": "Foo"}

@app.post("/post-call")
def simple_post():
post_data: dict = app.current_event.json_body
return {"message": post_data["value"]}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""

def __init__(
self,
proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent,
cors: Optional[CORSConfig] = None,
debug: Optional[bool] = None,
serializer: Optional[Callable[[Dict], str]] = None,
strip_prefixes: Optional[List[str]] = None,
):
"""
Parameters
----------
proxy_type: ProxyEventType
Proxy request type, defaults to API Gateway V1
cors: CORSConfig
Optionally configure and enabled CORS. Not each route will need to have to cors=True
debug: Optional[bool]
Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG"
environment variable
serializer : Callable, optional
function to serialize `obj` to a JSON formatted `str`, by default json.dumps
strip_prefixes: List[str], optional
optional list of prefixes to be removed from the request path before doing the routing. This is often used
with api gateways with multiple custom mappings.
"""
self._proxy_type = proxy_type
self._routes: List[Route] = []
self._route_keys: List[str] = []
self._cors = cors
self._cors_enabled: bool = cors is not None
self._cors_methods: Set[str] = {"OPTIONS"}
self._debug = resolve_truthy_env_var_choice(
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
)
self._strip_prefixes = strip_prefixes

# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)

if self._debug:
# Always does a pretty print when in debug mode
self._serializer = partial(json.dumps, indent=4, cls=Encoder)

def route(
self,
rule: str,
Expand All @@ -451,6 +467,10 @@ def register_resolver(func: Callable):
else:
cors_enabled = cors
self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control))
route_key = method + rule
if route_key in self._route_keys:
warnings.warn(f"A route like this was already registered. method: '{method}' rule: '{rule}'")
self._route_keys.append(route_key)
if cors_enabled:
logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS")
self._cors_methods.add(method.upper())
Expand All @@ -474,8 +494,8 @@ def resolve(self, event, context) -> Dict[str, Any]:
"""
if self._debug:
print(self._json_dump(event))
self.current_event = self._to_proxy_event(event)
self.lambda_context = context
BaseRouter.current_event = self._to_proxy_event(event)
BaseRouter.lambda_context = context
return self._resolve().build(self.current_event, self._cors)

def __call__(self, event, context) -> Any:
Expand Down Expand Up @@ -632,71 +652,41 @@ def _json_dump(self, obj: Any) -> str:
return self._serializer(obj)

def include_router(self, router: "Router", prefix: Optional[str] = None) -> None:
"""Adds all routes defined in a router"""
router._app = self
for route, func in router.api.items():
if prefix and route[0] == "/":
route = (prefix, *route[1:])
elif prefix:
route = (f"{prefix}{route[0]}", *route[1:])
self.route(*route)(func())

"""Adds all routes defined in a router

class Router:
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""
Parameters
----------
router : Router
The Router containing a list of routes to be registered after the existing routes
prefix : str, optional
An optional prefix to be added to the originally defined rule
"""
for route, func in router._routes.items():
if prefix:
rule = route[0]
rule = prefix if rule == "/" else f"{prefix}{rule}"
route = (rule, *route[1:])

_app: ApiGatewayResolver
self.route(*route)(func)

def __init__(self):
self.api: Dict[tuple, Callable] = {}

@property
def current_event(self) -> BaseProxyEvent:
return self._app.current_event
class Router(BaseRouter):
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""

@property
def lambda_context(self) -> LambdaContext:
return self._app.lambda_context
def __init__(self):
self._routes: Dict[tuple, Callable] = {}

def route(
self,
rule: str,
method: Union[str, Tuple[str], List[str]],
method: Union[str, List[str]],
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
):
def actual_decorator(func: Callable):
@wraps(func)
def wrapper():
def inner_wrapper(**kwargs):
return func(**kwargs)

return inner_wrapper

if isinstance(method, (list, tuple)):
for item in method:
self.api[(rule, item, cors, compress, cache_control)] = wrapper
else:
self.api[(rule, method, cors, compress, cache_control)] = wrapper

return actual_decorator

def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
return self.route(rule, "GET", cors, compress, cache_control)

def post(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
return self.route(rule, "POST", cors, compress, cache_control)

def put(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
return self.route(rule, "PUT", cors, compress, cache_control)

def delete(
self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None
):
return self.route(rule, "DELETE", cors, compress, cache_control)
def register_route(func: Callable):
methods = method if isinstance(method, list) else [method]
for item in methods:
self._routes[(rule, item, cors, compress, cache_control)] = func

def patch(
self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None
):
return self.route(rule, "PATCH", cors, compress, cache_control)
return register_route
4 changes: 2 additions & 2 deletions aws_lambda_powertools/utilities/validation/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class SchemaValidationError(Exception):

def __init__(
self,
message: str,
message: Optional[str] = None,
validation_message: Optional[str] = None,
name: Optional[str] = None,
path: Optional[List] = None,
Expand All @@ -21,7 +21,7 @@ def __init__(

Parameters
----------
message : str
message : str, optional
Powertools formatted error message
validation_message : str, optional
Containing human-readable information what is wrong
Expand Down
27 changes: 27 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,30 @@ def patch_func():
assert result["statusCode"] == 404
# AND cors headers are not returned
assert "Access-Control-Allow-Origin" not in result["headers"]


def test_duplicate_routes():
# GIVEN a duplicate routes
app = ApiGatewayResolver()
router = Router()

@router.get("/my/path")
def get_func_duplicate():
raise RuntimeError()

@app.get("/my/path")
def get_func():
return {}

@router.get("/my/path")
def get_func_another_duplicate():
raise RuntimeError()

app.include_router(router)

# WHEN calling the handler
result = app(LOAD_GW_EVENT, None)

# THEN only execute the first registered route
# AND print warnings
assert result["statusCode"] == 200
7 changes: 5 additions & 2 deletions tests/functional/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,11 +537,11 @@ def format(self, record: logging.LogRecord) -> str: # noqa: A003
logger = Logger(service=service_name, stream=stdout, logger_formatter=custom_formatter)

# WHEN a lambda function is decorated with logger
@logger.inject_lambda_context
@logger.inject_lambda_context(correlation_id_path="foo")
def handler(event, context):
logger.info("Hello")

handler({}, lambda_context)
handler({"foo": "value"}, lambda_context)

lambda_context_keys = (
"function_name",
Expand All @@ -554,8 +554,11 @@ def handler(event, context):

# THEN custom key should always be present
# and lambda contextual info should also be in the logs
# and get_correlation_id should return None
assert "my_default_key" in log
assert all(k in log for k in lambda_context_keys)
assert log["correlation_id"] == "value"
assert logger.get_correlation_id() is None


def test_logger_custom_handler(lambda_context, service_name, tmp_path):
Expand Down