Skip to content

Commit db4268b

Browse files
authored
feat(apigateway): multiple exceptions in exception_handler (#1707)
Co-authored-by: Steve Parker <[email protected]>
1 parent 8f44e49 commit db4268b

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -689,9 +689,14 @@ def not_found(self, func: Optional[Callable] = None):
689689
return self.exception_handler(NotFoundError)
690690
return self.exception_handler(NotFoundError)(func)
691691

692-
def exception_handler(self, exc_class: Type[Exception]):
692+
def exception_handler(self, exc_class: Union[Type[Exception], List[Type[Exception]]]):
693693
def register_exception_handler(func: Callable):
694-
self._exception_handlers[exc_class] = func
694+
if isinstance(exc_class, list):
695+
for exp in exc_class:
696+
self._exception_handlers[exp] = func
697+
else:
698+
self._exception_handlers[exc_class] = func
699+
return func
695700

696701
return register_exception_handler
697702

docs/core/event_handler/api_gateway.md

+3
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ You can use **`exception_handler`** decorator with any Python exception. This al
226226
--8<-- "examples/event_handler_rest/src/exception_handling.py"
227227
```
228228

229+
???+ info
230+
The `exception_handler` also supports passing a list of exception types you wish to handle with one handler.
231+
229232
### Raising HTTP errors
230233

231234
You can easily raise any HTTP Error back to the client using `ServiceError` exception. This ensures your Lambda function doesn't fail but return the correct HTTP response signalling the error.

tests/functional/event_handler/test_api_gateway.py

+59
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,65 @@ def get_lambda() -> Response:
13881388
assert result["body"] == json_dump(expected)
13891389

13901390

1391+
def test_exception_handler_supports_list(json_dump):
1392+
# GIVEN a resolver with an exception handler defined for a multiple exceptions in a list
1393+
app = ApiGatewayResolver()
1394+
event = deepcopy(LOAD_GW_EVENT)
1395+
1396+
@app.exception_handler([ValueError, NotFoundError])
1397+
def multiple_error(ex: Exception):
1398+
raise BadRequestError("Bad request")
1399+
1400+
@app.get("/path/a")
1401+
def path_a() -> Response:
1402+
raise ValueError("foo")
1403+
1404+
@app.get("/path/b")
1405+
def path_b() -> Response:
1406+
raise NotFoundError
1407+
1408+
# WHEN calling the app generating each exception
1409+
for route in ["/path/a", "/path/b"]:
1410+
event["path"] = route
1411+
result = app(event, {})
1412+
1413+
# THEN call the exception handler in the same way for both exceptions
1414+
assert result["statusCode"] == 400
1415+
assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON]
1416+
expected = {"statusCode": 400, "message": "Bad request"}
1417+
assert result["body"] == json_dump(expected)
1418+
1419+
1420+
def test_exception_handler_supports_multiple_decorators(json_dump):
1421+
# GIVEN a resolver with an exception handler defined with multiple decorators
1422+
app = ApiGatewayResolver()
1423+
event = deepcopy(LOAD_GW_EVENT)
1424+
1425+
@app.exception_handler(ValueError)
1426+
@app.exception_handler(NotFoundError)
1427+
def multiple_error(ex: Exception):
1428+
raise BadRequestError("Bad request")
1429+
1430+
@app.get("/path/a")
1431+
def path_a() -> Response:
1432+
raise ValueError("foo")
1433+
1434+
@app.get("/path/b")
1435+
def path_b() -> Response:
1436+
raise NotFoundError
1437+
1438+
# WHEN calling the app generating each exception
1439+
for route in ["/path/a", "/path/b"]:
1440+
event["path"] = route
1441+
result = app(event, {})
1442+
1443+
# THEN call the exception handler in the same way for both exceptions
1444+
assert result["statusCode"] == 400
1445+
assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON]
1446+
expected = {"statusCode": 400, "message": "Bad request"}
1447+
assert result["body"] == json_dump(expected)
1448+
1449+
13911450
def test_event_source_compatibility():
13921451
# GIVEN
13931452
app = APIGatewayHttpResolver()

0 commit comments

Comments
 (0)