Skip to content

Commit 7e56fe1

Browse files
amin-farjadiAmin Farjadileandrodamascena
authored
feat(event_handler): add custom response validation in OpenAPI utility (#6189)
* feat(openapi-validation): Add response validation flag and distinct exception. * feat(api-gateway-resolver): Add option for custom response validation error status code. * feat(docs): Added doc for custom response validation error responses. * refactor(docs): Make exception handler function name better. * feat(unit-test): Add tests for custom response validation error. * fix: Formatting. * fix(docs): Fix grammar in response validation docs * fix(unit-test): fix failed CI. * bugfix(lint): Ignore lint error FA102, irrelevant for python >=3.9 * refactor: make response_validation_error_http_status accept more types and add more detailed error messages. * feat(unit-test): add tests for incorrect types and invalid configs * refactor: rename response_validation_error_http_status to response_validation_error_http_code * refactor(api_gateway): add method for validating response_validation_error_http_code param. * fix(api_gateway): fix type and docstring for response_validation_error_http_code param. * fix(api_gateway): remove unncessary check of response_validation_error_http_code param being None. * fix(openapi-validation): docstring for has_response_validation_error param. * refactor(tests): move unit tests into openapi_validation functional test file * fix(tests): skipping validation for falsy response * Refactoring documentation --------- Co-authored-by: Amin Farjadi <[email protected]> Co-authored-by: Leandro Damascena <[email protected]>
1 parent bba7e5c commit 7e56fe1

File tree

9 files changed

+488
-10
lines changed

9 files changed

+488
-10
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+76-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from aws_lambda_powertools.event_handler import content_types
2020
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
2121
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION
22-
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, SchemaValidationError
22+
from aws_lambda_powertools.event_handler.openapi.exceptions import (
23+
RequestValidationError,
24+
ResponseValidationError,
25+
SchemaValidationError,
26+
)
2327
from aws_lambda_powertools.event_handler.openapi.types import (
2428
COMPONENT_REF_PREFIX,
2529
METHODS_WITH_BODY,
@@ -1496,6 +1500,7 @@ def __init__(
14961500
serializer: Callable[[dict], str] | None = None,
14971501
strip_prefixes: list[str | Pattern] | None = None,
14981502
enable_validation: bool = False,
1503+
response_validation_error_http_code: HTTPStatus | int | None = None,
14991504
):
15001505
"""
15011506
Parameters
@@ -1515,6 +1520,8 @@ def __init__(
15151520
Each prefix can be a static string or a compiled regex pattern
15161521
enable_validation: bool | None
15171522
Enables validation of the request body against the route schema, by default False.
1523+
response_validation_error_http_code
1524+
Sets the returned status code if response is not validated. enable_validation must be True.
15181525
"""
15191526
self._proxy_type = proxy_type
15201527
self._dynamic_routes: list[Route] = []
@@ -1530,6 +1537,11 @@ def __init__(
15301537
self.context: dict = {} # early init as customers might add context before event resolution
15311538
self.processed_stack_frames = []
15321539
self._response_builder_class = ResponseBuilder[BaseProxyEvent]
1540+
self._has_response_validation_error = response_validation_error_http_code is not None
1541+
self._response_validation_error_http_code = self._validate_response_validation_error_http_code(
1542+
response_validation_error_http_code,
1543+
enable_validation,
1544+
)
15331545

15341546
# Allow for a custom serializer or a concise json serialization
15351547
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
@@ -1539,7 +1551,40 @@ def __init__(
15391551

15401552
# Note the serializer argument: only use custom serializer if provided by the caller
15411553
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
1542-
self.use([OpenAPIValidationMiddleware(validation_serializer=serializer)])
1554+
self.use(
1555+
[
1556+
OpenAPIValidationMiddleware(
1557+
validation_serializer=serializer,
1558+
has_response_validation_error=self._has_response_validation_error,
1559+
),
1560+
],
1561+
)
1562+
1563+
def _validate_response_validation_error_http_code(
1564+
self,
1565+
response_validation_error_http_code: HTTPStatus | int | None,
1566+
enable_validation: bool,
1567+
) -> HTTPStatus:
1568+
if response_validation_error_http_code and not enable_validation:
1569+
msg = "'response_validation_error_http_code' cannot be set when enable_validation is False."
1570+
raise ValueError(msg)
1571+
1572+
if (
1573+
not isinstance(response_validation_error_http_code, HTTPStatus)
1574+
and response_validation_error_http_code is not None
1575+
):
1576+
1577+
try:
1578+
response_validation_error_http_code = HTTPStatus(response_validation_error_http_code)
1579+
except ValueError:
1580+
msg = f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code."
1581+
raise ValueError(msg) from None
1582+
1583+
return (
1584+
response_validation_error_http_code
1585+
if response_validation_error_http_code
1586+
else HTTPStatus.UNPROCESSABLE_ENTITY
1587+
)
15431588

15441589
def get_openapi_schema(
15451590
self,
@@ -2370,6 +2415,21 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
23702415
route=route,
23712416
)
23722417

2418+
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
2419+
# 'self._response_validation_error_http_code' is not None
2420+
if isinstance(exp, ResponseValidationError):
2421+
http_code = self._response_validation_error_http_code
2422+
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
2423+
return self._response_builder_class(
2424+
response=Response(
2425+
status_code=http_code.value,
2426+
content_type=content_types.APPLICATION_JSON,
2427+
body={"statusCode": self._response_validation_error_http_code, "detail": errors},
2428+
),
2429+
serializer=self._serializer,
2430+
route=route,
2431+
)
2432+
23732433
if isinstance(exp, ServiceError):
23742434
return self._response_builder_class(
23752435
response=Response(
@@ -2582,6 +2642,7 @@ def __init__(
25822642
serializer: Callable[[dict], str] | None = None,
25832643
strip_prefixes: list[str | Pattern] | None = None,
25842644
enable_validation: bool = False,
2645+
response_validation_error_http_code: HTTPStatus | int | None = None,
25852646
):
25862647
"""Amazon API Gateway REST and HTTP API v1 payload resolver"""
25872648
super().__init__(
@@ -2591,6 +2652,7 @@ def __init__(
25912652
serializer,
25922653
strip_prefixes,
25932654
enable_validation,
2655+
response_validation_error_http_code,
25942656
)
25952657

25962658
def _get_base_path(self) -> str:
@@ -2664,6 +2726,7 @@ def __init__(
26642726
serializer: Callable[[dict], str] | None = None,
26652727
strip_prefixes: list[str | Pattern] | None = None,
26662728
enable_validation: bool = False,
2729+
response_validation_error_http_code: HTTPStatus | int | None = None,
26672730
):
26682731
"""Amazon API Gateway HTTP API v2 payload resolver"""
26692732
super().__init__(
@@ -2673,6 +2736,7 @@ def __init__(
26732736
serializer,
26742737
strip_prefixes,
26752738
enable_validation,
2739+
response_validation_error_http_code,
26762740
)
26772741

26782742
def _get_base_path(self) -> str:
@@ -2701,9 +2765,18 @@ def __init__(
27012765
serializer: Callable[[dict], str] | None = None,
27022766
strip_prefixes: list[str | Pattern] | None = None,
27032767
enable_validation: bool = False,
2768+
response_validation_error_http_code: HTTPStatus | int | None = None,
27042769
):
27052770
"""Amazon Application Load Balancer (ALB) resolver"""
2706-
super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes, enable_validation)
2771+
super().__init__(
2772+
ProxyEventType.ALBEvent,
2773+
cors,
2774+
debug,
2775+
serializer,
2776+
strip_prefixes,
2777+
enable_validation,
2778+
response_validation_error_http_code,
2779+
)
27072780

27082781
def _get_base_path(self) -> str:
27092782
# ALB doesn't have a stage variable, so we just return an empty string

aws_lambda_powertools/event_handler/lambda_function_url.py

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
)
99

1010
if TYPE_CHECKING:
11+
from http import HTTPStatus
12+
1113
from aws_lambda_powertools.event_handler import CORSConfig
1214
from aws_lambda_powertools.utilities.data_classes import LambdaFunctionUrlEvent
1315

@@ -57,6 +59,7 @@ def __init__(
5759
serializer: Callable[[dict], str] | None = None,
5860
strip_prefixes: list[str | Pattern] | None = None,
5961
enable_validation: bool = False,
62+
response_validation_error_http_code: HTTPStatus | int | None = None,
6063
):
6164
super().__init__(
6265
ProxyEventType.LambdaFunctionUrlEvent,
@@ -65,6 +68,7 @@ def __init__(
6568
serializer,
6669
strip_prefixes,
6770
enable_validation,
71+
response_validation_error_http_code,
6872
)
6973

7074
def _get_base_path(self) -> str:

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
1919
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
20-
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
20+
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
2121
from aws_lambda_powertools.event_handler.openapi.params import Param
2222

2323
if TYPE_CHECKING:
@@ -58,7 +58,11 @@ def get_todos(): list[Todo]:
5858
```
5959
"""
6060

61-
def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
61+
def __init__(
62+
self,
63+
validation_serializer: Callable[[Any], str] | None = None,
64+
has_response_validation_error: bool = False,
65+
):
6266
"""
6367
Initialize the OpenAPIValidationMiddleware.
6468
@@ -67,8 +71,13 @@ def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
6771
validation_serializer : Callable, optional
6872
Optional serializer to use when serializing the response for validation.
6973
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
74+
75+
has_response_validation_error: bool, optional
76+
Optional flag used to distinguish between payload and validation errors.
77+
By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
7078
"""
7179
self._validation_serializer = validation_serializer
80+
self._has_response_validation_error = has_response_validation_error
7281

7382
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
7483
logger.debug("OpenAPIValidationMiddleware handler")
@@ -164,6 +173,8 @@ def _serialize_response(
164173
errors: list[dict[str, Any]] = []
165174
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
166175
if errors:
176+
if self._has_response_validation_error:
177+
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content)
167178
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
168179

169180
if hasattr(field, "serialize"):

aws_lambda_powertools/event_handler/openapi/exceptions.py

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
2323
self.body = body
2424

2525

26+
class ResponseValidationError(ValidationException):
27+
"""
28+
Raised when the response body does not match the OpenAPI schema
29+
"""
30+
31+
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
32+
super().__init__(errors)
33+
self.body = body
34+
35+
2636
class SerializationError(Exception):
2737
"""
2838
Base exception for all encoding errors

aws_lambda_powertools/event_handler/vpc_lattice.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
)
99

1010
if TYPE_CHECKING:
11+
from http import HTTPStatus
12+
1113
from aws_lambda_powertools.event_handler import CORSConfig
1214
from aws_lambda_powertools.utilities.data_classes import VPCLatticeEvent, VPCLatticeEventV2
1315

@@ -53,9 +55,18 @@ def __init__(
5355
serializer: Callable[[dict], str] | None = None,
5456
strip_prefixes: list[str | Pattern] | None = None,
5557
enable_validation: bool = False,
58+
response_validation_error_http_code: HTTPStatus | int | None = None,
5659
):
5760
"""Amazon VPC Lattice resolver"""
58-
super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation)
61+
super().__init__(
62+
ProxyEventType.VPCLatticeEvent,
63+
cors,
64+
debug,
65+
serializer,
66+
strip_prefixes,
67+
enable_validation,
68+
response_validation_error_http_code,
69+
)
5970

6071
def _get_base_path(self) -> str:
6172
return ""
@@ -102,9 +113,18 @@ def __init__(
102113
serializer: Callable[[dict], str] | None = None,
103114
strip_prefixes: list[str | Pattern] | None = None,
104115
enable_validation: bool = False,
116+
response_validation_error_http_code: HTTPStatus | int | None = None,
105117
):
106118
"""Amazon VPC Lattice resolver"""
107-
super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation)
119+
super().__init__(
120+
ProxyEventType.VPCLatticeEventV2,
121+
cors,
122+
debug,
123+
serializer,
124+
strip_prefixes,
125+
enable_validation,
126+
response_validation_error_http_code,
127+
)
108128

109129
def _get_base_path(self) -> str:
110130
return ""

docs/core/event_handler/api_gateway.md

+22-3
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ Let's rewrite the previous examples to signal our resolver what shape we expect
309309

310310
!!! info "By default, we hide extended error details for security reasons _(e.g., pydantic url, Pydantic code)_."
311311

312-
Any incoming request that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this:
312+
Any incoming request or and outgoing response that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this:
313313

314314
```json hl_lines="2 3" title="data_validation_error_unsanitized_output.json"
315315
--8<-- "examples/event_handler_rest/src/data_validation_error_unsanitized_output.json"
@@ -321,8 +321,6 @@ Here's an example where we catch validation errors, log all details for further
321321

322322
=== "data_validation_sanitized_error.py"
323323

324-
Note that Pydantic versions [1](https://docs.pydantic.dev/1.10/usage/models/#error-handling){target="_blank" rel="nofollow"} and [2](https://docs.pydantic.dev/latest/errors/errors/){target="_blank" rel="nofollow"} report validation detailed errors differently.
325-
326324
```python hl_lines="8 24-25 31"
327325
--8<-- "examples/event_handler_rest/src/data_validation_sanitized_error.py"
328326
```
@@ -398,6 +396,27 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou
398396
--8<-- "examples/event_handler_rest/src/validating_payload_subset_output.json"
399397
```
400398

399+
#### Validating responses
400+
401+
You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`.
402+
403+
=== "customizing_response_validation.py"
404+
405+
```python hl_lines="1 16 29 33"
406+
--8<-- "examples/event_handler_rest/src/customizing_response_validation.py"
407+
```
408+
409+
1. A response with status code set here will be returned if response data is not valid.
410+
2. Operation returns a string as oppose to a `Todo` object. This will lead to a `500` response as set in line 18.
411+
412+
=== "customizing_response_validation_exception.py"
413+
414+
```python hl_lines="1 18 38 39"
415+
--8<-- "examples/event_handler_rest/src/customizing_response_validation_exception.py"
416+
```
417+
418+
1. The distinct `ResponseValidationError` exception can be caught to customise the response.
419+
401420
#### Validating query strings
402421

403422
!!! info "We will automatically validate and inject incoming query strings via type annotation."
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from http import HTTPStatus
2+
from typing import Optional
3+
4+
import requests
5+
from pydantic import BaseModel, Field
6+
7+
from aws_lambda_powertools import Logger, Tracer
8+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
9+
from aws_lambda_powertools.logging import correlation_paths
10+
from aws_lambda_powertools.utilities.typing import LambdaContext
11+
12+
tracer = Tracer()
13+
logger = Logger()
14+
app = APIGatewayRestResolver(
15+
enable_validation=True,
16+
response_validation_error_http_code=HTTPStatus.INTERNAL_SERVER_ERROR, # (1)!
17+
)
18+
19+
20+
class Todo(BaseModel):
21+
userId: int
22+
id_: Optional[int] = Field(alias="id", default=None)
23+
title: str
24+
completed: bool
25+
26+
27+
@app.get("/todos_bad_response/<todo_id>")
28+
@tracer.capture_method
29+
def get_todo_by_id(todo_id: int) -> Todo:
30+
todo = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}")
31+
todo.raise_for_status()
32+
33+
return todo.json()["title"] # (2)!
34+
35+
36+
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP)
37+
@tracer.capture_lambda_handler
38+
def lambda_handler(event: dict, context: LambdaContext) -> dict:
39+
return app.resolve(event, context)

0 commit comments

Comments
 (0)