Skip to content

Commit a403f4d

Browse files
feat(event_handler): define exception_handler directly from the router (#3979)
Adding router.exception_handler
1 parent 1da46f8 commit a403f4d

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+15
Original file line numberDiff line numberDiff line change
@@ -2133,6 +2133,9 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
21332133
logger.debug("Appending Router middlewares into App middlewares.")
21342134
self._router_middlewares = self._router_middlewares + router._router_middlewares
21352135

2136+
logger.debug("Appending Router exception_handler into App exception_handler.")
2137+
self._exception_handlers.update(router._exception_handlers)
2138+
21362139
# use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx)
21372140
router.context = self.context
21382141

@@ -2198,6 +2201,7 @@ def __init__(self):
21982201
self._routes_with_middleware: Dict[tuple, List[Callable]] = {}
21992202
self.api_resolver: Optional[BaseRouter] = None
22002203
self.context = {} # early init as customers might add context before event resolution
2204+
self._exception_handlers: Dict[Type, Callable] = {}
22012205

22022206
def route(
22032207
self,
@@ -2252,6 +2256,17 @@ def register_route(func: Callable):
22522256

22532257
return register_route
22542258

2259+
def exception_handler(self, exc_class: Union[Type[Exception], List[Type[Exception]]]):
2260+
def register_exception_handler(func: Callable):
2261+
if isinstance(exc_class, list):
2262+
for exp in exc_class:
2263+
self._exception_handlers[exp] = func
2264+
else:
2265+
self._exception_handlers[exc_class] = func
2266+
return func
2267+
2268+
return register_exception_handler
2269+
22552270

22562271
class APIGatewayRestResolver(ApiGatewayResolver):
22572272
current_event: APIGatewayProxyEvent

tests/functional/event_handler/test_api_gateway.py

+30
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,36 @@ def get_lambda(param: int): ...
15041504
assert result["body"] == '{"msg":"Invalid data. Number of errors: 1"}'
15051505

15061506

1507+
def test_exception_handler_with_route():
1508+
app = ApiGatewayResolver()
1509+
# GIVEN a Router object with an exception handler defined for ValueError
1510+
router = Router()
1511+
1512+
@router.exception_handler(ValueError)
1513+
def handle_value_error(ex: ValueError):
1514+
print(f"request path is '{app.current_event.path}'")
1515+
return Response(
1516+
status_code=418,
1517+
content_type=content_types.TEXT_HTML,
1518+
body=str(ex),
1519+
)
1520+
1521+
@router.get("/my/path")
1522+
def get_lambda() -> Response:
1523+
raise ValueError("Foo!")
1524+
1525+
app.include_router(router)
1526+
1527+
# WHEN calling the event handler
1528+
# AND a ValueError is raised
1529+
result = app(LOAD_GW_EVENT, {})
1530+
1531+
# THEN call the exception_handler from Router
1532+
assert result["statusCode"] == 418
1533+
assert result["multiValueHeaders"]["Content-Type"] == [content_types.TEXT_HTML]
1534+
assert result["body"] == "Foo!"
1535+
1536+
15071537
def test_data_validation_error():
15081538
# GIVEN a resolver without an exception handler
15091539
app = ApiGatewayResolver(enable_validation=True)

0 commit comments

Comments
 (0)