diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 5eba4220c22..d4cff69423d 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -19,7 +19,11 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION -from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, SchemaValidationError +from aws_lambda_powertools.event_handler.openapi.exceptions import ( + RequestValidationError, + ResponseValidationError, + SchemaValidationError, +) from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, METHODS_WITH_BODY, @@ -1496,6 +1500,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, + response_validation_error_http_code: HTTPStatus | int | None = None, ): """ Parameters @@ -1515,6 +1520,8 @@ def __init__( Each prefix can be a static string or a compiled regex pattern enable_validation: bool | None Enables validation of the request body against the route schema, by default False. + response_validation_error_http_code + Sets the returned status code if response is not validated. enable_validation must be True. """ self._proxy_type = proxy_type self._dynamic_routes: list[Route] = [] @@ -1530,6 +1537,11 @@ def __init__( self.context: dict = {} # early init as customers might add context before event resolution self.processed_stack_frames = [] self._response_builder_class = ResponseBuilder[BaseProxyEvent] + self._has_response_validation_error = response_validation_error_http_code is not None + self._response_validation_error_http_code = self._validate_response_validation_error_http_code( + response_validation_error_http_code, + enable_validation, + ) # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) @@ -1539,7 +1551,40 @@ def __init__( # Note the serializer argument: only use custom serializer if provided by the caller # Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation. - self.use([OpenAPIValidationMiddleware(validation_serializer=serializer)]) + self.use( + [ + OpenAPIValidationMiddleware( + validation_serializer=serializer, + has_response_validation_error=self._has_response_validation_error, + ), + ], + ) + + def _validate_response_validation_error_http_code( + self, + response_validation_error_http_code: HTTPStatus | int | None, + enable_validation: bool, + ) -> HTTPStatus: + if response_validation_error_http_code and not enable_validation: + msg = "'response_validation_error_http_code' cannot be set when enable_validation is False." + raise ValueError(msg) + + if ( + not isinstance(response_validation_error_http_code, HTTPStatus) + and response_validation_error_http_code is not None + ): + + try: + response_validation_error_http_code = HTTPStatus(response_validation_error_http_code) + except ValueError: + msg = f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code." + raise ValueError(msg) from None + + return ( + response_validation_error_http_code + if response_validation_error_http_code + else HTTPStatus.UNPROCESSABLE_ENTITY + ) def get_openapi_schema( self, @@ -2370,6 +2415,21 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild route=route, ) + # OpenAPIValidationMiddleware will only raise ResponseValidationError when + # 'self._response_validation_error_http_code' is not None + if isinstance(exp, ResponseValidationError): + http_code = self._response_validation_error_http_code + errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()] + return self._response_builder_class( + response=Response( + status_code=http_code.value, + content_type=content_types.APPLICATION_JSON, + body={"statusCode": self._response_validation_error_http_code, "detail": errors}, + ), + serializer=self._serializer, + route=route, + ) + if isinstance(exp, ServiceError): return self._response_builder_class( response=Response( @@ -2582,6 +2642,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, + response_validation_error_http_code: HTTPStatus | int | None = None, ): """Amazon API Gateway REST and HTTP API v1 payload resolver""" super().__init__( @@ -2591,6 +2652,7 @@ def __init__( serializer, strip_prefixes, enable_validation, + response_validation_error_http_code, ) def _get_base_path(self) -> str: @@ -2664,6 +2726,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, + response_validation_error_http_code: HTTPStatus | int | None = None, ): """Amazon API Gateway HTTP API v2 payload resolver""" super().__init__( @@ -2673,6 +2736,7 @@ def __init__( serializer, strip_prefixes, enable_validation, + response_validation_error_http_code, ) def _get_base_path(self) -> str: @@ -2701,9 +2765,18 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, + response_validation_error_http_code: HTTPStatus | int | None = None, ): """Amazon Application Load Balancer (ALB) resolver""" - super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes, enable_validation) + super().__init__( + ProxyEventType.ALBEvent, + cors, + debug, + serializer, + strip_prefixes, + enable_validation, + response_validation_error_http_code, + ) def _get_base_path(self) -> str: # ALB doesn't have a stage variable, so we just return an empty string diff --git a/aws_lambda_powertools/event_handler/lambda_function_url.py b/aws_lambda_powertools/event_handler/lambda_function_url.py index c7075cd9fc6..c761834e8b3 100644 --- a/aws_lambda_powertools/event_handler/lambda_function_url.py +++ b/aws_lambda_powertools/event_handler/lambda_function_url.py @@ -8,6 +8,8 @@ ) if TYPE_CHECKING: + from http import HTTPStatus + from aws_lambda_powertools.event_handler import CORSConfig from aws_lambda_powertools.utilities.data_classes import LambdaFunctionUrlEvent @@ -57,6 +59,7 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, + response_validation_error_http_code: HTTPStatus | int | None = None, ): super().__init__( ProxyEventType.LambdaFunctionUrlEvent, @@ -65,6 +68,7 @@ def __init__( serializer, strip_prefixes, enable_validation, + response_validation_error_http_code, ) def _get_base_path(self) -> str: diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index fc0db9a3b3f..a1490eb23ec 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -17,7 +17,7 @@ ) from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder -from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError +from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError from aws_lambda_powertools.event_handler.openapi.params import Param if TYPE_CHECKING: @@ -58,7 +58,11 @@ def get_todos(): list[Todo]: ``` """ - def __init__(self, validation_serializer: Callable[[Any], str] | None = None): + def __init__( + self, + validation_serializer: Callable[[Any], str] | None = None, + has_response_validation_error: bool = False, + ): """ Initialize the OpenAPIValidationMiddleware. @@ -67,8 +71,13 @@ def __init__(self, validation_serializer: Callable[[Any], str] | None = None): validation_serializer : Callable, optional Optional serializer to use when serializing the response for validation. Use it when you have a custom type that cannot be serialized by the default jsonable_encoder. + + has_response_validation_error: bool, optional + Optional flag used to distinguish between payload and validation errors. + By setting this flag to True, ResponseValidationError will be raised if response could not be validated. """ self._validation_serializer = validation_serializer + self._has_response_validation_error = has_response_validation_error def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: logger.debug("OpenAPIValidationMiddleware handler") @@ -164,6 +173,8 @@ def _serialize_response( errors: list[dict[str, Any]] = [] value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) if errors: + if self._has_response_validation_error: + raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content) raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) if hasattr(field, "serialize"): diff --git a/aws_lambda_powertools/event_handler/openapi/exceptions.py b/aws_lambda_powertools/event_handler/openapi/exceptions.py index e1ed33e67fd..046a270cdf7 100644 --- a/aws_lambda_powertools/event_handler/openapi/exceptions.py +++ b/aws_lambda_powertools/event_handler/openapi/exceptions.py @@ -23,6 +23,16 @@ def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: self.body = body +class ResponseValidationError(ValidationException): + """ + Raised when the response body does not match the OpenAPI schema + """ + + def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + super().__init__(errors) + self.body = body + + class SerializationError(Exception): """ Base exception for all encoding errors diff --git a/aws_lambda_powertools/event_handler/vpc_lattice.py b/aws_lambda_powertools/event_handler/vpc_lattice.py index f145c4342e8..30ee8fd86fc 100644 --- a/aws_lambda_powertools/event_handler/vpc_lattice.py +++ b/aws_lambda_powertools/event_handler/vpc_lattice.py @@ -8,6 +8,8 @@ ) if TYPE_CHECKING: + from http import HTTPStatus + from aws_lambda_powertools.event_handler import CORSConfig from aws_lambda_powertools.utilities.data_classes import VPCLatticeEvent, VPCLatticeEventV2 @@ -53,9 +55,18 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, + response_validation_error_http_code: HTTPStatus | int | None = None, ): """Amazon VPC Lattice resolver""" - super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation) + super().__init__( + ProxyEventType.VPCLatticeEvent, + cors, + debug, + serializer, + strip_prefixes, + enable_validation, + response_validation_error_http_code, + ) def _get_base_path(self) -> str: return "" @@ -102,9 +113,18 @@ def __init__( serializer: Callable[[dict], str] | None = None, strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, + response_validation_error_http_code: HTTPStatus | int | None = None, ): """Amazon VPC Lattice resolver""" - super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation) + super().__init__( + ProxyEventType.VPCLatticeEventV2, + cors, + debug, + serializer, + strip_prefixes, + enable_validation, + response_validation_error_http_code, + ) def _get_base_path(self) -> str: return "" diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index f2a60697740..70eef0a2b86 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -309,7 +309,7 @@ Let's rewrite the previous examples to signal our resolver what shape we expect !!! info "By default, we hide extended error details for security reasons _(e.g., pydantic url, Pydantic code)_." -Any incoming request that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this: +Any incoming request or and outgoing response that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this: ```json hl_lines="2 3" title="data_validation_error_unsanitized_output.json" --8<-- "examples/event_handler_rest/src/data_validation_error_unsanitized_output.json" @@ -321,8 +321,6 @@ Here's an example where we catch validation errors, log all details for further === "data_validation_sanitized_error.py" - Note that Pydantic versions [1](https://docs.pydantic.dev/1.10/usage/models/#error-handling){target="_blank" rel="nofollow"} and [2](https://docs.pydantic.dev/latest/errors/errors/){target="_blank" rel="nofollow"} report validation detailed errors differently. - ```python hl_lines="8 24-25 31" --8<-- "examples/event_handler_rest/src/data_validation_sanitized_error.py" ``` @@ -398,6 +396,27 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou --8<-- "examples/event_handler_rest/src/validating_payload_subset_output.json" ``` +#### Validating responses + +You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`. + +=== "customizing_response_validation.py" + + ```python hl_lines="1 16 29 33" + --8<-- "examples/event_handler_rest/src/customizing_response_validation.py" + ``` + + 1. A response with status code set here will be returned if response data is not valid. + 2. Operation returns a string as oppose to a `Todo` object. This will lead to a `500` response as set in line 18. + +=== "customizing_response_validation_exception.py" + + ```python hl_lines="1 18 38 39" + --8<-- "examples/event_handler_rest/src/customizing_response_validation_exception.py" + ``` + + 1. The distinct `ResponseValidationError` exception can be caught to customise the response. + #### Validating query strings !!! info "We will automatically validate and inject incoming query strings via type annotation." diff --git a/examples/event_handler_rest/src/customizing_response_validation.py b/examples/event_handler_rest/src/customizing_response_validation.py new file mode 100644 index 00000000000..2b7b2c16c9f --- /dev/null +++ b/examples/event_handler_rest/src/customizing_response_validation.py @@ -0,0 +1,39 @@ +from http import HTTPStatus +from typing import Optional + +import requests +from pydantic import BaseModel, Field + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.logging import correlation_paths +from aws_lambda_powertools.utilities.typing import LambdaContext + +tracer = Tracer() +logger = Logger() +app = APIGatewayRestResolver( + enable_validation=True, + response_validation_error_http_code=HTTPStatus.INTERNAL_SERVER_ERROR, # (1)! +) + + +class Todo(BaseModel): + userId: int + id_: Optional[int] = Field(alias="id", default=None) + title: str + completed: bool + + +@app.get("/todos_bad_response/") +@tracer.capture_method +def get_todo_by_id(todo_id: int) -> Todo: + todo = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}") + todo.raise_for_status() + + return todo.json()["title"] # (2)! + + +@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP) +@tracer.capture_lambda_handler +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/customizing_response_validation_exception.py b/examples/event_handler_rest/src/customizing_response_validation_exception.py new file mode 100644 index 00000000000..c94ace290d2 --- /dev/null +++ b/examples/event_handler_rest/src/customizing_response_validation_exception.py @@ -0,0 +1,52 @@ +from http import HTTPStatus +from typing import Optional + +import requests +from pydantic import BaseModel, Field + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, content_types +from aws_lambda_powertools.event_handler.api_gateway import Response +from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError +from aws_lambda_powertools.logging import correlation_paths +from aws_lambda_powertools.utilities.typing import LambdaContext + +tracer = Tracer() +logger = Logger() +app = APIGatewayRestResolver( + enable_validation=True, + response_validation_error_http_code=HTTPStatus.INTERNAL_SERVER_ERROR, +) + + +class Todo(BaseModel): + userId: int + id_: Optional[int] = Field(alias="id", default=None) + title: str + completed: bool + + +@app.get("/todos_bad_response/") +@tracer.capture_method +def get_todo_by_id(todo_id: int) -> Todo: + todo = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}") + todo.raise_for_status() + + return todo.json()["title"] + + +@app.exception_handler(ResponseValidationError) # (1)! +def handle_response_validation_error(ex: ResponseValidationError): + logger.error("Request failed validation", path=app.current_event.path, errors=ex.errors()) + + return Response( + status_code=500, + content_type=content_types.APPLICATION_JSON, + body="Unexpected response.", + ) + + +@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP) +@tracer.capture_lambda_handler +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 54425f34986..4103a301020 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -17,6 +17,7 @@ VPCLatticeResolver, VPCLatticeV2Resolver, ) +from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError from aws_lambda_powertools.event_handler.openapi.params import Body, Header, Query @@ -1128,3 +1129,252 @@ def handler(user_id: int = 123): # THEN the handler should be invoked and return 200 result = app(minimal_event, {}) assert result["statusCode"] == 200 + + +@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed") +def test_validation_error_none_returned_non_optional_type(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/none_not_allowed") + def handler_none_not_allowed() -> Model: + return None # type: ignore + + # WHEN returning None for a non-Optional type + gw_event["path"] = "/none_not_allowed" + result = app(gw_event, {}) + + # THEN it should return a validation error + assert result["statusCode"] == 422 + body = json.loads(result["body"]) + assert body["detail"][0]["type"] == "model_attributes_type" + assert body["detail"][0]["loc"] == ["response"] + + +def test_validation_error_incomplete_model_returned_non_optional_type(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/incomplete_model_not_allowed") + def handler_incomplete_model_not_allowed() -> Model: + return {"age": 18} # type: ignore + + # WHEN returning incomplete model for a non-Optional type + gw_event["path"] = "/incomplete_model_not_allowed" + result = app(gw_event, {}) + + # THEN it should return a validation error + assert result["statusCode"] == 422 + body = json.loads(result["body"]) + assert "missing" in body["detail"][0]["type"] + assert "name" in body["detail"][0]["loc"] + + +def test_none_returned_for_optional_type(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/none_allowed") + def handler_none_allowed() -> Optional[Model]: + return None + + # WHEN returning None for an Optional type + gw_event["path"] = "/none_allowed" + result = app(gw_event, {}) + + # THEN it should succeed + assert result["statusCode"] == 200 + assert result["body"] == "null" + + +@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed") +@pytest.mark.parametrize( + "path, body", + [ + ("/empty_dict", {}), + ("/empty_list", []), + ("/none", "null"), + ("/empty_string", ""), + ], + ids=["empty_dict", "empty_list", "none", "empty_string"], +) +def test_none_returned_for_falsy_return(gw_event, path, body): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get(path) + def handler_none_allowed() -> Model: + return body + + # WHEN returning None for an Optional type + gw_event["path"] = path + result = app(gw_event, {}) + + # THEN it should succeed + assert result["statusCode"] == 422 + + +def test_custom_response_validation_error_http_code_valid_response(gw_event): + # GIVEN an APIGatewayRestResolver with custom response validation enabled + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=422) + + class Model(BaseModel): + name: str + age: int + + @app.get("/valid_response") + def handler_valid_response() -> Model: + return { + "name": "Joe", + "age": 18, + } # type: ignore + + # WHEN returning the expected type + gw_event["path"] = "/valid_response" + result = app(gw_event, {}) + + # THEN it should return a 200 OK + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body == {"name": "Joe", "age": 18} + + +@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed") +@pytest.mark.parametrize( + "http_code", + (422, 500, 510), +) +def test_custom_response_validation_error_http_code_invalid_response_none( + http_code, + gw_event, +): + # GIVEN an APIGatewayRestResolver with custom response validation enabled + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=http_code) + + class Model(BaseModel): + name: str + age: int + + @app.get("/none_not_allowed") + def handler_none_not_allowed() -> Model: + return None # type: ignore + + # WHEN returning None for a non-Optional type + gw_event["path"] = "/none_not_allowed" + result = app(gw_event, {}) + + # THEN it should return a validation error with the custom status code provided + assert result["statusCode"] == http_code + body = json.loads(result["body"]) + assert body["detail"][0]["type"] == "model_attributes_type" + assert body["detail"][0]["loc"] == ["response"] + + +@pytest.mark.parametrize( + "http_code", + (422, 500, 510), +) +def test_custom_response_validation_error_http_code_invalid_response_incomplete_model( + http_code, + gw_event, +): + # GIVEN an APIGatewayRestResolver with custom response validation enabled + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=http_code) + + class Model(BaseModel): + name: str + age: int + + @app.get("/incomplete_model_not_allowed") + def handler_incomplete_model_not_allowed() -> Model: + return {"age": 18} # type: ignore + + # WHEN returning incomplete model for a non-Optional type + gw_event["path"] = "/incomplete_model_not_allowed" + result = app(gw_event, {}) + + # THEN it should return a validation error with the custom status code provided + assert result["statusCode"] == http_code + body = json.loads(result["body"]) + assert body["detail"][0]["type"] == "missing" + assert body["detail"][0]["loc"] == ["response", "name"] + + +@pytest.mark.parametrize( + "http_code", + (422, 500, 510), +) +def test_custom_response_validation_error_sanitized_response( + http_code, + gw_event, +): + # GIVEN an APIGatewayRestResolver with custom response validation enabled + # with a sanitized response validation error response + app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=http_code) + + class Model(BaseModel): + name: str + age: int + + @app.get("/incomplete_model_not_allowed") + def handler_incomplete_model_not_allowed() -> Model: + return {"age": 18} # type: ignore + + @app.exception_handler(ResponseValidationError) + def handle_response_validation_error(ex: ResponseValidationError): + return Response( + status_code=500, + body="Unexpected response.", + ) + + # WHEN returning incomplete model for a non-Optional type + gw_event["path"] = "/incomplete_model_not_allowed" + result = app(gw_event, {}) + + # THEN it should return the sanitized response + assert result["statusCode"] == 500 + assert result["body"] == "Unexpected response." + + +def test_custom_response_validation_error_no_validation(): + # GIVEN an APIGatewayRestResolver with validation not enabled + # setting a custom http status code for response validation must raise a ValueError + with pytest.raises(ValueError) as exception_info: + APIGatewayRestResolver(response_validation_error_http_code=500) + + assert ( + str(exception_info.value) + == "'response_validation_error_http_code' cannot be set when enable_validation is False." + ) + + +@pytest.mark.parametrize("response_validation_error_http_code", [(20), ("hi"), (1.21)]) +def test_custom_response_validation_error_bad_http_code(response_validation_error_http_code): + # GIVEN an APIGatewayRestResolver with validation enabled + # setting custom status code for response validation that is not a valid HTTP code must raise a ValueError + with pytest.raises(ValueError) as exception_info: + APIGatewayRestResolver( + enable_validation=True, + response_validation_error_http_code=response_validation_error_http_code, + ) + + assert ( + str(exception_info.value) + == f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code." + )