Skip to content

refactor(apigateway): Add BaseRouter and duplicate route check #757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 15, 2021
181 changes: 91 additions & 90 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import os
import re
import traceback
import warnings
import zlib
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial, wraps
from http import HTTPStatus
Expand Down Expand Up @@ -227,78 +229,17 @@ def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dic
}


class ApiGatewayResolver:
"""API Gateway and ALB proxy resolver

Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.get("/get-call")
def simple_get():
return {"message": "Foo"}

@app.post("/post-call")
def simple_post():
post_data: dict = app.current_event.json_body
return {"message": post_data["value"]}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""

current_event: BaseProxyEvent
lambda_context: LambdaContext

def __init__(
class BaseRouter(ABC):
@abstractmethod
def route(
self,
proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent,
cors: Optional[CORSConfig] = None,
debug: Optional[bool] = None,
serializer: Optional[Callable[[Dict], str]] = None,
strip_prefixes: Optional[List[str]] = None,
rule: str,
method: Any,
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
):
"""
Parameters
----------
proxy_type: ProxyEventType
Proxy request type, defaults to API Gateway V1
cors: CORSConfig
Optionally configure and enabled CORS. Not each route will need to have to cors=True
debug: Optional[bool]
Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG"
environment variable
serializer : Callable, optional
function to serialize `obj` to a JSON formatted `str`, by default json.dumps
strip_prefixes: List[str], optional
optional list of prefixes to be removed from the request path before doing the routing. This is often used
with api gateways with multiple custom mappings.
"""
self._proxy_type = proxy_type
self._routes: List[Route] = []
self._cors = cors
self._cors_enabled: bool = cors is not None
self._cors_methods: Set[str] = {"OPTIONS"}
self._debug = resolve_truthy_env_var_choice(
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
)
self._strip_prefixes = strip_prefixes

# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)

if self._debug:
# Always does a pretty print when in debug mode
self._serializer = partial(json.dumps, indent=4, cls=Encoder)
raise NotImplementedError()

def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
"""Get route decorator with GET `method`
Expand Down Expand Up @@ -434,6 +375,81 @@ def lambda_handler(event, context):
"""
return self.route(rule, "PATCH", cors, compress, cache_control)


class ApiGatewayResolver(BaseRouter):
"""API Gateway and ALB proxy resolver

Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator

```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
app = ApiGatewayResolver()

@app.get("/get-call")
def simple_get():
return {"message": "Foo"}

@app.post("/post-call")
def simple_post():
post_data: dict = app.current_event.json_body
return {"message": post_data["value"]}

@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""

current_event: BaseProxyEvent
lambda_context: LambdaContext

def __init__(
self,
proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent,
cors: Optional[CORSConfig] = None,
debug: Optional[bool] = None,
serializer: Optional[Callable[[Dict], str]] = None,
strip_prefixes: Optional[List[str]] = None,
):
"""
Parameters
----------
proxy_type: ProxyEventType
Proxy request type, defaults to API Gateway V1
cors: CORSConfig
Optionally configure and enabled CORS. Not each route will need to have to cors=True
debug: Optional[bool]
Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG"
environment variable
serializer : Callable, optional
function to serialize `obj` to a JSON formatted `str`, by default json.dumps
strip_prefixes: List[str], optional
optional list of prefixes to be removed from the request path before doing the routing. This is often used
with api gateways with multiple custom mappings.
"""
self._proxy_type = proxy_type
self._routes: List[Route] = []
self._route_keys: List[str] = []
self._cors = cors
self._cors_enabled: bool = cors is not None
self._cors_methods: Set[str] = {"OPTIONS"}
self._debug = resolve_truthy_env_var_choice(
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
)
self._strip_prefixes = strip_prefixes

# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)

if self._debug:
# Always does a pretty print when in debug mode
self._serializer = partial(json.dumps, indent=4, cls=Encoder)

def route(
self,
rule: str,
Expand All @@ -451,6 +467,10 @@ def register_resolver(func: Callable):
else:
cors_enabled = cors
self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control))
route_key = method + rule
if route_key in self._route_keys:
warnings.warn(f"A route like this was already registered. method: '{method}' rule: '{rule}'")
self._route_keys.append(route_key)
if cors_enabled:
logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS")
self._cors_methods.add(method.upper())
Expand Down Expand Up @@ -642,7 +662,7 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
self.route(*route)(func())


class Router:
class Router(BaseRouter):
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""

_app: ApiGatewayResolver
Expand Down Expand Up @@ -681,22 +701,3 @@ def inner_wrapper(**kwargs):
self.api[(rule, method, cors, compress, cache_control)] = wrapper

return actual_decorator

def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
return self.route(rule, "GET", cors, compress, cache_control)

def post(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
return self.route(rule, "POST", cors, compress, cache_control)

def put(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
return self.route(rule, "PUT", cors, compress, cache_control)

def delete(
self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None
):
return self.route(rule, "DELETE", cors, compress, cache_control)

def patch(
self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None
):
return self.route(rule, "PATCH", cors, compress, cache_control)
4 changes: 2 additions & 2 deletions aws_lambda_powertools/utilities/validation/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class SchemaValidationError(Exception):

def __init__(
self,
message: str,
message: Optional[str] = None,
validation_message: Optional[str] = None,
name: Optional[str] = None,
path: Optional[List] = None,
Expand All @@ -21,7 +21,7 @@ def __init__(

Parameters
----------
message : str
message : str, optional
Powertools formatted error message
validation_message : str, optional
Containing human-readable information what is wrong
Expand Down
27 changes: 27 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,30 @@ def patch_func():
assert result["statusCode"] == 404
# AND cors headers are not returned
assert "Access-Control-Allow-Origin" not in result["headers"]


def test_duplicate_routes():
# GIVEN a duplicate routes
app = ApiGatewayResolver()
router = Router()

@router.get("/my/path")
def get_func_duplicate():
raise RuntimeError()

@app.get("/my/path")
def get_func():
return {}

@router.get("/my/path")
def get_func_another_duplicate():
raise RuntimeError()

app.include_router(router)

# WHEN calling the handler
result = app(LOAD_GW_EVENT, None)

# THEN only execute the first registered route
# AND print warnings
assert result["statusCode"] == 200
7 changes: 5 additions & 2 deletions tests/functional/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,11 +537,11 @@ def format(self, record: logging.LogRecord) -> str: # noqa: A003
logger = Logger(service=service_name, stream=stdout, logger_formatter=custom_formatter)

# WHEN a lambda function is decorated with logger
@logger.inject_lambda_context
@logger.inject_lambda_context(correlation_id_path="foo")
def handler(event, context):
logger.info("Hello")

handler({}, lambda_context)
handler({"foo": "value"}, lambda_context)

lambda_context_keys = (
"function_name",
Expand All @@ -554,8 +554,11 @@ def handler(event, context):

# THEN custom key should always be present
# and lambda contextual info should also be in the logs
# and get_correlation_id should return None
assert "my_default_key" in log
assert all(k in log for k in lambda_context_keys)
assert log["correlation_id"] == "value"
assert logger.get_correlation_id() is None


def test_logger_custom_handler(lambda_context, service_name, tmp_path):
Expand Down