From 3ed9443df084fde03351c7f3cab72ad91f48a687 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 18 Mar 2024 21:29:04 +0000 Subject: [PATCH] Adding router.exception_handler --- .../event_handler/api_gateway.py | 15 ++++++++++ .../event_handler/test_api_gateway.py | 30 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index bdbdcc564fb..fe51d68dab9 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -2133,6 +2133,9 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None logger.debug("Appending Router middlewares into App middlewares.") self._router_middlewares = self._router_middlewares + router._router_middlewares + logger.debug("Appending Router exception_handler into App exception_handler.") + self._exception_handlers.update(router._exception_handlers) + # use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx) router.context = self.context @@ -2198,6 +2201,7 @@ def __init__(self): self._routes_with_middleware: Dict[tuple, List[Callable]] = {} self.api_resolver: Optional[BaseRouter] = None self.context = {} # early init as customers might add context before event resolution + self._exception_handlers: Dict[Type, Callable] = {} def route( self, @@ -2252,6 +2256,17 @@ def register_route(func: Callable): return register_route + def exception_handler(self, exc_class: Union[Type[Exception], List[Type[Exception]]]): + def register_exception_handler(func: Callable): + if isinstance(exc_class, list): + for exp in exc_class: + self._exception_handlers[exp] = func + else: + self._exception_handlers[exc_class] = func + return func + + return register_exception_handler + class APIGatewayRestResolver(ApiGatewayResolver): current_event: APIGatewayProxyEvent diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index fa166bac77e..3929496be50 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1504,6 +1504,36 @@ def get_lambda(param: int): ... assert result["body"] == '{"msg":"Invalid data. Number of errors: 1"}' +def test_exception_handler_with_route(): + app = ApiGatewayResolver() + # GIVEN a Router object with an exception handler defined for ValueError + router = Router() + + @router.exception_handler(ValueError) + def handle_value_error(ex: ValueError): + print(f"request path is '{app.current_event.path}'") + return Response( + status_code=418, + content_type=content_types.TEXT_HTML, + body=str(ex), + ) + + @router.get("/my/path") + def get_lambda() -> Response: + raise ValueError("Foo!") + + app.include_router(router) + + # WHEN calling the event handler + # AND a ValueError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler from Router + assert result["statusCode"] == 418 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.TEXT_HTML] + assert result["body"] == "Foo!" + + def test_data_validation_error(): # GIVEN a resolver without an exception handler app = ApiGatewayResolver(enable_validation=True)