Skip to content

Commit f8ead84

Browse files
author
Amin Farjadi
committed
refactor: make response_validation_error_http_status accept more types and add more detailed error messages.
1 parent 1c33611 commit f8ead84

File tree

3 files changed

+38
-12
lines changed

3 files changed

+38
-12
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,7 @@ def __init__(
15001500
serializer: Callable[[dict], str] | None = None,
15011501
strip_prefixes: list[str | Pattern] | None = None,
15021502
enable_validation: bool = False,
1503-
response_validation_error_http_status: HTTPStatus | None = None,
1503+
response_validation_error_http_status=None,
15041504
):
15051505
"""
15061506
Parameters
@@ -1520,6 +1520,8 @@ def __init__(
15201520
Each prefix can be a static string or a compiled regex pattern
15211521
enable_validation: bool | None
15221522
Enables validation of the request body against the route schema, by default False.
1523+
response_validation_error_http_status
1524+
Enables response validation and sets returned status code if response is not validated.
15231525
"""
15241526
self._proxy_type = proxy_type
15251527
self._dynamic_routes: list[Route] = []
@@ -1535,7 +1537,28 @@ def __init__(
15351537
self.context: dict = {} # early init as customers might add context before event resolution
15361538
self.processed_stack_frames = []
15371539
self._response_builder_class = ResponseBuilder[BaseProxyEvent]
1538-
self._response_validation_error_http_status = response_validation_error_http_status
1540+
self._has_response_validation_error = response_validation_error_http_status is not None
1541+
1542+
if response_validation_error_http_status and not enable_validation:
1543+
msg = "'response_validation_error_http_status' cannot be set when enable_validation is False."
1544+
raise ValueError(msg)
1545+
1546+
if (
1547+
not isinstance(response_validation_error_http_status, HTTPStatus)
1548+
and response_validation_error_http_status is not None
1549+
):
1550+
1551+
try:
1552+
response_validation_error_http_status = HTTPStatus(response_validation_error_http_status)
1553+
except ValueError:
1554+
msg = f"'{response_validation_error_http_status}' must be an integer representing an HTTP status code."
1555+
raise ValueError(msg) from None
1556+
1557+
self._response_validation_error_http_status = (
1558+
response_validation_error_http_status
1559+
if response_validation_error_http_status
1560+
else HTTPStatus.UNPROCESSABLE_ENTITY
1561+
)
15391562

15401563
# Allow for a custom serializer or a concise json serialization
15411564
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
@@ -1549,7 +1572,7 @@ def __init__(
15491572
[
15501573
OpenAPIValidationMiddleware(
15511574
validation_serializer=serializer,
1552-
has_response_validation_error=self._response_validation_error_http_status is not None,
1575+
has_response_validation_error=self._has_response_validation_error,
15531576
),
15541577
],
15551578
)
@@ -2386,12 +2409,15 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
23862409
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
23872410
# 'self._response_validation_error_http_status' is not None
23882411
if isinstance(exp, ResponseValidationError):
2389-
if self._response_validation_error_http_status is None:
2390-
raise TypeError
2412+
http_status = (
2413+
self._response_validation_error_http_status
2414+
if self._response_validation_error_http_status
2415+
else HTTPStatus.UNPROCESSABLE_ENTITY
2416+
)
23912417
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
23922418
return self._response_builder_class(
23932419
response=Response(
2394-
status_code=self._response_validation_error_http_status,
2420+
status_code=http_status.value,
23952421
content_type=content_types.APPLICATION_JSON,
23962422
body={"statusCode": self._response_validation_error_http_status, "detail": errors},
23972423
),
@@ -2611,7 +2637,7 @@ def __init__(
26112637
serializer: Callable[[dict], str] | None = None,
26122638
strip_prefixes: list[str | Pattern] | None = None,
26132639
enable_validation: bool = False,
2614-
response_validation_error_http_status: HTTPStatus | None = None,
2640+
response_validation_error_http_status: HTTPStatus | int | None = None,
26152641
):
26162642
"""Amazon API Gateway REST and HTTP API v1 payload resolver"""
26172643
super().__init__(
@@ -2695,7 +2721,7 @@ def __init__(
26952721
serializer: Callable[[dict], str] | None = None,
26962722
strip_prefixes: list[str | Pattern] | None = None,
26972723
enable_validation: bool = False,
2698-
response_validation_error_http_status: HTTPStatus | None = None,
2724+
response_validation_error_http_status: HTTPStatus | int | None = None,
26992725
):
27002726
"""Amazon API Gateway HTTP API v2 payload resolver"""
27012727
super().__init__(
@@ -2734,7 +2760,7 @@ def __init__(
27342760
serializer: Callable[[dict], str] | None = None,
27352761
strip_prefixes: list[str | Pattern] | None = None,
27362762
enable_validation: bool = False,
2737-
response_validation_error_http_status: HTTPStatus | None = None,
2763+
response_validation_error_http_status: HTTPStatus | int | None = None,
27382764
):
27392765
"""Amazon Application Load Balancer (ALB) resolver"""
27402766
super().__init__(

aws_lambda_powertools/event_handler/lambda_function_url.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
serializer: Callable[[dict], str] | None = None,
6060
strip_prefixes: list[str | Pattern] | None = None,
6161
enable_validation: bool = False,
62-
response_validation_error_http_status: HTTPStatus | None = None,
62+
response_validation_error_http_status: HTTPStatus | int | None = None,
6363
):
6464
super().__init__(
6565
ProxyEventType.LambdaFunctionUrlEvent,

aws_lambda_powertools/event_handler/vpc_lattice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
serializer: Callable[[dict], str] | None = None,
5656
strip_prefixes: list[str | Pattern] | None = None,
5757
enable_validation: bool = False,
58-
response_validation_error_http_status: HTTPStatus | None = None,
58+
response_validation_error_http_status: HTTPStatus | int | None = None,
5959
):
6060
"""Amazon VPC Lattice resolver"""
6161
super().__init__(
@@ -113,7 +113,7 @@ def __init__(
113113
serializer: Callable[[dict], str] | None = None,
114114
strip_prefixes: list[str | Pattern] | None = None,
115115
enable_validation: bool = False,
116-
response_validation_error_http_status: HTTPStatus | None = None,
116+
response_validation_error_http_status: HTTPStatus | int | None = None,
117117
):
118118
"""Amazon VPC Lattice resolver"""
119119
super().__init__(

0 commit comments

Comments
 (0)