Skip to content

Commit f943f45

Browse files
Merging from develop
1 parent 26c12d4 commit f943f45

File tree

9 files changed

+484
-10
lines changed

9 files changed

+484
-10
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
DEFAULT_OPENAPI_TITLE,
2525
DEFAULT_OPENAPI_VERSION,
2626
)
27-
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, SchemaValidationError
27+
from aws_lambda_powertools.event_handler.openapi.exceptions import (
28+
RequestValidationError,
29+
ResponseValidationError,
30+
SchemaValidationError,
31+
)
2832
from aws_lambda_powertools.event_handler.openapi.types import (
2933
COMPONENT_REF_PREFIX,
3034
METHODS_WITH_BODY,
@@ -1501,6 +1505,7 @@ def __init__(
15011505
serializer: Callable[[dict], str] | None = None,
15021506
strip_prefixes: list[str | Pattern] | None = None,
15031507
enable_validation: bool = False,
1508+
response_validation_error_http_code: HTTPStatus | int | None = None,
15041509
):
15051510
"""
15061511
Parameters
@@ -1520,6 +1525,8 @@ def __init__(
15201525
Each prefix can be a static string or a compiled regex pattern
15211526
enable_validation: bool | None
15221527
Enables validation of the request body against the route schema, by default False.
1528+
response_validation_error_http_code
1529+
Sets the returned status code if response is not validated. enable_validation must be True.
15231530
"""
15241531
self._proxy_type = proxy_type
15251532
self._dynamic_routes: list[Route] = []
@@ -1536,6 +1543,11 @@ def __init__(
15361543
self.processed_stack_frames = []
15371544
self._response_builder_class = ResponseBuilder[BaseProxyEvent]
15381545
self.openapi_config = OpenAPIConfig() # starting an empty dataclass
1546+
self._has_response_validation_error = response_validation_error_http_code is not None
1547+
self._response_validation_error_http_code = self._validate_response_validation_error_http_code(
1548+
response_validation_error_http_code,
1549+
enable_validation,
1550+
)
15391551

15401552
# Allow for a custom serializer or a concise json serialization
15411553
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
@@ -1545,7 +1557,36 @@ def __init__(
15451557

15461558
# Note the serializer argument: only use custom serializer if provided by the caller
15471559
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
1548-
self.use([OpenAPIValidationMiddleware(validation_serializer=serializer)])
1560+
self.use(
1561+
[
1562+
OpenAPIValidationMiddleware(
1563+
validation_serializer=serializer,
1564+
has_response_validation_error=self._has_response_validation_error,
1565+
),
1566+
],
1567+
)
1568+
1569+
def _validate_response_validation_error_http_code(
1570+
self,
1571+
response_validation_error_http_code: HTTPStatus | int | None,
1572+
enable_validation: bool,
1573+
) -> HTTPStatus:
1574+
if response_validation_error_http_code and not enable_validation:
1575+
msg = "'response_validation_error_http_code' cannot be set when enable_validation is False."
1576+
raise ValueError(msg)
1577+
1578+
if (
1579+
not isinstance(response_validation_error_http_code, HTTPStatus)
1580+
and response_validation_error_http_code is not None
1581+
):
1582+
1583+
try:
1584+
response_validation_error_http_code = HTTPStatus(response_validation_error_http_code)
1585+
except ValueError:
1586+
msg = f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code."
1587+
raise ValueError(msg) from None
1588+
1589+
return response_validation_error_http_code or HTTPStatus.UNPROCESSABLE_ENTITY
15491590

15501591
def get_openapi_schema(
15511592
self,
@@ -2484,6 +2525,21 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
24842525
route=route,
24852526
)
24862527

2528+
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
2529+
# 'self._response_validation_error_http_code' is not None
2530+
if isinstance(exp, ResponseValidationError):
2531+
http_code = self._response_validation_error_http_code
2532+
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
2533+
return self._response_builder_class(
2534+
response=Response(
2535+
status_code=http_code.value,
2536+
content_type=content_types.APPLICATION_JSON,
2537+
body={"statusCode": self._response_validation_error_http_code, "detail": errors},
2538+
),
2539+
serializer=self._serializer,
2540+
route=route,
2541+
)
2542+
24872543
if isinstance(exp, ServiceError):
24882544
return self._response_builder_class(
24892545
response=Response(
@@ -2696,6 +2752,7 @@ def __init__(
26962752
serializer: Callable[[dict], str] | None = None,
26972753
strip_prefixes: list[str | Pattern] | None = None,
26982754
enable_validation: bool = False,
2755+
response_validation_error_http_code: HTTPStatus | int | None = None,
26992756
):
27002757
"""Amazon API Gateway REST and HTTP API v1 payload resolver"""
27012758
super().__init__(
@@ -2705,6 +2762,7 @@ def __init__(
27052762
serializer,
27062763
strip_prefixes,
27072764
enable_validation,
2765+
response_validation_error_http_code,
27082766
)
27092767

27102768
def _get_base_path(self) -> str:
@@ -2778,6 +2836,7 @@ def __init__(
27782836
serializer: Callable[[dict], str] | None = None,
27792837
strip_prefixes: list[str | Pattern] | None = None,
27802838
enable_validation: bool = False,
2839+
response_validation_error_http_code: HTTPStatus | int | None = None,
27812840
):
27822841
"""Amazon API Gateway HTTP API v2 payload resolver"""
27832842
super().__init__(
@@ -2787,6 +2846,7 @@ def __init__(
27872846
serializer,
27882847
strip_prefixes,
27892848
enable_validation,
2849+
response_validation_error_http_code,
27902850
)
27912851

27922852
def _get_base_path(self) -> str:
@@ -2815,9 +2875,18 @@ def __init__(
28152875
serializer: Callable[[dict], str] | None = None,
28162876
strip_prefixes: list[str | Pattern] | None = None,
28172877
enable_validation: bool = False,
2878+
response_validation_error_http_code: HTTPStatus | int | None = None,
28182879
):
28192880
"""Amazon Application Load Balancer (ALB) resolver"""
2820-
super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes, enable_validation)
2881+
super().__init__(
2882+
ProxyEventType.ALBEvent,
2883+
cors,
2884+
debug,
2885+
serializer,
2886+
strip_prefixes,
2887+
enable_validation,
2888+
response_validation_error_http_code,
2889+
)
28212890

28222891
def _get_base_path(self) -> str:
28232892
# ALB doesn't have a stage variable, so we just return an empty string

aws_lambda_powertools/event_handler/lambda_function_url.py

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
)
99

1010
if TYPE_CHECKING:
11+
from http import HTTPStatus
12+
1113
from aws_lambda_powertools.event_handler import CORSConfig
1214
from aws_lambda_powertools.utilities.data_classes import LambdaFunctionUrlEvent
1315

@@ -57,6 +59,7 @@ def __init__(
5759
serializer: Callable[[dict], str] | None = None,
5860
strip_prefixes: list[str | Pattern] | None = None,
5961
enable_validation: bool = False,
62+
response_validation_error_http_code: HTTPStatus | int | None = None,
6063
):
6164
super().__init__(
6265
ProxyEventType.LambdaFunctionUrlEvent,
@@ -65,6 +68,7 @@ def __init__(
6568
serializer,
6669
strip_prefixes,
6770
enable_validation,
71+
response_validation_error_http_code,
6872
)
6973

7074
def _get_base_path(self) -> str:

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
1919
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
20-
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
20+
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
2121
from aws_lambda_powertools.event_handler.openapi.params import Param
2222

2323
if TYPE_CHECKING:
@@ -58,7 +58,11 @@ def get_todos(): list[Todo]:
5858
```
5959
"""
6060

61-
def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
61+
def __init__(
62+
self,
63+
validation_serializer: Callable[[Any], str] | None = None,
64+
has_response_validation_error: bool = False,
65+
):
6266
"""
6367
Initialize the OpenAPIValidationMiddleware.
6468
@@ -67,8 +71,13 @@ def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
6771
validation_serializer : Callable, optional
6872
Optional serializer to use when serializing the response for validation.
6973
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
74+
75+
has_response_validation_error: bool, optional
76+
Optional flag used to distinguish between payload and validation errors.
77+
By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
7078
"""
7179
self._validation_serializer = validation_serializer
80+
self._has_response_validation_error = has_response_validation_error
7281

7382
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
7483
logger.debug("OpenAPIValidationMiddleware handler")
@@ -164,6 +173,8 @@ def _serialize_response(
164173
errors: list[dict[str, Any]] = []
165174
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
166175
if errors:
176+
if self._has_response_validation_error:
177+
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content)
167178
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
168179

169180
if hasattr(field, "serialize"):

aws_lambda_powertools/event_handler/openapi/exceptions.py

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
2323
self.body = body
2424

2525

26+
class ResponseValidationError(ValidationException):
27+
"""
28+
Raised when the response body does not match the OpenAPI schema
29+
"""
30+
31+
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
32+
super().__init__(errors)
33+
self.body = body
34+
35+
2636
class SerializationError(Exception):
2737
"""
2838
Base exception for all encoding errors

aws_lambda_powertools/event_handler/vpc_lattice.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
)
99

1010
if TYPE_CHECKING:
11+
from http import HTTPStatus
12+
1113
from aws_lambda_powertools.event_handler import CORSConfig
1214
from aws_lambda_powertools.utilities.data_classes import VPCLatticeEvent, VPCLatticeEventV2
1315

@@ -53,9 +55,18 @@ def __init__(
5355
serializer: Callable[[dict], str] | None = None,
5456
strip_prefixes: list[str | Pattern] | None = None,
5557
enable_validation: bool = False,
58+
response_validation_error_http_code: HTTPStatus | int | None = None,
5659
):
5760
"""Amazon VPC Lattice resolver"""
58-
super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation)
61+
super().__init__(
62+
ProxyEventType.VPCLatticeEvent,
63+
cors,
64+
debug,
65+
serializer,
66+
strip_prefixes,
67+
enable_validation,
68+
response_validation_error_http_code,
69+
)
5970

6071
def _get_base_path(self) -> str:
6172
return ""
@@ -102,9 +113,18 @@ def __init__(
102113
serializer: Callable[[dict], str] | None = None,
103114
strip_prefixes: list[str | Pattern] | None = None,
104115
enable_validation: bool = False,
116+
response_validation_error_http_code: HTTPStatus | int | None = None,
105117
):
106118
"""Amazon VPC Lattice resolver"""
107-
super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation)
119+
super().__init__(
120+
ProxyEventType.VPCLatticeEventV2,
121+
cors,
122+
debug,
123+
serializer,
124+
strip_prefixes,
125+
enable_validation,
126+
response_validation_error_http_code,
127+
)
108128

109129
def _get_base_path(self) -> str:
110130
return ""

docs/core/event_handler/api_gateway.md

+22-3
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ Let's rewrite the previous examples to signal our resolver what shape we expect
309309

310310
!!! info "By default, we hide extended error details for security reasons _(e.g., pydantic url, Pydantic code)_."
311311

312-
Any incoming request that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this:
312+
Any incoming request or and outgoing response that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this:
313313

314314
```json hl_lines="2 3" title="data_validation_error_unsanitized_output.json"
315315
--8<-- "examples/event_handler_rest/src/data_validation_error_unsanitized_output.json"
@@ -321,8 +321,6 @@ Here's an example where we catch validation errors, log all details for further
321321

322322
=== "data_validation_sanitized_error.py"
323323

324-
Note that Pydantic versions [1](https://docs.pydantic.dev/1.10/usage/models/#error-handling){target="_blank" rel="nofollow"} and [2](https://docs.pydantic.dev/latest/errors/errors/){target="_blank" rel="nofollow"} report validation detailed errors differently.
325-
326324
```python hl_lines="8 24-25 31"
327325
--8<-- "examples/event_handler_rest/src/data_validation_sanitized_error.py"
328326
```
@@ -398,6 +396,27 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou
398396
--8<-- "examples/event_handler_rest/src/validating_payload_subset_output.json"
399397
```
400398

399+
#### Validating responses
400+
401+
You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`.
402+
403+
=== "customizing_response_validation.py"
404+
405+
```python hl_lines="1 16 29 33"
406+
--8<-- "examples/event_handler_rest/src/customizing_response_validation.py"
407+
```
408+
409+
1. A response with status code set here will be returned if response data is not valid.
410+
2. Operation returns a string as oppose to a `Todo` object. This will lead to a `500` response as set in line 18.
411+
412+
=== "customizing_response_validation_exception.py"
413+
414+
```python hl_lines="1 18 38 39"
415+
--8<-- "examples/event_handler_rest/src/customizing_response_validation_exception.py"
416+
```
417+
418+
1. The distinct `ResponseValidationError` exception can be caught to customise the response.
419+
401420
#### Validating query strings
402421

403422
!!! info "We will automatically validate and inject incoming query strings via type annotation."
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from http import HTTPStatus
2+
from typing import Optional
3+
4+
import requests
5+
from pydantic import BaseModel, Field
6+
7+
from aws_lambda_powertools import Logger, Tracer
8+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
9+
from aws_lambda_powertools.logging import correlation_paths
10+
from aws_lambda_powertools.utilities.typing import LambdaContext
11+
12+
tracer = Tracer()
13+
logger = Logger()
14+
app = APIGatewayRestResolver(
15+
enable_validation=True,
16+
response_validation_error_http_code=HTTPStatus.INTERNAL_SERVER_ERROR, # (1)!
17+
)
18+
19+
20+
class Todo(BaseModel):
21+
userId: int
22+
id_: Optional[int] = Field(alias="id", default=None)
23+
title: str
24+
completed: bool
25+
26+
27+
@app.get("/todos_bad_response/<todo_id>")
28+
@tracer.capture_method
29+
def get_todo_by_id(todo_id: int) -> Todo:
30+
todo = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}")
31+
todo.raise_for_status()
32+
33+
return todo.json()["title"] # (2)!
34+
35+
36+
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP)
37+
@tracer.capture_lambda_handler
38+
def lambda_handler(event: dict, context: LambdaContext) -> dict:
39+
return app.resolve(event, context)

0 commit comments

Comments
 (0)