From 85321d11f285fd13ba795c8ca0019015e9c76a15 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 28 Feb 2025 12:12:05 +0000 Subject: [PATCH 01/23] feat(api-gateway-resolver): Add option for custom response validation error status code. --- aws_lambda_powertools/event_handler/lambda_function_url.py | 1 + aws_lambda_powertools/event_handler/vpc_lattice.py | 1 + 2 files changed, 2 insertions(+) diff --git a/aws_lambda_powertools/event_handler/lambda_function_url.py b/aws_lambda_powertools/event_handler/lambda_function_url.py index c761834e8b3..9f1aaaea62b 100644 --- a/aws_lambda_powertools/event_handler/lambda_function_url.py +++ b/aws_lambda_powertools/event_handler/lambda_function_url.py @@ -12,6 +12,7 @@ from aws_lambda_powertools.event_handler import CORSConfig from aws_lambda_powertools.utilities.data_classes import LambdaFunctionUrlEvent + from http import HTTPStatus class LambdaFunctionUrlResolver(ApiGatewayResolver): diff --git a/aws_lambda_powertools/event_handler/vpc_lattice.py b/aws_lambda_powertools/event_handler/vpc_lattice.py index 30ee8fd86fc..94f5079dbc2 100644 --- a/aws_lambda_powertools/event_handler/vpc_lattice.py +++ b/aws_lambda_powertools/event_handler/vpc_lattice.py @@ -12,6 +12,7 @@ from aws_lambda_powertools.event_handler import CORSConfig from aws_lambda_powertools.utilities.data_classes import VPCLatticeEvent, VPCLatticeEventV2 + from http import HTTPStatus class VPCLatticeResolver(ApiGatewayResolver): From aa7cf6fb5413eca8e978e3fc39cb8b329136779a Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 28 Feb 2025 12:12:56 +0000 Subject: [PATCH 02/23] feat(docs): Added doc for custom response validation error responses. --- .../src/response_validation_error_unsanitized_output.json | 8 ++++++++ .../src/response_validation_sanitized_error_output.json | 8 ++++++++ 2 files changed, 16 insertions(+) create mode 100644 examples/event_handler_rest/src/response_validation_error_unsanitized_output.json create mode 100644 examples/event_handler_rest/src/response_validation_sanitized_error_output.json diff --git a/examples/event_handler_rest/src/response_validation_error_unsanitized_output.json b/examples/event_handler_rest/src/response_validation_error_unsanitized_output.json new file mode 100644 index 00000000000..c2fbe3df339 --- /dev/null +++ b/examples/event_handler_rest/src/response_validation_error_unsanitized_output.json @@ -0,0 +1,8 @@ +{ + "statusCode": 500, + "body": "{\"statusCode\": 500, \"detail\": [{\"type\": \"model_attributes_type\", \"loc\": [\"response\", ]}]}", + "isBase64Encoded": false, + "headers": { + "Content-Type": "application/json" + } +} \ No newline at end of file diff --git a/examples/event_handler_rest/src/response_validation_sanitized_error_output.json b/examples/event_handler_rest/src/response_validation_sanitized_error_output.json new file mode 100644 index 00000000000..79c97da7498 --- /dev/null +++ b/examples/event_handler_rest/src/response_validation_sanitized_error_output.json @@ -0,0 +1,8 @@ +{ + "statusCode": 500, + "body": "Unexpected response.", + "isBase64Encoded": false, + "headers": { + "Content-Type": "application/json" + } +} \ No newline at end of file From cb7fd6dc3d2f1e15c46ebf9642a79b9f0302764c Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 28 Feb 2025 13:24:22 +0000 Subject: [PATCH 03/23] feat(unit-test): Add tests for custom response validation error. --- .../event_handler/test_response_validation.py | 182 ++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 tests/unit/event_handler/test_response_validation.py diff --git a/tests/unit/event_handler/test_response_validation.py b/tests/unit/event_handler/test_response_validation.py new file mode 100644 index 00000000000..db18d1b95fd --- /dev/null +++ b/tests/unit/event_handler/test_response_validation.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import json +from http import HTTPStatus + +import pytest +from pydantic import BaseModel, Field + +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError + +app = APIGatewayRestResolver(enable_validation=True) +app_with_custom_response_validation_error = APIGatewayRestResolver( + enable_validation=True, + response_validation_error_http_status=HTTPStatus.INTERNAL_SERVER_ERROR, +) + + +class Todo(BaseModel): + userId: int + id_: int | None = Field(alias="id", default=None) + title: str + completed: bool + + +TODO_OBJECT = Todo(userId="1234", id="1", title="Write tests.", completed=True) + + +@app_with_custom_response_validation_error.get("/string_not_todo") +@app.get("/string_not_todo") +def return_string_not_todo() -> Todo: + return "hello" + + +@app_with_custom_response_validation_error.get("/incomplete_todo") +@app.get("/incomplete_todo") +def return_incomplete_todo() -> Todo: + return {"title": "fix_response_validation"} + + +@app_with_custom_response_validation_error.get("/todo") +@app.get("/todo") +def return_todo() -> Todo: + return TODO_OBJECT + + +# --- Tests below --- + + +@pytest.fixture() +def event_factory(): + def _factory(path: str): + return { + "httpMethod": "GET", + "path": path, + } + + yield _factory + + +@pytest.fixture() +def response_validation_error_factory(): + def _factory(loc: list[str], type_: str): + if not loc: + return [{"loc": ["response"], "type": type_}] + return [{"loc": ["response", location], "type": type_} for location in loc] + + yield _factory + + +class TestDefaultResponseValidation: + + def test_valid_response(self, event_factory): + event = event_factory("/todo") + + response = app.resolve(event, None) + body = json.loads(response["body"]) + + assert response["statusCode"] == HTTPStatus.OK + assert body == TODO_OBJECT.model_dump(by_alias=True) + + @pytest.mark.parametrize( + ( + "path", + "error_location", + "error_type", + ), + [ + ("/string_not_todo", [], "model_attributes_type"), + ("/incomplete_todo", ["userId", "completed"], "missing"), + ], + ids=["string_not_todo", "incomplete_todo"], + ) + def test_default_serialization_failure( + self, + path, + error_location, + error_type, + event_factory, + response_validation_error_factory, + ): + """Tests to demonstrate cases when response serialization fails, as expected.""" + event = event_factory(path) + error_detail = response_validation_error_factory(error_location, error_type) + + response = app.resolve(event, None) + body = json.loads(response["body"]) + + assert response["statusCode"] == HTTPStatus.UNPROCESSABLE_ENTITY + assert body == {"statusCode": 422, "detail": error_detail} + + +class TestCustomResponseValidation: + + def test_valid_response(self, event_factory): + + event = event_factory("/todo") + + response = app_with_custom_response_validation_error.resolve(event, None) + body = json.loads(response["body"]) + + assert response["statusCode"] == HTTPStatus.OK + assert body == TODO_OBJECT.model_dump(by_alias=True) + + @pytest.mark.parametrize( + ( + "path", + "error_location", + "error_type", + ), + [ + ("/string_not_todo", [], "model_attributes_type"), + ("/incomplete_todo", ["userId", "completed"], "missing"), + ], + ids=["string_not_todo", "incomplete_todo"], + ) + def test_custom_serialization_failure( + self, + path, + error_location, + error_type, + event_factory, + response_validation_error_factory, + ): + """Tests to demonstrate cases when response serialization fails, as expected.""" + + event = event_factory(path) + error_detail = response_validation_error_factory(error_location, error_type) + + response = app_with_custom_response_validation_error.resolve(event, None) + body = json.loads(response["body"]) + + assert response["statusCode"] == HTTPStatus.INTERNAL_SERVER_ERROR + assert body == {"statusCode": 500, "detail": error_detail} + + @pytest.mark.parametrize( + "path", + [ + ("/string_not_todo"), + ("/incomplete_todo"), + ], + ids=["string_not_todo", "incomplete_todo"], + ) + def test_sanitized_error_response( + self, + path, + event_factory, + ): + event = event_factory(path) + + @app_with_custom_response_validation_error.exception_handler(ResponseValidationError) + def handle_response_validation_error(ex: ResponseValidationError): + return Response( + status_code=500, + content_type="application/json", + body="Unexpected response.", + ) + + response = app_with_custom_response_validation_error.resolve(event, None) + + assert response["statusCode"] == HTTPStatus.INTERNAL_SERVER_ERROR + assert response["body"] == "Unexpected response." From 1228632c4874fdd64833a8b41d47a16deed9910a Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 28 Feb 2025 13:41:00 +0000 Subject: [PATCH 04/23] fix: Formatting. --- aws_lambda_powertools/event_handler/lambda_function_url.py | 1 - aws_lambda_powertools/event_handler/vpc_lattice.py | 1 - 2 files changed, 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/lambda_function_url.py b/aws_lambda_powertools/event_handler/lambda_function_url.py index 9f1aaaea62b..c761834e8b3 100644 --- a/aws_lambda_powertools/event_handler/lambda_function_url.py +++ b/aws_lambda_powertools/event_handler/lambda_function_url.py @@ -12,7 +12,6 @@ from aws_lambda_powertools.event_handler import CORSConfig from aws_lambda_powertools.utilities.data_classes import LambdaFunctionUrlEvent - from http import HTTPStatus class LambdaFunctionUrlResolver(ApiGatewayResolver): diff --git a/aws_lambda_powertools/event_handler/vpc_lattice.py b/aws_lambda_powertools/event_handler/vpc_lattice.py index 94f5079dbc2..30ee8fd86fc 100644 --- a/aws_lambda_powertools/event_handler/vpc_lattice.py +++ b/aws_lambda_powertools/event_handler/vpc_lattice.py @@ -12,7 +12,6 @@ from aws_lambda_powertools.event_handler import CORSConfig from aws_lambda_powertools.utilities.data_classes import VPCLatticeEvent, VPCLatticeEventV2 - from http import HTTPStatus class VPCLatticeResolver(ApiGatewayResolver): From b95d52180ba45b5a7d71989ebab04a2cc385ed5d Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 28 Feb 2025 15:54:38 +0000 Subject: [PATCH 05/23] fix(unit-test): fix failed CI. --- tests/unit/event_handler/test_response_validation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/event_handler/test_response_validation.py b/tests/unit/event_handler/test_response_validation.py index db18d1b95fd..0b49fd6f49d 100644 --- a/tests/unit/event_handler/test_response_validation.py +++ b/tests/unit/event_handler/test_response_validation.py @@ -1,7 +1,6 @@ -from __future__ import annotations - import json from http import HTTPStatus +from typing import Optional import pytest from pydantic import BaseModel, Field @@ -18,7 +17,7 @@ class Todo(BaseModel): userId: int - id_: int | None = Field(alias="id", default=None) + id_: Optional[int] = Field(alias="id", default=None) title: str completed: bool From f8499304a36c81a1dffbc80281a8e3e0230d599a Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 28 Feb 2025 17:50:33 +0000 Subject: [PATCH 06/23] feat(unit-test): add tests for incorrect types and invalid configs --- .../event_handler/test_response_validation.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/unit/event_handler/test_response_validation.py b/tests/unit/event_handler/test_response_validation.py index 0b49fd6f49d..1ee711e5784 100644 --- a/tests/unit/event_handler/test_response_validation.py +++ b/tests/unit/event_handler/test_response_validation.py @@ -179,3 +179,25 @@ def handle_response_validation_error(ex: ResponseValidationError): assert response["statusCode"] == HTTPStatus.INTERNAL_SERVER_ERROR assert response["body"] == "Unexpected response." + + def test_incorrect_resolver_config_no_validation(self): + with pytest.raises(ValueError) as exception_info: + APIGatewayRestResolver(response_validation_error_http_status=500) + + assert ( + str(exception_info.value) + == "'response_validation_error_http_status' cannot be set when enable_validation is False." + ) + + @pytest.mark.parametrize("response_validation_error_http_status", [(20), ("hi"), (1.21)]) + def test_incorrect_resolver_config_bad_http_status_code(self, response_validation_error_http_status): + with pytest.raises(ValueError) as exception_info: + APIGatewayRestResolver( + enable_validation=True, + response_validation_error_http_status=response_validation_error_http_status, + ) + + assert ( + str(exception_info.value) + == f"'{response_validation_error_http_status}' must be an integer representing an HTTP status code." + ) From bafd19c218aaeb5d5af8efb3f1eb1cf49f107520 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 7 Mar 2025 10:34:29 +0400 Subject: [PATCH 07/23] refactor: rename response_validation_error_http_status to response_validation_error_http_code --- .../unit/event_handler/test_response_validation.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/event_handler/test_response_validation.py b/tests/unit/event_handler/test_response_validation.py index 1ee711e5784..9f537b85aad 100644 --- a/tests/unit/event_handler/test_response_validation.py +++ b/tests/unit/event_handler/test_response_validation.py @@ -11,7 +11,7 @@ app = APIGatewayRestResolver(enable_validation=True) app_with_custom_response_validation_error = APIGatewayRestResolver( enable_validation=True, - response_validation_error_http_status=HTTPStatus.INTERNAL_SERVER_ERROR, + response_validation_error_http_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) @@ -182,22 +182,22 @@ def handle_response_validation_error(ex: ResponseValidationError): def test_incorrect_resolver_config_no_validation(self): with pytest.raises(ValueError) as exception_info: - APIGatewayRestResolver(response_validation_error_http_status=500) + APIGatewayRestResolver(response_validation_error_http_code=500) assert ( str(exception_info.value) - == "'response_validation_error_http_status' cannot be set when enable_validation is False." + == "'response_validation_error_http_code' cannot be set when enable_validation is False." ) - @pytest.mark.parametrize("response_validation_error_http_status", [(20), ("hi"), (1.21)]) - def test_incorrect_resolver_config_bad_http_status_code(self, response_validation_error_http_status): + @pytest.mark.parametrize("response_validation_error_http_code", [(20), ("hi"), (1.21)]) + def test_incorrect_resolver_config_bad_http_status_code(self, response_validation_error_http_code): with pytest.raises(ValueError) as exception_info: APIGatewayRestResolver( enable_validation=True, - response_validation_error_http_status=response_validation_error_http_status, + response_validation_error_http_code=response_validation_error_http_code, ) assert ( str(exception_info.value) - == f"'{response_validation_error_http_status}' must be an integer representing an HTTP status code." + == f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code." ) From 9b09bb70a941fd41653b8ddb4aa5399bc88bd134 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 7 Mar 2025 15:42:17 +0400 Subject: [PATCH 08/23] refactor(tests): move unit tests into openapi_validation functional test file --- .../event_handler/test_response_validation.py | 203 ------------------ 1 file changed, 203 deletions(-) delete mode 100644 tests/unit/event_handler/test_response_validation.py diff --git a/tests/unit/event_handler/test_response_validation.py b/tests/unit/event_handler/test_response_validation.py deleted file mode 100644 index 9f537b85aad..00000000000 --- a/tests/unit/event_handler/test_response_validation.py +++ /dev/null @@ -1,203 +0,0 @@ -import json -from http import HTTPStatus -from typing import Optional - -import pytest -from pydantic import BaseModel, Field - -from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response -from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError - -app = APIGatewayRestResolver(enable_validation=True) -app_with_custom_response_validation_error = APIGatewayRestResolver( - enable_validation=True, - response_validation_error_http_code=HTTPStatus.INTERNAL_SERVER_ERROR, -) - - -class Todo(BaseModel): - userId: int - id_: Optional[int] = Field(alias="id", default=None) - title: str - completed: bool - - -TODO_OBJECT = Todo(userId="1234", id="1", title="Write tests.", completed=True) - - -@app_with_custom_response_validation_error.get("/string_not_todo") -@app.get("/string_not_todo") -def return_string_not_todo() -> Todo: - return "hello" - - -@app_with_custom_response_validation_error.get("/incomplete_todo") -@app.get("/incomplete_todo") -def return_incomplete_todo() -> Todo: - return {"title": "fix_response_validation"} - - -@app_with_custom_response_validation_error.get("/todo") -@app.get("/todo") -def return_todo() -> Todo: - return TODO_OBJECT - - -# --- Tests below --- - - -@pytest.fixture() -def event_factory(): - def _factory(path: str): - return { - "httpMethod": "GET", - "path": path, - } - - yield _factory - - -@pytest.fixture() -def response_validation_error_factory(): - def _factory(loc: list[str], type_: str): - if not loc: - return [{"loc": ["response"], "type": type_}] - return [{"loc": ["response", location], "type": type_} for location in loc] - - yield _factory - - -class TestDefaultResponseValidation: - - def test_valid_response(self, event_factory): - event = event_factory("/todo") - - response = app.resolve(event, None) - body = json.loads(response["body"]) - - assert response["statusCode"] == HTTPStatus.OK - assert body == TODO_OBJECT.model_dump(by_alias=True) - - @pytest.mark.parametrize( - ( - "path", - "error_location", - "error_type", - ), - [ - ("/string_not_todo", [], "model_attributes_type"), - ("/incomplete_todo", ["userId", "completed"], "missing"), - ], - ids=["string_not_todo", "incomplete_todo"], - ) - def test_default_serialization_failure( - self, - path, - error_location, - error_type, - event_factory, - response_validation_error_factory, - ): - """Tests to demonstrate cases when response serialization fails, as expected.""" - event = event_factory(path) - error_detail = response_validation_error_factory(error_location, error_type) - - response = app.resolve(event, None) - body = json.loads(response["body"]) - - assert response["statusCode"] == HTTPStatus.UNPROCESSABLE_ENTITY - assert body == {"statusCode": 422, "detail": error_detail} - - -class TestCustomResponseValidation: - - def test_valid_response(self, event_factory): - - event = event_factory("/todo") - - response = app_with_custom_response_validation_error.resolve(event, None) - body = json.loads(response["body"]) - - assert response["statusCode"] == HTTPStatus.OK - assert body == TODO_OBJECT.model_dump(by_alias=True) - - @pytest.mark.parametrize( - ( - "path", - "error_location", - "error_type", - ), - [ - ("/string_not_todo", [], "model_attributes_type"), - ("/incomplete_todo", ["userId", "completed"], "missing"), - ], - ids=["string_not_todo", "incomplete_todo"], - ) - def test_custom_serialization_failure( - self, - path, - error_location, - error_type, - event_factory, - response_validation_error_factory, - ): - """Tests to demonstrate cases when response serialization fails, as expected.""" - - event = event_factory(path) - error_detail = response_validation_error_factory(error_location, error_type) - - response = app_with_custom_response_validation_error.resolve(event, None) - body = json.loads(response["body"]) - - assert response["statusCode"] == HTTPStatus.INTERNAL_SERVER_ERROR - assert body == {"statusCode": 500, "detail": error_detail} - - @pytest.mark.parametrize( - "path", - [ - ("/string_not_todo"), - ("/incomplete_todo"), - ], - ids=["string_not_todo", "incomplete_todo"], - ) - def test_sanitized_error_response( - self, - path, - event_factory, - ): - event = event_factory(path) - - @app_with_custom_response_validation_error.exception_handler(ResponseValidationError) - def handle_response_validation_error(ex: ResponseValidationError): - return Response( - status_code=500, - content_type="application/json", - body="Unexpected response.", - ) - - response = app_with_custom_response_validation_error.resolve(event, None) - - assert response["statusCode"] == HTTPStatus.INTERNAL_SERVER_ERROR - assert response["body"] == "Unexpected response." - - def test_incorrect_resolver_config_no_validation(self): - with pytest.raises(ValueError) as exception_info: - APIGatewayRestResolver(response_validation_error_http_code=500) - - assert ( - str(exception_info.value) - == "'response_validation_error_http_code' cannot be set when enable_validation is False." - ) - - @pytest.mark.parametrize("response_validation_error_http_code", [(20), ("hi"), (1.21)]) - def test_incorrect_resolver_config_bad_http_status_code(self, response_validation_error_http_code): - with pytest.raises(ValueError) as exception_info: - APIGatewayRestResolver( - enable_validation=True, - response_validation_error_http_code=response_validation_error_http_code, - ) - - assert ( - str(exception_info.value) - == f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code." - ) From bbbd9891cbda9df4111bed7f3f99763b74484a8d Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 7 Mar 2025 15:56:27 +0000 Subject: [PATCH 09/23] feat: add route-specific custom response validation and tests --- .../event_handler/api_gateway.py | 40 +++++++++++- .../middlewares/openapi_validation.py | 10 ++- .../event_handler/openapi/exceptions.py | 5 +- .../test_openapi_validation_middleware.py | 64 +++++++++++++++++++ 4 files changed, 113 insertions(+), 6 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index c8e4248fda4..f5951c4cc82 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -320,6 +320,7 @@ def __init__( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Response]] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ): """ Internally used Route Configuration @@ -362,6 +363,7 @@ def __init__( Whether or not to mark this route as deprecated in the OpenAPI schema middlewares: list[Callable[..., Response]] | None The list of route middlewares to be called in order. + # TODO """ self.method = method.upper() self.path = "/" if path.strip() == "" else path @@ -397,6 +399,8 @@ def __init__( # _body_field is used to cache the dependant model for the body field self._body_field: ModelField | None = None + self.custom_response_validation_http_code: int | HTTPStatus | None = custom_response_validation_http_code + def __call__( self, router_middlewares: list[Callable], @@ -565,6 +569,8 @@ def _get_openapi_path( }, } + # TODO update responses + # Add the response to the OpenAPI operation if self.responses: for status_code in list(self.responses): @@ -943,6 +949,7 @@ def route( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: raise NotImplementedError() @@ -1004,6 +1011,7 @@ def get( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Get route decorator with GET `method` @@ -1044,6 +1052,7 @@ def lambda_handler(event, context): openapi_extensions, deprecated, middlewares, + custom_response_validation_http_code, ) def post( @@ -1063,6 +1072,7 @@ def post( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Post route decorator with POST `method` @@ -1104,6 +1114,7 @@ def lambda_handler(event, context): openapi_extensions, deprecated, middlewares, + custom_response_validation_http_code, ) def put( @@ -1123,6 +1134,7 @@ def put( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Put route decorator with PUT `method` @@ -1164,6 +1176,7 @@ def lambda_handler(event, context): openapi_extensions, deprecated, middlewares, + custom_response_validation_http_code, ) def delete( @@ -1183,6 +1196,7 @@ def delete( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Delete route decorator with DELETE `method` @@ -1223,6 +1237,7 @@ def lambda_handler(event, context): openapi_extensions, deprecated, middlewares, + custom_response_validation_http_code, ) def patch( @@ -1242,6 +1257,7 @@ def patch( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Patch route decorator with PATCH `method` @@ -1285,6 +1301,7 @@ def lambda_handler(event, context): openapi_extensions, deprecated, middlewares, + custom_response_validation_http_code, ) def head( @@ -1304,6 +1321,7 @@ def head( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Head route decorator with HEAD `method` @@ -1346,6 +1364,7 @@ def lambda_handler(event, context): openapi_extensions, deprecated, middlewares, + custom_response_validation_http_code, ) def _push_processed_stack_frame(self, frame: str): @@ -2126,9 +2145,14 @@ def route( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Route decorator includes parameter `method`""" + custom_response_validation_http_code = self._validate_route_response_validation_error_http_code( + custom_response_validation_http_code, + ) + def register_resolver(func: AnyCallableT) -> AnyCallableT: methods = (method,) if isinstance(method, str) else method logger.debug(f"Adding route using rule {rule} and methods: {','.join(m.upper() for m in methods)}") @@ -2155,6 +2179,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT: openapi_extensions, deprecated, middlewares, + custom_response_validation_http_code, ) # The more specific route wins. @@ -2523,15 +2548,20 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild ) # OpenAPIValidationMiddleware will only raise ResponseValidationError when - # 'self._response_validation_error_http_code' is not None + # 'self._response_validation_error_http_code' is not None or + # when route has custom_response_validation_http_code if isinstance(exp, ResponseValidationError): - http_code = self._response_validation_error_http_code + http_code = ( + self._response_validation_error_http_code + if exp.source == "app" + else route.custom_response_validation_http_code + ) errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()] return self._response_builder_class( response=Response( status_code=http_code.value, content_type=content_types.APPLICATION_JSON, - body={"statusCode": self._response_validation_error_http_code, "detail": errors}, + body={"statusCode": http_code, "detail": errors}, ), serializer=self._serializer, route=route, @@ -2683,6 +2713,7 @@ def route( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: def register_route(func: AnyCallableT) -> AnyCallableT: # All dict keys needs to be hashable. So we'll need to do some conversions: @@ -2708,6 +2739,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT: frozen_security, frozen_openapi_extensions, deprecated, + custom_response_validation_http_code, ) # Collate Middleware for routes @@ -2795,6 +2827,7 @@ def route( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: # NOTE: see #1552 for more context. return super().route( @@ -2814,6 +2847,7 @@ def route( openapi_extensions, deprecated, middlewares, + custom_response_validation_http_code, ) # Override _compile_regex to exclude trailing slashes for route resolution diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index a1490eb23ec..289b0fa1c86 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -150,6 +150,7 @@ def _handle_response(self, *, route: Route, response: Response): response.body = self._serialize_response( field=route.dependant.return_param, response_content=response.body, + has_route_custom_response_validation=route.custom_response_validation_http_code is not None, ) return response @@ -165,6 +166,7 @@ def _serialize_response( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + has_route_custom_response_validation: bool = False, ) -> Any: """ Serialize the response content according to the field type. @@ -174,7 +176,13 @@ def _serialize_response( value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) if errors: if self._has_response_validation_error: - raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content) + raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app") + if has_route_custom_response_validation: + raise ResponseValidationError( + errors=_normalize_errors(errors), + body=response_content, + source="route", + ) raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) if hasattr(field, "serialize"): diff --git a/aws_lambda_powertools/event_handler/openapi/exceptions.py b/aws_lambda_powertools/event_handler/openapi/exceptions.py index 046a270cdf7..b06141af47e 100644 --- a/aws_lambda_powertools/event_handler/openapi/exceptions.py +++ b/aws_lambda_powertools/event_handler/openapi/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence +from typing import Any, Literal, Sequence class ValidationException(Exception): @@ -28,9 +28,10 @@ class ResponseValidationError(ValidationException): Raised when the response body does not match the OpenAPI schema """ - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + def __init__(self, errors: Sequence[Any], *, body: Any = None, source: Literal["route", "app"] = "app") -> None: super().__init__(errors) self.body = body + self.source = source class SerializationError(Exception): diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 4103a301020..3c364d4a26e 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1378,3 +1378,67 @@ def test_custom_response_validation_error_bad_http_code(response_validation_erro str(exception_info.value) == f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code." ) + + +def test_custom_route_response_validation_error_http_code_invalid_response_incomplete_model(gw_event): + # GIVEN an APIGatewayRestResolver with custom response validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/incomplete_model_not_allowed") + def handler_incomplete_model_not_allowed() -> Model: + return {"age": 18} # type: ignore + + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=500, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + # WHEN returning incomplete model for a non-Optional type + gw_event["path"] = "/incomplete_model_not_allowed" + result = app(gw_event, {}) + + gw_event["path"] = "/custom_incomplete_model_not_allowed" + custom_result = app(gw_event, {}) + + # THEN it should return a validation error with the custom status code provided + assert result["statusCode"] == 422 + assert custom_result["statusCode"] == 500 + assert json.loads(result["body"])["detail"] == json.loads(custom_result["body"])["detail"] + + +def test_custom_route_response_validation_error_sanitized_response(gw_event): + # GIVEN an APIGatewayRestResolver with custom response validation enabled + # with a sanitized response validation error response + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=422, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + @app.exception_handler(ResponseValidationError) + def handle_response_validation_error(ex: ResponseValidationError): + return Response( + status_code=500, + body="Unexpected response.", + ) + + # WHEN returning incomplete model for a non-Optional type + gw_event["path"] = "/custom_incomplete_model_not_allowed" + result = app(gw_event, {}) + + # THEN it should return the sanitized response + assert result["statusCode"] == 500 + assert result["body"] == "Unexpected response." From ce7be15669d43b18b29299acf0c6c305d186663c Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Tue, 18 Mar 2025 00:46:40 +0000 Subject: [PATCH 10/23] fix: except Route implementation --- .../event_handler/api_gateway.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index f5951c4cc82..7e31d68cdc8 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -320,7 +320,7 @@ def __init__( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Response]] | None = None, - custom_response_validation_http_code: int | HTTPStatus | None = None, + custom_response_validation_http_code: HTTPStatus | None = None, ): """ Internally used Route Configuration @@ -399,7 +399,7 @@ def __init__( # _body_field is used to cache the dependant model for the body field self._body_field: ModelField | None = None - self.custom_response_validation_http_code: int | HTTPStatus | None = custom_response_validation_http_code + self.custom_response_validation_http_code: HTTPStatus | None = custom_response_validation_http_code def __call__( self, @@ -2127,6 +2127,29 @@ def swagger_handler(): body=body, ) + def _validate_route_response_validation_error_http_code( + self, + custom_response_validation_http_code: int | HTTPStatus | None, + ) -> HTTPStatus | None: + if custom_response_validation_http_code and not self._enable_validation: + msg = ( + "'custom_response_validation_http_code' cannot be set for route when enable_validation is False " + "on resolver." + ) + raise ValueError(msg) + + if ( + not isinstance(custom_response_validation_http_code, HTTPStatus) + and custom_response_validation_http_code is not None + ): + try: + custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code) + except ValueError: + msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code." + raise ValueError(msg) from None + + return custom_response_validation_http_code + def route( self, rule: str, From 95d9aee2a2abb6f13b9ee3f013281359bcedc2dd Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Fri, 21 Mar 2025 21:47:38 +0000 Subject: [PATCH 11/23] fix: put custom_response_validation_http_code before middleware --- .../event_handler/api_gateway.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 7e31d68cdc8..4df0ff71b5a 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -319,8 +319,8 @@ def __init__( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable[..., Response]] | None = None, custom_response_validation_http_code: HTTPStatus | None = None, + middlewares: list[Callable[..., Response]] | None = None, ): """ Internally used Route Configuration @@ -949,7 +949,6 @@ def route( openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, - custom_response_validation_http_code: int | HTTPStatus | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: raise NotImplementedError() @@ -1010,8 +1009,8 @@ def get( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable[..., Any]] | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, + middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Get route decorator with GET `method` @@ -1051,8 +1050,8 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, - middlewares, custom_response_validation_http_code, + middlewares, ) def post( @@ -1071,8 +1070,8 @@ def post( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable[..., Any]] | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, + middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Post route decorator with POST `method` @@ -1113,8 +1112,8 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, - middlewares, custom_response_validation_http_code, + middlewares, ) def put( @@ -1133,8 +1132,8 @@ def put( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable[..., Any]] | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, + middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Put route decorator with PUT `method` @@ -1175,8 +1174,8 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, - middlewares, custom_response_validation_http_code, + middlewares, ) def delete( @@ -1195,8 +1194,8 @@ def delete( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable[..., Any]] | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, + middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Delete route decorator with DELETE `method` @@ -1236,8 +1235,8 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, - middlewares, custom_response_validation_http_code, + middlewares, ) def patch( @@ -1256,8 +1255,8 @@ def patch( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable] | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, + middlewares: list[Callable] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Patch route decorator with PATCH `method` @@ -1300,8 +1299,8 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, - middlewares, custom_response_validation_http_code, + middlewares, ) def head( @@ -1320,8 +1319,8 @@ def head( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable] | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, + middlewares: list[Callable] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Head route decorator with HEAD `method` @@ -1363,8 +1362,8 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, - middlewares, custom_response_validation_http_code, + middlewares, ) def _push_processed_stack_frame(self, frame: str): @@ -2167,8 +2166,8 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable[..., Any]] | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, + middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Route decorator includes parameter `method`""" @@ -2201,8 +2200,8 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT: security, openapi_extensions, deprecated, - middlewares, custom_response_validation_http_code, + middlewares, ) # The more specific route wins. @@ -2735,8 +2734,8 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable[..., Any]] | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, + middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: def register_route(func: AnyCallableT) -> AnyCallableT: # All dict keys needs to be hashable. So we'll need to do some conversions: @@ -2849,8 +2848,8 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - middlewares: list[Callable[..., Any]] | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, + middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: # NOTE: see #1552 for more context. return super().route( @@ -2869,8 +2868,8 @@ def route( security, openapi_extensions, deprecated, - middlewares, custom_response_validation_http_code, + middlewares, ) # Override _compile_regex to exclude trailing slashes for route resolution From 210b765510a9eacdbb64663f7e40c2ab60e19a46 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Sun, 23 Mar 2025 17:43:23 +0000 Subject: [PATCH 12/23] feat: route's custom response validation must take precedence over app's. --- .../event_handler/api_gateway.py | 20 +++++++++--- .../middlewares/openapi_validation.py | 6 ++-- .../event_handler/openapi/types.py | 12 +++++++ .../test_openapi_validation_middleware.py | 31 +++++++++++++++++-- 4 files changed, 60 insertions(+), 9 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 4df0ff71b5a..0d1aceba627 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -319,7 +319,7 @@ def __init__( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - custom_response_validation_http_code: HTTPStatus | None = None, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Response]] | None = None, ): """ @@ -361,6 +361,8 @@ def __init__( Additional OpenAPI extensions as a dictionary. deprecated: bool Whether or not to mark this route as deprecated in the OpenAPI schema + custom_response_validation_http_code: int | HTTPStatus | None, optional + Whether to have custom http status code for this route if response validation fails middlewares: list[Callable[..., Response]] | None The list of route middlewares to be called in order. # TODO @@ -569,7 +571,13 @@ def _get_openapi_path( }, } - # TODO update responses + # Add custom response validation response, if exists + if self.custom_response_validation_http_code: + http_code = self.custom_response_validation_http_code.value + operation_responses[http_code] = { + "description": "Response Validation Error", + "content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}}, + } # Add the response to the OpenAPI operation if self.responses: @@ -948,6 +956,7 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: raise NotImplementedError() @@ -2573,10 +2582,11 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild # 'self._response_validation_error_http_code' is not None or # when route has custom_response_validation_http_code if isinstance(exp, ResponseValidationError): + # route validation must take precedence over app validation http_code = ( - self._response_validation_error_http_code - if exp.source == "app" - else route.custom_response_validation_http_code + route.custom_response_validation_http_code + if exp.source == "route" + else self._response_validation_error_http_code ) errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()] return self._response_builder_class( diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 289b0fa1c86..137cd703d4b 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -175,14 +175,16 @@ def _serialize_response( errors: list[dict[str, Any]] = [] value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) if errors: - if self._has_response_validation_error: - raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app") + # route-level validation must take precedence over app-level if has_route_custom_response_validation: raise ResponseValidationError( errors=_normalize_errors(errors), body=response_content, source="route", ) + if self._has_response_validation_error: + raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app") + raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) if hasattr(field, "serialize"): diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py index 0f8d55e8158..428c38ab3cd 100644 --- a/aws_lambda_powertools/event_handler/openapi/types.py +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -49,6 +49,18 @@ }, } +response_validation_error_response_definition = { + "title": "ResponseValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": f"{COMPONENT_REF_PREFIX}ValidationError"}, + }, + }, +} + class OpenAPIResponseContentSchema(TypedDict, total=False): schema: dict diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 3c364d4a26e..3c017156b28 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1380,7 +1380,7 @@ def test_custom_response_validation_error_bad_http_code(response_validation_erro ) -def test_custom_route_response_validation_error_http_code_invalid_response_incomplete_model(gw_event): +def test_custom_route_response_validation_error__custom_route_in_app_with_default_validation(gw_event): # GIVEN an APIGatewayRestResolver with custom response validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -1412,7 +1412,7 @@ def handler_custom_route_response_validation_error() -> Model: assert json.loads(result["body"])["detail"] == json.loads(custom_result["body"])["detail"] -def test_custom_route_response_validation_error_sanitized_response(gw_event): +def test_custom_route_response_validation_error__sanitized_response(gw_event): # GIVEN an APIGatewayRestResolver with custom response validation enabled # with a sanitized response validation error response app = APIGatewayRestResolver(enable_validation=True) @@ -1442,3 +1442,30 @@ def handle_response_validation_error(ex: ResponseValidationError): # THEN it should return the sanitized response assert result["statusCode"] == 500 assert result["body"] == "Unexpected response." + + +def test_custom_route_response_validation_error__in_app_with_custom_validation_code(gw_event): + # GIVEN an APIGatewayRestResolver with custom response validation enabled + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=500) + + class Model(BaseModel): + name: str + age: int + + # and a route with custom response validation + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=422, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + # WHEN returning incomplete model for a non-Optional type on route with custom response validation + gw_event["path"] = "/custom_incomplete_model_not_allowed" + result = app(gw_event, {}) + + # THEN route's custom response validation must take precedence over the app's. + assert result["statusCode"] == 422 + body = json.loads(result["body"]) + assert body["detail"][0]["type"] == "missing" + assert body["detail"][0]["loc"] == ["response", "name"] From 575e7139cc6ac708f773b8be62acbb791c29dcd9 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Sun, 23 Mar 2025 18:23:34 +0000 Subject: [PATCH 13/23] feat: added more tests. --- .../test_openapi_validation_middleware.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 3c017156b28..5dedbc32792 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1469,3 +1469,52 @@ def handler_custom_route_response_validation_error() -> Model: body = json.loads(result["body"]) assert body["detail"][0]["type"] == "missing" assert body["detail"][0]["loc"] == ["response", "name"] + + +def test_custom_route_response_validation__error_no_app_validation(): + # GIVEN an APIGatewayRestResolver with validation not enabled + with pytest.raises(ValueError) as exception_info: + app = APIGatewayRestResolver() + + class Model(BaseModel): + name: str + age: int + + # HAVING a route with custom response validation http code + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=422, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + # THEN it must raise ValueError describing the issue + assert ( + str(exception_info.value) + == "'custom_response_validation_http_code' cannot be set for route when enable_validation is False on resolver." + ) + + +@pytest.mark.parametrize("response_validation_error_http_code", [(20), ("hi"), (1.21), (True), (False)]) +def test_custom_route_response_validation__error_bad_http_code(response_validation_error_http_code): + # GIVEN an APIGatewayRestResolver with validation enabled + with pytest.raises(ValueError) as exception_info: + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # HAVING a route with custom response validation which is not a valid HTTP code + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=response_validation_error_http_code, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + # THEN it must raise ValueError describing the issue + assert ( + str(exception_info.value) + == f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501 + ) From 440a3f43b1a359ad3846ee5fed78c7244fd370cf Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Sun, 23 Mar 2025 18:25:08 +0000 Subject: [PATCH 14/23] refactor: improved error messagee and tests' descriptions. --- .../event_handler/api_gateway.py | 2 +- .../test_openapi_validation_middleware.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 0d1aceba627..33a3042b924 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -2153,7 +2153,7 @@ def _validate_route_response_validation_error_http_code( try: custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code) except ValueError: - msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code." + msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501 raise ValueError(msg) from None return custom_response_validation_http_code diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 5dedbc32792..ce535f761d1 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1381,7 +1381,7 @@ def test_custom_response_validation_error_bad_http_code(response_validation_erro def test_custom_route_response_validation_error__custom_route_in_app_with_default_validation(gw_event): - # GIVEN an APIGatewayRestResolver with custom response validation enabled + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): @@ -1392,6 +1392,7 @@ class Model(BaseModel): def handler_incomplete_model_not_allowed() -> Model: return {"age": 18} # type: ignore + # HAVING route with custom response validation error @app.get( "/custom_incomplete_model_not_allowed", custom_response_validation_http_code=500, @@ -1406,7 +1407,7 @@ def handler_custom_route_response_validation_error() -> Model: gw_event["path"] = "/custom_incomplete_model_not_allowed" custom_result = app(gw_event, {}) - # THEN it should return a validation error with the custom status code provided + # THEN it must return a validation error with the custom status code provided assert result["statusCode"] == 422 assert custom_result["statusCode"] == 500 assert json.loads(result["body"])["detail"] == json.loads(custom_result["body"])["detail"] @@ -1414,7 +1415,6 @@ def handler_custom_route_response_validation_error() -> Model: def test_custom_route_response_validation_error__sanitized_response(gw_event): # GIVEN an APIGatewayRestResolver with custom response validation enabled - # with a sanitized response validation error response app = APIGatewayRestResolver(enable_validation=True) class Model(BaseModel): @@ -1428,6 +1428,7 @@ class Model(BaseModel): def handler_custom_route_response_validation_error() -> Model: return {"age": 18} # type: ignore + # HAVING a sanitized response validation error response @app.exception_handler(ResponseValidationError) def handle_response_validation_error(ex: ResponseValidationError): return Response( @@ -1439,20 +1440,20 @@ def handle_response_validation_error(ex: ResponseValidationError): gw_event["path"] = "/custom_incomplete_model_not_allowed" result = app(gw_event, {}) - # THEN it should return the sanitized response + # THEN it must return the sanitized response assert result["statusCode"] == 500 assert result["body"] == "Unexpected response." -def test_custom_route_response_validation_error__in_app_with_custom_validation_code(gw_event): - # GIVEN an APIGatewayRestResolver with custom response validation enabled +def test_custom_route_response_validation_error__with_app_custom_response_validation(gw_event): + # GIVEN an APIGatewayRestResolver with validation and custom response validation enabled app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=500) class Model(BaseModel): name: str age: int - # and a route with custom response validation + # HAVING a route with custom response validation @app.get( "/custom_incomplete_model_not_allowed", custom_response_validation_http_code=422, From 249554f97b20823e68a7f462d4f914b684826526 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Tue, 25 Mar 2025 07:46:30 +0000 Subject: [PATCH 15/23] feat: updated docs. --- docs/core/event_handler/api_gateway.md | 13 +++++++++++++ .../src/customizing_response_validation.py | 12 ++++++++++++ 2 files changed, 25 insertions(+) diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 4919598b3ec..e058683ab0b 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -400,8 +400,21 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`. +For a more granular control over the failed response validation http code, the `custom_response_validation_http_code` argument can be set per route. +This value will override the value of the failed response validation http code set at app-level (if any). + === "customizing_response_validation.py" + ```python hl_lines="1 16 29 33 38" + --8<-- "examples/event_handler_rest/src/customizing_response_validation.py" + ``` + + 1. A response with status code set here will be returned if response data is not valid. + 2. Operation returns a string as oppose to a `Todo` object. This will lead to a `500` response as set in line 18. + 3. Operation will return a `422 Unprocessable Entity` response if response is not a `Todo` object. This overrides the custom http code set in line 16. + +=== "customizing_route_response_validation.py" + ```python hl_lines="1 16 29 33" --8<-- "examples/event_handler_rest/src/customizing_response_validation.py" ``` diff --git a/examples/event_handler_rest/src/customizing_response_validation.py b/examples/event_handler_rest/src/customizing_response_validation.py index 2b7b2c16c9f..25aa07bf52a 100644 --- a/examples/event_handler_rest/src/customizing_response_validation.py +++ b/examples/event_handler_rest/src/customizing_response_validation.py @@ -33,6 +33,18 @@ def get_todo_by_id(todo_id: int) -> Todo: return todo.json()["title"] # (2)! +@app.get( + "/todos_bad_response_with_custom_http_code/", + custom_response_validation_http_code=HTTPStatus.UNPROCESSABLE_ENTITY, # (3)! +) +@tracer.capture_method +def get_todo_by_id_custom(todo_id: int) -> Todo: + todo = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}") + todo.raise_for_status() + + return todo.json()["title"] + + @logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP) @tracer.capture_lambda_handler def lambda_handler(event: dict, context: LambdaContext) -> dict: From d0eadf016edd3fe35c297dcc384a1628bba7196a Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Tue, 25 Mar 2025 08:30:47 +0000 Subject: [PATCH 16/23] move veritifcation method of route custom http code to BaseRouter. --- .../event_handler/api_gateway.py | 59 ++++++++++--------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 33a3042b924..5f3b17afdd8 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -319,7 +319,7 @@ def __init__( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, - custom_response_validation_http_code: int | HTTPStatus | None = None, + custom_response_validation_http_code: HTTPStatus | None = None, middlewares: list[Callable[..., Response]] | None = None, ): """ @@ -361,7 +361,7 @@ def __init__( Additional OpenAPI extensions as a dictionary. deprecated: bool Whether or not to mark this route as deprecated in the OpenAPI schema - custom_response_validation_http_code: int | HTTPStatus | None, optional + custom_response_validation_http_code: HTTPStatus | None, optional Whether to have custom http status code for this route if response validation fails middlewares: list[Callable[..., Response]] | None The list of route middlewares to be called in order. @@ -401,7 +401,7 @@ def __init__( # _body_field is used to cache the dependant model for the body field self._body_field: ModelField | None = None - self.custom_response_validation_http_code: HTTPStatus | None = custom_response_validation_http_code + self.custom_response_validation_http_code = custom_response_validation_http_code def __call__( self, @@ -511,7 +511,7 @@ def body_field(self) -> ModelField | None: return self._body_field - def _get_openapi_path( + def _get_openapi_path( # noqa: PLR0912 self, *, dependant: Dependant, @@ -938,6 +938,29 @@ class BaseRouter(ABC): _router_middlewares: list[Callable] = [] processed_stack_frames: list[str] = [] + def _validate_route_response_validation_error_http_code( + self, + custom_response_validation_http_code: int | HTTPStatus | None, + ) -> HTTPStatus | None: + if custom_response_validation_http_code and not self._enable_validation: + msg = ( + "'custom_response_validation_http_code' cannot be set for route when enable_validation is False " + "on resolver." + ) + raise ValueError(msg) + + if ( + not isinstance(custom_response_validation_http_code, HTTPStatus) + and custom_response_validation_http_code is not None + ): + try: + custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code) + except ValueError: + msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501 + raise ValueError(msg) from None + + return custom_response_validation_http_code + @abstractmethod def route( self, @@ -2135,29 +2158,6 @@ def swagger_handler(): body=body, ) - def _validate_route_response_validation_error_http_code( - self, - custom_response_validation_http_code: int | HTTPStatus | None, - ) -> HTTPStatus | None: - if custom_response_validation_http_code and not self._enable_validation: - msg = ( - "'custom_response_validation_http_code' cannot be set for route when enable_validation is False " - "on resolver." - ) - raise ValueError(msg) - - if ( - not isinstance(custom_response_validation_http_code, HTTPStatus) - and custom_response_validation_http_code is not None - ): - try: - custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code) - except ValueError: - msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501 - raise ValueError(msg) from None - - return custom_response_validation_http_code - def route( self, rule: str, @@ -2754,6 +2754,9 @@ def register_route(func: AnyCallableT) -> AnyCallableT: frozen_tags = frozenset(tags) if tags else None frozen_security = _FrozenListDict(security) if security else None frozen_openapi_extensions = _FrozenDict(openapi_extensions) if openapi_extensions else None + response_validation_http_code = self._validate_route_response_validation_error_http_code( + custom_response_validation_http_code, + ) route_key = ( rule, @@ -2771,7 +2774,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT: frozen_security, frozen_openapi_extensions, deprecated, - custom_response_validation_http_code, + response_validation_http_code, ) # Collate Middleware for routes From 59bb4aa8960df2d053f68f944cbc65635e781faa Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Tue, 25 Mar 2025 09:17:33 +0000 Subject: [PATCH 17/23] fix: add validate function for route http code to APIGatewayResolver not Router --- .../event_handler/api_gateway.py | 51 +++++++++---------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 5f3b17afdd8..4f38201ddb5 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -938,29 +938,6 @@ class BaseRouter(ABC): _router_middlewares: list[Callable] = [] processed_stack_frames: list[str] = [] - def _validate_route_response_validation_error_http_code( - self, - custom_response_validation_http_code: int | HTTPStatus | None, - ) -> HTTPStatus | None: - if custom_response_validation_http_code and not self._enable_validation: - msg = ( - "'custom_response_validation_http_code' cannot be set for route when enable_validation is False " - "on resolver." - ) - raise ValueError(msg) - - if ( - not isinstance(custom_response_validation_http_code, HTTPStatus) - and custom_response_validation_http_code is not None - ): - try: - custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code) - except ValueError: - msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501 - raise ValueError(msg) from None - - return custom_response_validation_http_code - @abstractmethod def route( self, @@ -2158,6 +2135,29 @@ def swagger_handler(): body=body, ) + def _validate_route_response_validation_error_http_code( + self, + custom_response_validation_http_code: int | HTTPStatus | None, + ) -> HTTPStatus | None: + if custom_response_validation_http_code and not self._enable_validation: + msg = ( + "'custom_response_validation_http_code' cannot be set for route when enable_validation is False " + "on resolver." + ) + raise ValueError(msg) + + if ( + not isinstance(custom_response_validation_http_code, HTTPStatus) + and custom_response_validation_http_code is not None + ): + try: + custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code) + except ValueError: + msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501 + raise ValueError(msg) from None + + return custom_response_validation_http_code + def route( self, rule: str, @@ -2754,9 +2754,6 @@ def register_route(func: AnyCallableT) -> AnyCallableT: frozen_tags = frozenset(tags) if tags else None frozen_security = _FrozenListDict(security) if security else None frozen_openapi_extensions = _FrozenDict(openapi_extensions) if openapi_extensions else None - response_validation_http_code = self._validate_route_response_validation_error_http_code( - custom_response_validation_http_code, - ) route_key = ( rule, @@ -2774,7 +2771,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT: frozen_security, frozen_openapi_extensions, deprecated, - response_validation_http_code, + custom_response_validation_http_code, ) # Collate Middleware for routes From 020c973f9f09280bea04dddaa9beb97e285d05b9 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Tue, 25 Mar 2025 09:42:08 +0000 Subject: [PATCH 18/23] feat: add custom_response_validation_http_code to the routes of Bedrock --- aws_lambda_powertools/event_handler/bedrock_agent.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 215199e0022..e4f41bd38db 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -14,6 +14,7 @@ from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION if TYPE_CHECKING: + from http import HTTPStatus from re import Match from aws_lambda_powertools.event_handler.openapi.models import Contact, License, SecurityScheme, Server, Tag @@ -109,6 +110,7 @@ def get( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: openapi_extensions = None @@ -129,6 +131,7 @@ def get( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -148,6 +151,7 @@ def post( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ): openapi_extensions = None @@ -168,6 +172,7 @@ def post( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -187,6 +192,7 @@ def put( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ): openapi_extensions = None @@ -207,6 +213,7 @@ def put( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -226,6 +233,7 @@ def patch( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable] | None = None, ): openapi_extensions = None @@ -246,6 +254,7 @@ def patch( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -265,6 +274,7 @@ def delete( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ): openapi_extensions = None @@ -285,6 +295,7 @@ def delete( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) From 5ea8ffafa18aae466b37544cce082c51f3668321 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Tue, 25 Mar 2025 09:42:31 +0000 Subject: [PATCH 19/23] fix: make mypy happy --- aws_lambda_powertools/event_handler/api_gateway.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 4f38201ddb5..33a6cab52e5 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -2583,9 +2583,10 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild # when route has custom_response_validation_http_code if isinstance(exp, ResponseValidationError): # route validation must take precedence over app validation + route_response_validation_http_code = route.custom_response_validation_http_code http_code = ( - route.custom_response_validation_http_code - if exp.source == "route" + route_response_validation_http_code + if route_response_validation_http_code else self._response_validation_error_http_code ) errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()] From fca7db0f0ac8078f35d5ac91d5e0670205c7b18e Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Wed, 9 Apr 2025 21:11:59 +0100 Subject: [PATCH 20/23] fix: address comments --- aws_lambda_powertools/event_handler/api_gateway.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 33a6cab52e5..331be42c8c3 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -35,6 +35,7 @@ OpenAPIResponse, OpenAPIResponseContentModel, OpenAPIResponseContentSchema, + response_validation_error_response_definition, validation_error_definition, validation_error_response_definition, ) @@ -365,7 +366,6 @@ def __init__( Whether to have custom http status code for this route if response validation fails middlewares: list[Callable[..., Response]] | None The list of route middlewares to be called in order. - # TODO """ self.method = method.upper() self.path = "/" if path.strip() == "" else path @@ -650,6 +650,7 @@ def _get_openapi_path( # noqa: PLR0912 { "ValidationError": validation_error_definition, "HTTPValidationError": validation_error_response_definition, + "ResponseValidationError": response_validation_error_response_definition, }, ) From 2d7a73d1ab684a09d77e42f5d74a144d191f0625 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Wed, 9 Apr 2025 23:08:48 +0100 Subject: [PATCH 21/23] fix(openapi): add response for response validation error and definition for it --- .../event_handler/api_gateway.py | 34 +++++++++++-- .../_pydantic/test_openapi_responses.py | 48 +++++++++++++++++++ 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 331be42c8c3..4334630bbcf 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -578,6 +578,8 @@ def _get_openapi_path( # noqa: PLR0912 "description": "Response Validation Error", "content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}}, } + # Add model definition + definitions.update({"ResponseValidationError": response_validation_error_response_definition}) # Add the response to the OpenAPI operation if self.responses: @@ -650,7 +652,6 @@ def _get_openapi_path( # noqa: PLR0912 { "ValidationError": validation_error_definition, "HTTPValidationError": validation_error_response_definition, - "ResponseValidationError": response_validation_error_response_definition, }, ) @@ -1616,6 +1617,33 @@ def _validate_response_validation_error_http_code( return response_validation_error_http_code or HTTPStatus.UNPROCESSABLE_ENTITY + def _add_resolver_response_validation_error_response_to_route( + self, + route_openapi_path: tuple[dict[str, Any], dict[str, Any]], + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Adds resolver response validation error response to route's operations.""" + path, path_definitions = route_openapi_path + if self._has_response_validation_error and "ResponseValidationError" not in path_definitions: + response_validation_error_response = { + "description": "Response Validation Error", + "content": { + "application/json": { + "schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}, + }, + }, + } + http_code = self._response_validation_error_http_code.value + for operation in path.values(): + operation["responses"][http_code] = response_validation_error_response + return path, path_definitions + + def _generate_schemas(self, definitions: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]: + schemas = {k: definitions[k] for k in sorted(definitions)} + # add response validation error definition + if self._response_validation_error_http_code: + schemas.setdefault("ResponseValidationError", response_validation_error_response_definition) + return schemas + def get_openapi_schema( self, *, @@ -1767,14 +1795,14 @@ def get_openapi_schema( field_mapping=field_mapping, ) if result: - path, path_definitions = result + path, path_definitions = self._add_resolver_response_validation_error_response_to_route(result) if path: paths.setdefault(route.openapi_path, {}).update(path) if path_definitions: definitions.update(path_definitions) if definitions: - components["schemas"] = {k: definitions[k] for k in sorted(definitions)} + components["schemas"] = self._generate_schemas(definitions) if security_schemes: components["securitySchemes"] = security_schemes if components: diff --git a/tests/functional/event_handler/_pydantic/test_openapi_responses.py b/tests/functional/event_handler/_pydantic/test_openapi_responses.py index c2ab8008b5c..0ac24dcc96b 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_responses.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_responses.py @@ -170,3 +170,51 @@ def handler() -> Response[Union[User, Order]]: assert 202 in responses.keys() assert responses[202].description == "202 Response" assert responses[202].content["application/json"].schema_.ref == "#/components/schemas/Order" + + +def test_openapi_route_with_custom_response_validation(): + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/", custom_response_validation_http_code=418) + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 418 in responses + assert responses[418].description == "Response Validation Error" + + +def test_openapi_resolver_with_custom_response_validation(): + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=418) + + @app.get("/") + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 418 in responses + assert responses[418].description == "Response Validation Error" + + +def test_openapi_route_and_resolver_with_custom_response_validation(): + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=417) + + @app.get("/", custom_response_validation_http_code=418) + def handler(): + return {"message": "hello world"} + + @app.get("/hi") + def another_handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses_with_route_response_validation = schema.paths["/"].get.responses + responses_with_resolver_response_validation = schema.paths["/hi"].get.responses + assert 418 in responses_with_route_response_validation + assert 417 not in responses_with_route_response_validation + assert responses_with_route_response_validation[418].description == "Response Validation Error" + assert 417 in responses_with_resolver_response_validation + assert 418 not in responses_with_resolver_response_validation + assert responses_with_resolver_response_validation[417].description == "Response Validation Error" From ba7d6c7ea9f11e86975e6fd3b38ef5db18adf95d Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 10 Apr 2025 17:57:17 +0100 Subject: [PATCH 22/23] minor changes --- .../event_handler/api_gateway.py | 16 ++++++---------- docs/core/event_handler/api_gateway.md | 2 +- ruff.toml | 4 ++-- .../test_openapi_validation_middleware.py | 10 +++++----- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 4334630bbcf..013e00ce474 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -362,13 +362,13 @@ def __init__( Additional OpenAPI extensions as a dictionary. deprecated: bool Whether or not to mark this route as deprecated in the OpenAPI schema - custom_response_validation_http_code: HTTPStatus | None, optional + custom_response_validation_http_code: int | HTTPStatus | None, optional Whether to have custom http status code for this route if response validation fails middlewares: list[Callable[..., Response]] | None The list of route middlewares to be called in order. """ self.method = method.upper() - self.path = "/" if path.strip() == "" else path + self.path = path if path.strip() else "/" # OpenAPI spec only understands paths with { }. So we'll have to convert Powertools' < >. # https://swagger.io/specification/#path-templating @@ -511,7 +511,7 @@ def body_field(self) -> ModelField | None: return self._body_field - def _get_openapi_path( # noqa: PLR0912 + def _get_openapi_path( self, *, dependant: Dependant, @@ -579,7 +579,7 @@ def _get_openapi_path( # noqa: PLR0912 "content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}}, } # Add model definition - definitions.update({"ResponseValidationError": response_validation_error_response_definition}) + definitions["ResponseValidationError"] = response_validation_error_response_definition # Add the response to the OpenAPI operation if self.responses: @@ -1600,6 +1600,7 @@ def _validate_response_validation_error_http_code( response_validation_error_http_code: HTTPStatus | int | None, enable_validation: bool, ) -> HTTPStatus: + if response_validation_error_http_code and not enable_validation: msg = "'response_validation_error_http_code' cannot be set when enable_validation is False." raise ValueError(msg) @@ -2612,12 +2613,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild # when route has custom_response_validation_http_code if isinstance(exp, ResponseValidationError): # route validation must take precedence over app validation - route_response_validation_http_code = route.custom_response_validation_http_code - http_code = ( - route_response_validation_http_code - if route_response_validation_http_code - else self._response_validation_error_http_code - ) + http_code = route.custom_response_validation_http_code or self._response_validation_error_http_code errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()] return self._response_builder_class( response=Response( diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index e058683ab0b..72904b6e172 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -401,7 +401,7 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`. For a more granular control over the failed response validation http code, the `custom_response_validation_http_code` argument can be set per route. -This value will override the value of the failed response validation http code set at app-level (if any). +This value will override the value of the failed response validation http code set at constructor level (if any). === "customizing_response_validation.py" diff --git a/ruff.toml b/ruff.toml index b415c63f949..2af94434185 100644 --- a/ruff.toml +++ b/ruff.toml @@ -74,11 +74,11 @@ lint.typing-modules = [ [lint.mccabe] # Maximum cyclomatic complexity -max-complexity = 15 +max-complexity = 16 [lint.pylint] # Maximum number of nested blocks -max-branches = 15 +max-branches = 16 # Maximum number of if statements in a function max-statements = 70 diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index ce535f761d1..c1cc0462bf7 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1380,7 +1380,7 @@ def test_custom_response_validation_error_bad_http_code(response_validation_erro ) -def test_custom_route_response_validation_error__custom_route_in_app_with_default_validation(gw_event): +def test_custom_route_response_validation_error_custom_route_and_app_with_default_validation(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -1413,7 +1413,7 @@ def handler_custom_route_response_validation_error() -> Model: assert json.loads(result["body"])["detail"] == json.loads(custom_result["body"])["detail"] -def test_custom_route_response_validation_error__sanitized_response(gw_event): +def test_custom_route_response_validation_error_sanitized_response(gw_event): # GIVEN an APIGatewayRestResolver with custom response validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -1445,7 +1445,7 @@ def handle_response_validation_error(ex: ResponseValidationError): assert result["body"] == "Unexpected response." -def test_custom_route_response_validation_error__with_app_custom_response_validation(gw_event): +def test_custom_route_response_validation_error_with_app_custom_response_validation(gw_event): # GIVEN an APIGatewayRestResolver with validation and custom response validation enabled app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=500) @@ -1472,7 +1472,7 @@ def handler_custom_route_response_validation_error() -> Model: assert body["detail"][0]["loc"] == ["response", "name"] -def test_custom_route_response_validation__error_no_app_validation(): +def test_custom_route_response_validation_error_no_app_validation(): # GIVEN an APIGatewayRestResolver with validation not enabled with pytest.raises(ValueError) as exception_info: app = APIGatewayRestResolver() @@ -1497,7 +1497,7 @@ def handler_custom_route_response_validation_error() -> Model: @pytest.mark.parametrize("response_validation_error_http_code", [(20), ("hi"), (1.21), (True), (False)]) -def test_custom_route_response_validation__error_bad_http_code(response_validation_error_http_code): +def test_custom_route_response_validation_error_bad_http_code(response_validation_error_http_code): # GIVEN an APIGatewayRestResolver with validation enabled with pytest.raises(ValueError) as exception_info: app = APIGatewayRestResolver(enable_validation=True) From 9e04a636779ebd47b3364ccb64967d568f9da5c3 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 10 Apr 2025 18:04:19 +0100 Subject: [PATCH 23/23] minor changes --- docs/core/event_handler/api_gateway.md | 2 +- .../src/response_validation_error_unsanitized_output.json | 8 -------- .../src/response_validation_sanitized_error_output.json | 8 -------- 3 files changed, 1 insertion(+), 17 deletions(-) delete mode 100644 examples/event_handler_rest/src/response_validation_error_unsanitized_output.json delete mode 100644 examples/event_handler_rest/src/response_validation_sanitized_error_output.json diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 72904b6e172..12a2a77b48b 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -410,7 +410,7 @@ This value will override the value of the failed response validation http code s ``` 1. A response with status code set here will be returned if response data is not valid. - 2. Operation returns a string as oppose to a `Todo` object. This will lead to a `500` response as set in line 18. + 2. Operation returns a string as oppose to a `Todo` object. This will lead to a `500` response as set in line 16. 3. Operation will return a `422 Unprocessable Entity` response if response is not a `Todo` object. This overrides the custom http code set in line 16. === "customizing_route_response_validation.py" diff --git a/examples/event_handler_rest/src/response_validation_error_unsanitized_output.json b/examples/event_handler_rest/src/response_validation_error_unsanitized_output.json deleted file mode 100644 index c2fbe3df339..00000000000 --- a/examples/event_handler_rest/src/response_validation_error_unsanitized_output.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "statusCode": 500, - "body": "{\"statusCode\": 500, \"detail\": [{\"type\": \"model_attributes_type\", \"loc\": [\"response\", ]}]}", - "isBase64Encoded": false, - "headers": { - "Content-Type": "application/json" - } -} \ No newline at end of file diff --git a/examples/event_handler_rest/src/response_validation_sanitized_error_output.json b/examples/event_handler_rest/src/response_validation_sanitized_error_output.json deleted file mode 100644 index 79c97da7498..00000000000 --- a/examples/event_handler_rest/src/response_validation_sanitized_error_output.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "statusCode": 500, - "body": "Unexpected response.", - "isBase64Encoded": false, - "headers": { - "Content-Type": "application/json" - } -} \ No newline at end of file