diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 112bcd92dfe..c1c92ddabf7 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -689,9 +689,14 @@ def not_found(self, func: Optional[Callable] = None): return self.exception_handler(NotFoundError) return self.exception_handler(NotFoundError)(func) - def exception_handler(self, exc_class: Type[Exception]): + def exception_handler(self, exc_class: Union[Type[Exception], List[Type[Exception]]]): def register_exception_handler(func: Callable): - self._exception_handlers[exc_class] = func + if isinstance(exc_class, list): + for exp in exc_class: + self._exception_handlers[exp] = func + else: + self._exception_handlers[exc_class] = func + return func return register_exception_handler diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 88ffe3cb0bd..ca092e30c04 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -226,6 +226,9 @@ You can use **`exception_handler`** decorator with any Python exception. This al --8<-- "examples/event_handler_rest/src/exception_handling.py" ``` +???+ info + The `exception_handler` also supports passing a list of exception types you wish to handle with one handler. + ### Raising HTTP errors You can easily raise any HTTP Error back to the client using `ServiceError` exception. This ensures your Lambda function doesn't fail but return the correct HTTP response signalling the error. diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 6b343dd1f0f..2d439bc0e0b 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1388,6 +1388,65 @@ def get_lambda() -> Response: assert result["body"] == json_dump(expected) +def test_exception_handler_supports_list(json_dump): + # GIVEN a resolver with an exception handler defined for a multiple exceptions in a list + app = ApiGatewayResolver() + event = deepcopy(LOAD_GW_EVENT) + + @app.exception_handler([ValueError, NotFoundError]) + def multiple_error(ex: Exception): + raise BadRequestError("Bad request") + + @app.get("/path/a") + def path_a() -> Response: + raise ValueError("foo") + + @app.get("/path/b") + def path_b() -> Response: + raise NotFoundError + + # WHEN calling the app generating each exception + for route in ["/path/a", "/path/b"]: + event["path"] = route + result = app(event, {}) + + # THEN call the exception handler in the same way for both exceptions + assert result["statusCode"] == 400 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] + expected = {"statusCode": 400, "message": "Bad request"} + assert result["body"] == json_dump(expected) + + +def test_exception_handler_supports_multiple_decorators(json_dump): + # GIVEN a resolver with an exception handler defined with multiple decorators + app = ApiGatewayResolver() + event = deepcopy(LOAD_GW_EVENT) + + @app.exception_handler(ValueError) + @app.exception_handler(NotFoundError) + def multiple_error(ex: Exception): + raise BadRequestError("Bad request") + + @app.get("/path/a") + def path_a() -> Response: + raise ValueError("foo") + + @app.get("/path/b") + def path_b() -> Response: + raise NotFoundError + + # WHEN calling the app generating each exception + for route in ["/path/a", "/path/b"]: + event["path"] = route + result = app(event, {}) + + # THEN call the exception handler in the same way for both exceptions + assert result["statusCode"] == 400 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] + expected = {"statusCode": 400, "message": "Bad request"} + assert result["body"] == json_dump(expected) + + def test_event_source_compatibility(): # GIVEN app = APIGatewayHttpResolver()