diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index d3a79761556..dce520c147d 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -4,11 +4,13 @@ import os import re import traceback +import warnings import zlib +from abc import ABC, abstractmethod from enum import Enum -from functools import partial, wraps +from functools import partial from http import HTTPStatus -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import ServiceError @@ -227,78 +229,20 @@ def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dic } -class ApiGatewayResolver: - """API Gateway and ALB proxy resolver - - Examples - -------- - Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator - - ```python - from aws_lambda_powertools import Tracer - from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver - - tracer = Tracer() - app = ApiGatewayResolver() - - @app.get("/get-call") - def simple_get(): - return {"message": "Foo"} - - @app.post("/post-call") - def simple_post(): - post_data: dict = app.current_event.json_body - return {"message": post_data["value"]} - - @tracer.capture_lambda_handler - def lambda_handler(event, context): - return app.resolve(event, context) - ``` - """ - +class BaseRouter(ABC): current_event: BaseProxyEvent lambda_context: LambdaContext - def __init__( + @abstractmethod + def route( self, - proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, - cors: Optional[CORSConfig] = None, - debug: Optional[bool] = None, - serializer: Optional[Callable[[Dict], str]] = None, - strip_prefixes: Optional[List[str]] = None, + rule: str, + method: Any, + cors: Optional[bool] = None, + compress: bool = False, + cache_control: Optional[str] = None, ): - """ - Parameters - ---------- - proxy_type: ProxyEventType - Proxy request type, defaults to API Gateway V1 - cors: CORSConfig - Optionally configure and enabled CORS. Not each route will need to have to cors=True - debug: Optional[bool] - Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG" - environment variable - serializer : Callable, optional - function to serialize `obj` to a JSON formatted `str`, by default json.dumps - strip_prefixes: List[str], optional - optional list of prefixes to be removed from the request path before doing the routing. This is often used - with api gateways with multiple custom mappings. - """ - self._proxy_type = proxy_type - self._routes: List[Route] = [] - self._cors = cors - self._cors_enabled: bool = cors is not None - self._cors_methods: Set[str] = {"OPTIONS"} - self._debug = resolve_truthy_env_var_choice( - env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug - ) - self._strip_prefixes = strip_prefixes - - # Allow for a custom serializer or a concise json serialization - self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) - - if self._debug: - # Always does a pretty print when in debug mode - self._serializer = partial(json.dumps, indent=4, cls=Encoder) + raise NotImplementedError() def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): """Get route decorator with GET `method` @@ -434,6 +378,78 @@ def lambda_handler(event, context): """ return self.route(rule, "PATCH", cors, compress, cache_control) + +class ApiGatewayResolver(BaseRouter): + """API Gateway and ALB proxy resolver + + Examples + -------- + Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator + + ```python + from aws_lambda_powertools import Tracer + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver + + tracer = Tracer() + app = ApiGatewayResolver() + + @app.get("/get-call") + def simple_get(): + return {"message": "Foo"} + + @app.post("/post-call") + def simple_post(): + post_data: dict = app.current_event.json_body + return {"message": post_data["value"]} + + @tracer.capture_lambda_handler + def lambda_handler(event, context): + return app.resolve(event, context) + ``` + """ + + def __init__( + self, + proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, + cors: Optional[CORSConfig] = None, + debug: Optional[bool] = None, + serializer: Optional[Callable[[Dict], str]] = None, + strip_prefixes: Optional[List[str]] = None, + ): + """ + Parameters + ---------- + proxy_type: ProxyEventType + Proxy request type, defaults to API Gateway V1 + cors: CORSConfig + Optionally configure and enabled CORS. Not each route will need to have to cors=True + debug: Optional[bool] + Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG" + environment variable + serializer : Callable, optional + function to serialize `obj` to a JSON formatted `str`, by default json.dumps + strip_prefixes: List[str], optional + optional list of prefixes to be removed from the request path before doing the routing. This is often used + with api gateways with multiple custom mappings. + """ + self._proxy_type = proxy_type + self._routes: List[Route] = [] + self._route_keys: List[str] = [] + self._cors = cors + self._cors_enabled: bool = cors is not None + self._cors_methods: Set[str] = {"OPTIONS"} + self._debug = resolve_truthy_env_var_choice( + env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug + ) + self._strip_prefixes = strip_prefixes + + # Allow for a custom serializer or a concise json serialization + self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) + + if self._debug: + # Always does a pretty print when in debug mode + self._serializer = partial(json.dumps, indent=4, cls=Encoder) + def route( self, rule: str, @@ -451,6 +467,10 @@ def register_resolver(func: Callable): else: cors_enabled = cors self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control)) + route_key = method + rule + if route_key in self._route_keys: + warnings.warn(f"A route like this was already registered. method: '{method}' rule: '{rule}'") + self._route_keys.append(route_key) if cors_enabled: logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS") self._cors_methods.add(method.upper()) @@ -474,8 +494,8 @@ def resolve(self, event, context) -> Dict[str, Any]: """ if self._debug: print(self._json_dump(event)) - self.current_event = self._to_proxy_event(event) - self.lambda_context = context + BaseRouter.current_event = self._to_proxy_event(event) + BaseRouter.lambda_context = context return self._resolve().build(self.current_event, self._cors) def __call__(self, event, context) -> Any: @@ -632,71 +652,41 @@ def _json_dump(self, obj: Any) -> str: return self._serializer(obj) def include_router(self, router: "Router", prefix: Optional[str] = None) -> None: - """Adds all routes defined in a router""" - router._app = self - for route, func in router.api.items(): - if prefix and route[0] == "/": - route = (prefix, *route[1:]) - elif prefix: - route = (f"{prefix}{route[0]}", *route[1:]) - self.route(*route)(func()) - + """Adds all routes defined in a router -class Router: - """Router helper class to allow splitting ApiGatewayResolver into multiple files""" + Parameters + ---------- + router : Router + The Router containing a list of routes to be registered after the existing routes + prefix : str, optional + An optional prefix to be added to the originally defined rule + """ + for route, func in router._routes.items(): + if prefix: + rule = route[0] + rule = prefix if rule == "/" else f"{prefix}{rule}" + route = (rule, *route[1:]) - _app: ApiGatewayResolver + self.route(*route)(func) - def __init__(self): - self.api: Dict[tuple, Callable] = {} - @property - def current_event(self) -> BaseProxyEvent: - return self._app.current_event +class Router(BaseRouter): + """Router helper class to allow splitting ApiGatewayResolver into multiple files""" - @property - def lambda_context(self) -> LambdaContext: - return self._app.lambda_context + def __init__(self): + self._routes: Dict[tuple, Callable] = {} def route( self, rule: str, - method: Union[str, Tuple[str], List[str]], + method: Union[str, List[str]], cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, ): - def actual_decorator(func: Callable): - @wraps(func) - def wrapper(): - def inner_wrapper(**kwargs): - return func(**kwargs) - - return inner_wrapper - - if isinstance(method, (list, tuple)): - for item in method: - self.api[(rule, item, cors, compress, cache_control)] = wrapper - else: - self.api[(rule, method, cors, compress, cache_control)] = wrapper - - return actual_decorator - - def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): - return self.route(rule, "GET", cors, compress, cache_control) - - def post(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): - return self.route(rule, "POST", cors, compress, cache_control) - - def put(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): - return self.route(rule, "PUT", cors, compress, cache_control) - - def delete( - self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None - ): - return self.route(rule, "DELETE", cors, compress, cache_control) + def register_route(func: Callable): + methods = method if isinstance(method, list) else [method] + for item in methods: + self._routes[(rule, item, cors, compress, cache_control)] = func - def patch( - self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None - ): - return self.route(rule, "PATCH", cors, compress, cache_control) + return register_route diff --git a/aws_lambda_powertools/utilities/validation/exceptions.py b/aws_lambda_powertools/utilities/validation/exceptions.py index 7c719ca3119..2f13ff64188 100644 --- a/aws_lambda_powertools/utilities/validation/exceptions.py +++ b/aws_lambda_powertools/utilities/validation/exceptions.py @@ -8,7 +8,7 @@ class SchemaValidationError(Exception): def __init__( self, - message: str, + message: Optional[str] = None, validation_message: Optional[str] = None, name: Optional[str] = None, path: Optional[List] = None, @@ -21,7 +21,7 @@ def __init__( Parameters ---------- - message : str + message : str, optional Powertools formatted error message validation_message : str, optional Containing human-readable information what is wrong diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index afc979065f8..f4543fa300c 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -994,3 +994,30 @@ def patch_func(): assert result["statusCode"] == 404 # AND cors headers are not returned assert "Access-Control-Allow-Origin" not in result["headers"] + + +def test_duplicate_routes(): + # GIVEN a duplicate routes + app = ApiGatewayResolver() + router = Router() + + @router.get("/my/path") + def get_func_duplicate(): + raise RuntimeError() + + @app.get("/my/path") + def get_func(): + return {} + + @router.get("/my/path") + def get_func_another_duplicate(): + raise RuntimeError() + + app.include_router(router) + + # WHEN calling the handler + result = app(LOAD_GW_EVENT, None) + + # THEN only execute the first registered route + # AND print warnings + assert result["statusCode"] == 200 diff --git a/tests/functional/test_logger.py b/tests/functional/test_logger.py index a8d92c05257..3c9a8a54189 100644 --- a/tests/functional/test_logger.py +++ b/tests/functional/test_logger.py @@ -537,11 +537,11 @@ def format(self, record: logging.LogRecord) -> str: # noqa: A003 logger = Logger(service=service_name, stream=stdout, logger_formatter=custom_formatter) # WHEN a lambda function is decorated with logger - @logger.inject_lambda_context + @logger.inject_lambda_context(correlation_id_path="foo") def handler(event, context): logger.info("Hello") - handler({}, lambda_context) + handler({"foo": "value"}, lambda_context) lambda_context_keys = ( "function_name", @@ -554,8 +554,11 @@ def handler(event, context): # THEN custom key should always be present # and lambda contextual info should also be in the logs + # and get_correlation_id should return None assert "my_default_key" in log assert all(k in log for k in lambda_context_keys) + assert log["correlation_id"] == "value" + assert logger.get_correlation_id() is None def test_logger_custom_handler(lambda_context, service_name, tmp_path):