diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index c8e4248fda4..013e00ce474 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -35,6 +35,7 @@ OpenAPIResponse, OpenAPIResponseContentModel, OpenAPIResponseContentSchema, + response_validation_error_response_definition, validation_error_definition, validation_error_response_definition, ) @@ -319,6 +320,7 @@ def __init__( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: HTTPStatus | None = None, middlewares: list[Callable[..., Response]] | None = None, ): """ @@ -360,11 +362,13 @@ def __init__( Additional OpenAPI extensions as a dictionary. deprecated: bool Whether or not to mark this route as deprecated in the OpenAPI schema + custom_response_validation_http_code: int | HTTPStatus | None, optional + Whether to have custom http status code for this route if response validation fails middlewares: list[Callable[..., Response]] | None The list of route middlewares to be called in order. """ self.method = method.upper() - self.path = "/" if path.strip() == "" else path + self.path = path if path.strip() else "/" # OpenAPI spec only understands paths with { }. So we'll have to convert Powertools' < >. # https://swagger.io/specification/#path-templating @@ -397,6 +401,8 @@ def __init__( # _body_field is used to cache the dependant model for the body field self._body_field: ModelField | None = None + self.custom_response_validation_http_code = custom_response_validation_http_code + def __call__( self, router_middlewares: list[Callable], @@ -565,6 +571,16 @@ def _get_openapi_path( }, } + # Add custom response validation response, if exists + if self.custom_response_validation_http_code: + http_code = self.custom_response_validation_http_code.value + operation_responses[http_code] = { + "description": "Response Validation Error", + "content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}}, + } + # Add model definition + definitions["ResponseValidationError"] = response_validation_error_response_definition + # Add the response to the OpenAPI operation if self.responses: for status_code in list(self.responses): @@ -942,6 +958,7 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: raise NotImplementedError() @@ -1003,6 +1020,7 @@ def get( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Get route decorator with GET `method` @@ -1043,6 +1061,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -1062,6 +1081,7 @@ def post( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Post route decorator with POST `method` @@ -1103,6 +1123,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -1122,6 +1143,7 @@ def put( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Put route decorator with PUT `method` @@ -1163,6 +1185,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -1182,6 +1205,7 @@ def delete( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Delete route decorator with DELETE `method` @@ -1222,6 +1246,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -1241,6 +1266,7 @@ def patch( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Patch route decorator with PATCH `method` @@ -1284,6 +1310,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -1303,6 +1330,7 @@ def head( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Head route decorator with HEAD `method` @@ -1345,6 +1373,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -1571,6 +1600,7 @@ def _validate_response_validation_error_http_code( response_validation_error_http_code: HTTPStatus | int | None, enable_validation: bool, ) -> HTTPStatus: + if response_validation_error_http_code and not enable_validation: msg = "'response_validation_error_http_code' cannot be set when enable_validation is False." raise ValueError(msg) @@ -1588,6 +1618,33 @@ def _validate_response_validation_error_http_code( return response_validation_error_http_code or HTTPStatus.UNPROCESSABLE_ENTITY + def _add_resolver_response_validation_error_response_to_route( + self, + route_openapi_path: tuple[dict[str, Any], dict[str, Any]], + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Adds resolver response validation error response to route's operations.""" + path, path_definitions = route_openapi_path + if self._has_response_validation_error and "ResponseValidationError" not in path_definitions: + response_validation_error_response = { + "description": "Response Validation Error", + "content": { + "application/json": { + "schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}, + }, + }, + } + http_code = self._response_validation_error_http_code.value + for operation in path.values(): + operation["responses"][http_code] = response_validation_error_response + return path, path_definitions + + def _generate_schemas(self, definitions: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]: + schemas = {k: definitions[k] for k in sorted(definitions)} + # add response validation error definition + if self._response_validation_error_http_code: + schemas.setdefault("ResponseValidationError", response_validation_error_response_definition) + return schemas + def get_openapi_schema( self, *, @@ -1739,14 +1796,14 @@ def get_openapi_schema( field_mapping=field_mapping, ) if result: - path, path_definitions = result + path, path_definitions = self._add_resolver_response_validation_error_response_to_route(result) if path: paths.setdefault(route.openapi_path, {}).update(path) if path_definitions: definitions.update(path_definitions) if definitions: - components["schemas"] = {k: definitions[k] for k in sorted(definitions)} + components["schemas"] = self._generate_schemas(definitions) if security_schemes: components["securitySchemes"] = security_schemes if components: @@ -2108,6 +2165,29 @@ def swagger_handler(): body=body, ) + def _validate_route_response_validation_error_http_code( + self, + custom_response_validation_http_code: int | HTTPStatus | None, + ) -> HTTPStatus | None: + if custom_response_validation_http_code and not self._enable_validation: + msg = ( + "'custom_response_validation_http_code' cannot be set for route when enable_validation is False " + "on resolver." + ) + raise ValueError(msg) + + if ( + not isinstance(custom_response_validation_http_code, HTTPStatus) + and custom_response_validation_http_code is not None + ): + try: + custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code) + except ValueError: + msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501 + raise ValueError(msg) from None + + return custom_response_validation_http_code + def route( self, rule: str, @@ -2125,10 +2205,15 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Route decorator includes parameter `method`""" + custom_response_validation_http_code = self._validate_route_response_validation_error_http_code( + custom_response_validation_http_code, + ) + def register_resolver(func: AnyCallableT) -> AnyCallableT: methods = (method,) if isinstance(method, str) else method logger.debug(f"Adding route using rule {rule} and methods: {','.join(m.upper() for m in methods)}") @@ -2154,6 +2239,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT: security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -2523,15 +2609,17 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild ) # OpenAPIValidationMiddleware will only raise ResponseValidationError when - # 'self._response_validation_error_http_code' is not None + # 'self._response_validation_error_http_code' is not None or + # when route has custom_response_validation_http_code if isinstance(exp, ResponseValidationError): - http_code = self._response_validation_error_http_code + # route validation must take precedence over app validation + http_code = route.custom_response_validation_http_code or self._response_validation_error_http_code errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()] return self._response_builder_class( response=Response( status_code=http_code.value, content_type=content_types.APPLICATION_JSON, - body={"statusCode": self._response_validation_error_http_code, "detail": errors}, + body={"statusCode": http_code, "detail": errors}, ), serializer=self._serializer, route=route, @@ -2682,6 +2770,7 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: def register_route(func: AnyCallableT) -> AnyCallableT: @@ -2708,6 +2797,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT: frozen_security, frozen_openapi_extensions, deprecated, + custom_response_validation_http_code, ) # Collate Middleware for routes @@ -2794,6 +2884,7 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: # NOTE: see #1552 for more context. @@ -2813,6 +2904,7 @@ def route( security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 215199e0022..e4f41bd38db 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -14,6 +14,7 @@ from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION if TYPE_CHECKING: + from http import HTTPStatus from re import Match from aws_lambda_powertools.event_handler.openapi.models import Contact, License, SecurityScheme, Server, Tag @@ -109,6 +110,7 @@ def get( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: openapi_extensions = None @@ -129,6 +131,7 @@ def get( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -148,6 +151,7 @@ def post( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ): openapi_extensions = None @@ -168,6 +172,7 @@ def post( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -187,6 +192,7 @@ def put( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ): openapi_extensions = None @@ -207,6 +213,7 @@ def put( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -226,6 +233,7 @@ def patch( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable] | None = None, ): openapi_extensions = None @@ -246,6 +254,7 @@ def patch( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) @@ -265,6 +274,7 @@ def delete( # type: ignore[override] operation_id: str | None = None, include_in_schema: bool = True, deprecated: bool = False, + custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ): openapi_extensions = None @@ -285,6 +295,7 @@ def delete( # type: ignore[override] security, openapi_extensions, deprecated, + custom_response_validation_http_code, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index a1490eb23ec..137cd703d4b 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -150,6 +150,7 @@ def _handle_response(self, *, route: Route, response: Response): response.body = self._serialize_response( field=route.dependant.return_param, response_content=response.body, + has_route_custom_response_validation=route.custom_response_validation_http_code is not None, ) return response @@ -165,6 +166,7 @@ def _serialize_response( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + has_route_custom_response_validation: bool = False, ) -> Any: """ Serialize the response content according to the field type. @@ -173,8 +175,16 @@ def _serialize_response( errors: list[dict[str, Any]] = [] value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) if errors: + # route-level validation must take precedence over app-level + if has_route_custom_response_validation: + raise ResponseValidationError( + errors=_normalize_errors(errors), + body=response_content, + source="route", + ) if self._has_response_validation_error: - raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content) + raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app") + raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) if hasattr(field, "serialize"): diff --git a/aws_lambda_powertools/event_handler/openapi/exceptions.py b/aws_lambda_powertools/event_handler/openapi/exceptions.py index 046a270cdf7..b06141af47e 100644 --- a/aws_lambda_powertools/event_handler/openapi/exceptions.py +++ b/aws_lambda_powertools/event_handler/openapi/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence +from typing import Any, Literal, Sequence class ValidationException(Exception): @@ -28,9 +28,10 @@ class ResponseValidationError(ValidationException): Raised when the response body does not match the OpenAPI schema """ - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + def __init__(self, errors: Sequence[Any], *, body: Any = None, source: Literal["route", "app"] = "app") -> None: super().__init__(errors) self.body = body + self.source = source class SerializationError(Exception): diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py index 0f8d55e8158..428c38ab3cd 100644 --- a/aws_lambda_powertools/event_handler/openapi/types.py +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -49,6 +49,18 @@ }, } +response_validation_error_response_definition = { + "title": "ResponseValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": f"{COMPONENT_REF_PREFIX}ValidationError"}, + }, + }, +} + class OpenAPIResponseContentSchema(TypedDict, total=False): schema: dict diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 4919598b3ec..12a2a77b48b 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -400,8 +400,21 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`. +For a more granular control over the failed response validation http code, the `custom_response_validation_http_code` argument can be set per route. +This value will override the value of the failed response validation http code set at constructor level (if any). + === "customizing_response_validation.py" + ```python hl_lines="1 16 29 33 38" + --8<-- "examples/event_handler_rest/src/customizing_response_validation.py" + ``` + + 1. A response with status code set here will be returned if response data is not valid. + 2. Operation returns a string as oppose to a `Todo` object. This will lead to a `500` response as set in line 16. + 3. Operation will return a `422 Unprocessable Entity` response if response is not a `Todo` object. This overrides the custom http code set in line 16. + +=== "customizing_route_response_validation.py" + ```python hl_lines="1 16 29 33" --8<-- "examples/event_handler_rest/src/customizing_response_validation.py" ``` diff --git a/examples/event_handler_rest/src/customizing_response_validation.py b/examples/event_handler_rest/src/customizing_response_validation.py index 2b7b2c16c9f..25aa07bf52a 100644 --- a/examples/event_handler_rest/src/customizing_response_validation.py +++ b/examples/event_handler_rest/src/customizing_response_validation.py @@ -33,6 +33,18 @@ def get_todo_by_id(todo_id: int) -> Todo: return todo.json()["title"] # (2)! +@app.get( + "/todos_bad_response_with_custom_http_code/", + custom_response_validation_http_code=HTTPStatus.UNPROCESSABLE_ENTITY, # (3)! +) +@tracer.capture_method +def get_todo_by_id_custom(todo_id: int) -> Todo: + todo = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}") + todo.raise_for_status() + + return todo.json()["title"] + + @logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP) @tracer.capture_lambda_handler def lambda_handler(event: dict, context: LambdaContext) -> dict: diff --git a/ruff.toml b/ruff.toml index b415c63f949..2af94434185 100644 --- a/ruff.toml +++ b/ruff.toml @@ -74,11 +74,11 @@ lint.typing-modules = [ [lint.mccabe] # Maximum cyclomatic complexity -max-complexity = 15 +max-complexity = 16 [lint.pylint] # Maximum number of nested blocks -max-branches = 15 +max-branches = 16 # Maximum number of if statements in a function max-statements = 70 diff --git a/tests/functional/event_handler/_pydantic/test_openapi_responses.py b/tests/functional/event_handler/_pydantic/test_openapi_responses.py index c2ab8008b5c..0ac24dcc96b 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_responses.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_responses.py @@ -170,3 +170,51 @@ def handler() -> Response[Union[User, Order]]: assert 202 in responses.keys() assert responses[202].description == "202 Response" assert responses[202].content["application/json"].schema_.ref == "#/components/schemas/Order" + + +def test_openapi_route_with_custom_response_validation(): + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/", custom_response_validation_http_code=418) + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 418 in responses + assert responses[418].description == "Response Validation Error" + + +def test_openapi_resolver_with_custom_response_validation(): + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=418) + + @app.get("/") + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 418 in responses + assert responses[418].description == "Response Validation Error" + + +def test_openapi_route_and_resolver_with_custom_response_validation(): + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=417) + + @app.get("/", custom_response_validation_http_code=418) + def handler(): + return {"message": "hello world"} + + @app.get("/hi") + def another_handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses_with_route_response_validation = schema.paths["/"].get.responses + responses_with_resolver_response_validation = schema.paths["/hi"].get.responses + assert 418 in responses_with_route_response_validation + assert 417 not in responses_with_route_response_validation + assert responses_with_route_response_validation[418].description == "Response Validation Error" + assert 417 in responses_with_resolver_response_validation + assert 418 not in responses_with_resolver_response_validation + assert responses_with_resolver_response_validation[417].description == "Response Validation Error" diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 4103a301020..c1cc0462bf7 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1378,3 +1378,144 @@ def test_custom_response_validation_error_bad_http_code(response_validation_erro str(exception_info.value) == f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code." ) + + +def test_custom_route_response_validation_error_custom_route_and_app_with_default_validation(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/incomplete_model_not_allowed") + def handler_incomplete_model_not_allowed() -> Model: + return {"age": 18} # type: ignore + + # HAVING route with custom response validation error + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=500, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + # WHEN returning incomplete model for a non-Optional type + gw_event["path"] = "/incomplete_model_not_allowed" + result = app(gw_event, {}) + + gw_event["path"] = "/custom_incomplete_model_not_allowed" + custom_result = app(gw_event, {}) + + # THEN it must return a validation error with the custom status code provided + assert result["statusCode"] == 422 + assert custom_result["statusCode"] == 500 + assert json.loads(result["body"])["detail"] == json.loads(custom_result["body"])["detail"] + + +def test_custom_route_response_validation_error_sanitized_response(gw_event): + # GIVEN an APIGatewayRestResolver with custom response validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=422, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + # HAVING a sanitized response validation error response + @app.exception_handler(ResponseValidationError) + def handle_response_validation_error(ex: ResponseValidationError): + return Response( + status_code=500, + body="Unexpected response.", + ) + + # WHEN returning incomplete model for a non-Optional type + gw_event["path"] = "/custom_incomplete_model_not_allowed" + result = app(gw_event, {}) + + # THEN it must return the sanitized response + assert result["statusCode"] == 500 + assert result["body"] == "Unexpected response." + + +def test_custom_route_response_validation_error_with_app_custom_response_validation(gw_event): + # GIVEN an APIGatewayRestResolver with validation and custom response validation enabled + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=500) + + class Model(BaseModel): + name: str + age: int + + # HAVING a route with custom response validation + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=422, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + # WHEN returning incomplete model for a non-Optional type on route with custom response validation + gw_event["path"] = "/custom_incomplete_model_not_allowed" + result = app(gw_event, {}) + + # THEN route's custom response validation must take precedence over the app's. + assert result["statusCode"] == 422 + body = json.loads(result["body"]) + assert body["detail"][0]["type"] == "missing" + assert body["detail"][0]["loc"] == ["response", "name"] + + +def test_custom_route_response_validation_error_no_app_validation(): + # GIVEN an APIGatewayRestResolver with validation not enabled + with pytest.raises(ValueError) as exception_info: + app = APIGatewayRestResolver() + + class Model(BaseModel): + name: str + age: int + + # HAVING a route with custom response validation http code + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=422, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + # THEN it must raise ValueError describing the issue + assert ( + str(exception_info.value) + == "'custom_response_validation_http_code' cannot be set for route when enable_validation is False on resolver." + ) + + +@pytest.mark.parametrize("response_validation_error_http_code", [(20), ("hi"), (1.21), (True), (False)]) +def test_custom_route_response_validation_error_bad_http_code(response_validation_error_http_code): + # GIVEN an APIGatewayRestResolver with validation enabled + with pytest.raises(ValueError) as exception_info: + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # HAVING a route with custom response validation which is not a valid HTTP code + @app.get( + "/custom_incomplete_model_not_allowed", + custom_response_validation_http_code=response_validation_error_http_code, + ) + def handler_custom_route_response_validation_error() -> Model: + return {"age": 18} # type: ignore + + # THEN it must raise ValueError describing the issue + assert ( + str(exception_info.value) + == f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501 + )