19
19
from aws_lambda_powertools .event_handler import content_types
20
20
from aws_lambda_powertools .event_handler .exceptions import NotFoundError , ServiceError
21
21
from aws_lambda_powertools .event_handler .openapi .constants import DEFAULT_API_VERSION , DEFAULT_OPENAPI_VERSION
22
- from aws_lambda_powertools .event_handler .openapi .exceptions import RequestValidationError , SchemaValidationError
22
+ from aws_lambda_powertools .event_handler .openapi .exceptions import (
23
+ RequestValidationError ,
24
+ ResponseValidationError ,
25
+ SchemaValidationError ,
26
+ )
23
27
from aws_lambda_powertools .event_handler .openapi .types import (
24
28
COMPONENT_REF_PREFIX ,
25
29
METHODS_WITH_BODY ,
@@ -1496,6 +1500,7 @@ def __init__(
1496
1500
serializer : Callable [[dict ], str ] | None = None ,
1497
1501
strip_prefixes : list [str | Pattern ] | None = None ,
1498
1502
enable_validation : bool = False ,
1503
+ response_validation_error_http_code : HTTPStatus | int | None = None ,
1499
1504
):
1500
1505
"""
1501
1506
Parameters
@@ -1515,6 +1520,8 @@ def __init__(
1515
1520
Each prefix can be a static string or a compiled regex pattern
1516
1521
enable_validation: bool | None
1517
1522
Enables validation of the request body against the route schema, by default False.
1523
+ response_validation_error_http_code
1524
+ Sets the returned status code if response is not validated. enable_validation must be True.
1518
1525
"""
1519
1526
self ._proxy_type = proxy_type
1520
1527
self ._dynamic_routes : list [Route ] = []
@@ -1530,6 +1537,11 @@ def __init__(
1530
1537
self .context : dict = {} # early init as customers might add context before event resolution
1531
1538
self .processed_stack_frames = []
1532
1539
self ._response_builder_class = ResponseBuilder [BaseProxyEvent ]
1540
+ self ._has_response_validation_error = response_validation_error_http_code is not None
1541
+ self ._response_validation_error_http_code = self ._validate_response_validation_error_http_code (
1542
+ response_validation_error_http_code ,
1543
+ enable_validation ,
1544
+ )
1533
1545
1534
1546
# Allow for a custom serializer or a concise json serialization
1535
1547
self ._serializer = serializer or partial (json .dumps , separators = ("," , ":" ), cls = Encoder )
@@ -1539,7 +1551,40 @@ def __init__(
1539
1551
1540
1552
# Note the serializer argument: only use custom serializer if provided by the caller
1541
1553
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
1542
- self .use ([OpenAPIValidationMiddleware (validation_serializer = serializer )])
1554
+ self .use (
1555
+ [
1556
+ OpenAPIValidationMiddleware (
1557
+ validation_serializer = serializer ,
1558
+ has_response_validation_error = self ._has_response_validation_error ,
1559
+ ),
1560
+ ],
1561
+ )
1562
+
1563
+ def _validate_response_validation_error_http_code (
1564
+ self ,
1565
+ response_validation_error_http_code : HTTPStatus | int | None ,
1566
+ enable_validation : bool ,
1567
+ ) -> HTTPStatus :
1568
+ if response_validation_error_http_code and not enable_validation :
1569
+ msg = "'response_validation_error_http_code' cannot be set when enable_validation is False."
1570
+ raise ValueError (msg )
1571
+
1572
+ if (
1573
+ not isinstance (response_validation_error_http_code , HTTPStatus )
1574
+ and response_validation_error_http_code is not None
1575
+ ):
1576
+
1577
+ try :
1578
+ response_validation_error_http_code = HTTPStatus (response_validation_error_http_code )
1579
+ except ValueError :
1580
+ msg = f"'{ response_validation_error_http_code } ' must be an integer representing an HTTP status code."
1581
+ raise ValueError (msg ) from None
1582
+
1583
+ return (
1584
+ response_validation_error_http_code
1585
+ if response_validation_error_http_code
1586
+ else HTTPStatus .UNPROCESSABLE_ENTITY
1587
+ )
1543
1588
1544
1589
def get_openapi_schema (
1545
1590
self ,
@@ -2370,6 +2415,21 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
2370
2415
route = route ,
2371
2416
)
2372
2417
2418
+ # OpenAPIValidationMiddleware will only raise ResponseValidationError when
2419
+ # 'self._response_validation_error_http_code' is not None
2420
+ if isinstance (exp , ResponseValidationError ):
2421
+ http_code = self ._response_validation_error_http_code
2422
+ errors = [{"loc" : e ["loc" ], "type" : e ["type" ]} for e in exp .errors ()]
2423
+ return self ._response_builder_class (
2424
+ response = Response (
2425
+ status_code = http_code .value ,
2426
+ content_type = content_types .APPLICATION_JSON ,
2427
+ body = {"statusCode" : self ._response_validation_error_http_code , "detail" : errors },
2428
+ ),
2429
+ serializer = self ._serializer ,
2430
+ route = route ,
2431
+ )
2432
+
2373
2433
if isinstance (exp , ServiceError ):
2374
2434
return self ._response_builder_class (
2375
2435
response = Response (
@@ -2582,6 +2642,7 @@ def __init__(
2582
2642
serializer : Callable [[dict ], str ] | None = None ,
2583
2643
strip_prefixes : list [str | Pattern ] | None = None ,
2584
2644
enable_validation : bool = False ,
2645
+ response_validation_error_http_code : HTTPStatus | int | None = None ,
2585
2646
):
2586
2647
"""Amazon API Gateway REST and HTTP API v1 payload resolver"""
2587
2648
super ().__init__ (
@@ -2591,6 +2652,7 @@ def __init__(
2591
2652
serializer ,
2592
2653
strip_prefixes ,
2593
2654
enable_validation ,
2655
+ response_validation_error_http_code ,
2594
2656
)
2595
2657
2596
2658
def _get_base_path (self ) -> str :
@@ -2664,6 +2726,7 @@ def __init__(
2664
2726
serializer : Callable [[dict ], str ] | None = None ,
2665
2727
strip_prefixes : list [str | Pattern ] | None = None ,
2666
2728
enable_validation : bool = False ,
2729
+ response_validation_error_http_code : HTTPStatus | int | None = None ,
2667
2730
):
2668
2731
"""Amazon API Gateway HTTP API v2 payload resolver"""
2669
2732
super ().__init__ (
@@ -2673,6 +2736,7 @@ def __init__(
2673
2736
serializer ,
2674
2737
strip_prefixes ,
2675
2738
enable_validation ,
2739
+ response_validation_error_http_code ,
2676
2740
)
2677
2741
2678
2742
def _get_base_path (self ) -> str :
@@ -2701,9 +2765,18 @@ def __init__(
2701
2765
serializer : Callable [[dict ], str ] | None = None ,
2702
2766
strip_prefixes : list [str | Pattern ] | None = None ,
2703
2767
enable_validation : bool = False ,
2768
+ response_validation_error_http_code : HTTPStatus | int | None = None ,
2704
2769
):
2705
2770
"""Amazon Application Load Balancer (ALB) resolver"""
2706
- super ().__init__ (ProxyEventType .ALBEvent , cors , debug , serializer , strip_prefixes , enable_validation )
2771
+ super ().__init__ (
2772
+ ProxyEventType .ALBEvent ,
2773
+ cors ,
2774
+ debug ,
2775
+ serializer ,
2776
+ strip_prefixes ,
2777
+ enable_validation ,
2778
+ response_validation_error_http_code ,
2779
+ )
2707
2780
2708
2781
def _get_base_path (self ) -> str :
2709
2782
# ALB doesn't have a stage variable, so we just return an empty string
0 commit comments