Skip to content

Commit f861fc1

Browse files
amin-farjadiAmin Farjadileandrodamascena
committed
feat(event_handler): add route-level custom response validation in OpenAPI utility (#6341)
* feat(api-gateway-resolver): Add option for custom response validation error status code. * feat(docs): Added doc for custom response validation error responses. * feat(unit-test): Add tests for custom response validation error. * fix: Formatting. * fix(unit-test): fix failed CI. * feat(unit-test): add tests for incorrect types and invalid configs * refactor: rename response_validation_error_http_status to response_validation_error_http_code * refactor(tests): move unit tests into openapi_validation functional test file * feat: add route-specific custom response validation and tests * fix: except Route implementation * fix: put custom_response_validation_http_code before middleware * feat: route's custom response validation must take precedence over app's. * feat: added more tests. * refactor: improved error messagee and tests' descriptions. * feat: updated docs. * move veritifcation method of route custom http code to BaseRouter. * fix: add validate function for route http code to APIGatewayResolver not Router * feat: add custom_response_validation_http_code to the routes of Bedrock * fix: make mypy happy * fix: address comments * fix(openapi): add response for response validation error and definition for it * minor changes * minor changes --------- Co-authored-by: Amin Farjadi <[email protected]> Co-authored-by: Leandro Damascena <[email protected]>
1 parent 628a4dc commit f861fc1

File tree

10 files changed

+351
-11
lines changed

10 files changed

+351
-11
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+98-6
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
OpenAPIResponse,
3737
OpenAPIResponseContentModel,
3838
OpenAPIResponseContentSchema,
39+
response_validation_error_response_definition,
3940
validation_error_definition,
4041
validation_error_response_definition,
4142
)
@@ -320,6 +321,7 @@ def __init__(
320321
security: list[dict[str, list[str]]] | None = None,
321322
openapi_extensions: dict[str, Any] | None = None,
322323
deprecated: bool = False,
324+
custom_response_validation_http_code: HTTPStatus | None = None,
323325
middlewares: list[Callable[..., Response]] | None = None,
324326
):
325327
"""
@@ -361,11 +363,13 @@ def __init__(
361363
Additional OpenAPI extensions as a dictionary.
362364
deprecated: bool
363365
Whether or not to mark this route as deprecated in the OpenAPI schema
366+
custom_response_validation_http_code: int | HTTPStatus | None, optional
367+
Whether to have custom http status code for this route if response validation fails
364368
middlewares: list[Callable[..., Response]] | None
365369
The list of route middlewares to be called in order.
366370
"""
367371
self.method = method.upper()
368-
self.path = "/" if path.strip() == "" else path
372+
self.path = path if path.strip() else "/"
369373

370374
# OpenAPI spec only understands paths with { }. So we'll have to convert Powertools' < >.
371375
# https://swagger.io/specification/#path-templating
@@ -398,6 +402,8 @@ def __init__(
398402
# _body_field is used to cache the dependant model for the body field
399403
self._body_field: ModelField | None = None
400404

405+
self.custom_response_validation_http_code = custom_response_validation_http_code
406+
401407
def __call__(
402408
self,
403409
router_middlewares: list[Callable],
@@ -566,6 +572,16 @@ def _get_openapi_path(
566572
},
567573
}
568574

575+
# Add custom response validation response, if exists
576+
if self.custom_response_validation_http_code:
577+
http_code = self.custom_response_validation_http_code.value
578+
operation_responses[http_code] = {
579+
"description": "Response Validation Error",
580+
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}},
581+
}
582+
# Add model definition
583+
definitions["ResponseValidationError"] = response_validation_error_response_definition
584+
569585
# Add the response to the OpenAPI operation
570586
if self.responses:
571587
for status_code in list(self.responses):
@@ -943,6 +959,7 @@ def route(
943959
security: list[dict[str, list[str]]] | None = None,
944960
openapi_extensions: dict[str, Any] | None = None,
945961
deprecated: bool = False,
962+
custom_response_validation_http_code: int | HTTPStatus | None = None,
946963
middlewares: list[Callable[..., Any]] | None = None,
947964
) -> Callable[[AnyCallableT], AnyCallableT]:
948965
raise NotImplementedError()
@@ -1004,6 +1021,7 @@ def get(
10041021
security: list[dict[str, list[str]]] | None = None,
10051022
openapi_extensions: dict[str, Any] | None = None,
10061023
deprecated: bool = False,
1024+
custom_response_validation_http_code: int | HTTPStatus | None = None,
10071025
middlewares: list[Callable[..., Any]] | None = None,
10081026
) -> Callable[[AnyCallableT], AnyCallableT]:
10091027
"""Get route decorator with GET `method`
@@ -1044,6 +1062,7 @@ def lambda_handler(event, context):
10441062
security,
10451063
openapi_extensions,
10461064
deprecated,
1065+
custom_response_validation_http_code,
10471066
middlewares,
10481067
)
10491068

@@ -1063,6 +1082,7 @@ def post(
10631082
security: list[dict[str, list[str]]] | None = None,
10641083
openapi_extensions: dict[str, Any] | None = None,
10651084
deprecated: bool = False,
1085+
custom_response_validation_http_code: int | HTTPStatus | None = None,
10661086
middlewares: list[Callable[..., Any]] | None = None,
10671087
) -> Callable[[AnyCallableT], AnyCallableT]:
10681088
"""Post route decorator with POST `method`
@@ -1104,6 +1124,7 @@ def lambda_handler(event, context):
11041124
security,
11051125
openapi_extensions,
11061126
deprecated,
1127+
custom_response_validation_http_code,
11071128
middlewares,
11081129
)
11091130

@@ -1123,6 +1144,7 @@ def put(
11231144
security: list[dict[str, list[str]]] | None = None,
11241145
openapi_extensions: dict[str, Any] | None = None,
11251146
deprecated: bool = False,
1147+
custom_response_validation_http_code: int | HTTPStatus | None = None,
11261148
middlewares: list[Callable[..., Any]] | None = None,
11271149
) -> Callable[[AnyCallableT], AnyCallableT]:
11281150
"""Put route decorator with PUT `method`
@@ -1164,6 +1186,7 @@ def lambda_handler(event, context):
11641186
security,
11651187
openapi_extensions,
11661188
deprecated,
1189+
custom_response_validation_http_code,
11671190
middlewares,
11681191
)
11691192

@@ -1183,6 +1206,7 @@ def delete(
11831206
security: list[dict[str, list[str]]] | None = None,
11841207
openapi_extensions: dict[str, Any] | None = None,
11851208
deprecated: bool = False,
1209+
custom_response_validation_http_code: int | HTTPStatus | None = None,
11861210
middlewares: list[Callable[..., Any]] | None = None,
11871211
) -> Callable[[AnyCallableT], AnyCallableT]:
11881212
"""Delete route decorator with DELETE `method`
@@ -1223,6 +1247,7 @@ def lambda_handler(event, context):
12231247
security,
12241248
openapi_extensions,
12251249
deprecated,
1250+
custom_response_validation_http_code,
12261251
middlewares,
12271252
)
12281253

@@ -1242,6 +1267,7 @@ def patch(
12421267
security: list[dict[str, list[str]]] | None = None,
12431268
openapi_extensions: dict[str, Any] | None = None,
12441269
deprecated: bool = False,
1270+
custom_response_validation_http_code: int | HTTPStatus | None = None,
12451271
middlewares: list[Callable] | None = None,
12461272
) -> Callable[[AnyCallableT], AnyCallableT]:
12471273
"""Patch route decorator with PATCH `method`
@@ -1285,6 +1311,7 @@ def lambda_handler(event, context):
12851311
security,
12861312
openapi_extensions,
12871313
deprecated,
1314+
custom_response_validation_http_code,
12881315
middlewares,
12891316
)
12901317

@@ -1304,6 +1331,7 @@ def head(
13041331
security: list[dict[str, list[str]]] | None = None,
13051332
openapi_extensions: dict[str, Any] | None = None,
13061333
deprecated: bool = False,
1334+
custom_response_validation_http_code: int | HTTPStatus | None = None,
13071335
middlewares: list[Callable] | None = None,
13081336
) -> Callable[[AnyCallableT], AnyCallableT]:
13091337
"""Head route decorator with HEAD `method`
@@ -1346,6 +1374,7 @@ def lambda_handler(event, context):
13461374
security,
13471375
openapi_extensions,
13481376
deprecated,
1377+
custom_response_validation_http_code,
13491378
middlewares,
13501379
)
13511380

@@ -1573,6 +1602,7 @@ def _validate_response_validation_error_http_code(
15731602
response_validation_error_http_code: HTTPStatus | int | None,
15741603
enable_validation: bool,
15751604
) -> HTTPStatus:
1605+
15761606
if response_validation_error_http_code and not enable_validation:
15771607
msg = "'response_validation_error_http_code' cannot be set when enable_validation is False."
15781608
raise ValueError(msg)
@@ -1590,6 +1620,33 @@ def _validate_response_validation_error_http_code(
15901620

15911621
return response_validation_error_http_code or HTTPStatus.UNPROCESSABLE_ENTITY
15921622

1623+
def _add_resolver_response_validation_error_response_to_route(
1624+
self,
1625+
route_openapi_path: tuple[dict[str, Any], dict[str, Any]],
1626+
) -> tuple[dict[str, Any], dict[str, Any]]:
1627+
"""Adds resolver response validation error response to route's operations."""
1628+
path, path_definitions = route_openapi_path
1629+
if self._has_response_validation_error and "ResponseValidationError" not in path_definitions:
1630+
response_validation_error_response = {
1631+
"description": "Response Validation Error",
1632+
"content": {
1633+
"application/json": {
1634+
"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"},
1635+
},
1636+
},
1637+
}
1638+
http_code = self._response_validation_error_http_code.value
1639+
for operation in path.values():
1640+
operation["responses"][http_code] = response_validation_error_response
1641+
return path, path_definitions
1642+
1643+
def _generate_schemas(self, definitions: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
1644+
schemas = {k: definitions[k] for k in sorted(definitions)}
1645+
# add response validation error definition
1646+
if self._response_validation_error_http_code:
1647+
schemas.setdefault("ResponseValidationError", response_validation_error_response_definition)
1648+
return schemas
1649+
15931650
def get_openapi_schema(
15941651
self,
15951652
*,
@@ -1741,14 +1798,14 @@ def get_openapi_schema(
17411798
field_mapping=field_mapping,
17421799
)
17431800
if result:
1744-
path, path_definitions = result
1801+
path, path_definitions = self._add_resolver_response_validation_error_response_to_route(result)
17451802
if path:
17461803
paths.setdefault(route.openapi_path, {}).update(path)
17471804
if path_definitions:
17481805
definitions.update(path_definitions)
17491806

17501807
if definitions:
1751-
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
1808+
components["schemas"] = self._generate_schemas(definitions)
17521809
if security_schemes:
17531810
components["securitySchemes"] = security_schemes
17541811
if components:
@@ -2110,6 +2167,29 @@ def swagger_handler():
21102167
body=body,
21112168
)
21122169

2170+
def _validate_route_response_validation_error_http_code(
2171+
self,
2172+
custom_response_validation_http_code: int | HTTPStatus | None,
2173+
) -> HTTPStatus | None:
2174+
if custom_response_validation_http_code and not self._enable_validation:
2175+
msg = (
2176+
"'custom_response_validation_http_code' cannot be set for route when enable_validation is False "
2177+
"on resolver."
2178+
)
2179+
raise ValueError(msg)
2180+
2181+
if (
2182+
not isinstance(custom_response_validation_http_code, HTTPStatus)
2183+
and custom_response_validation_http_code is not None
2184+
):
2185+
try:
2186+
custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code)
2187+
except ValueError:
2188+
msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501
2189+
raise ValueError(msg) from None
2190+
2191+
return custom_response_validation_http_code
2192+
21132193
def route(
21142194
self,
21152195
rule: str,
@@ -2127,10 +2207,15 @@ def route(
21272207
security: list[dict[str, list[str]]] | None = None,
21282208
openapi_extensions: dict[str, Any] | None = None,
21292209
deprecated: bool = False,
2210+
custom_response_validation_http_code: int | HTTPStatus | None = None,
21302211
middlewares: list[Callable[..., Any]] | None = None,
21312212
) -> Callable[[AnyCallableT], AnyCallableT]:
21322213
"""Route decorator includes parameter `method`"""
21332214

2215+
custom_response_validation_http_code = self._validate_route_response_validation_error_http_code(
2216+
custom_response_validation_http_code,
2217+
)
2218+
21342219
def register_resolver(func: AnyCallableT) -> AnyCallableT:
21352220
methods = (method,) if isinstance(method, str) else method
21362221
logger.debug(f"Adding route using rule {rule} and methods: {','.join(m.upper() for m in methods)}")
@@ -2156,6 +2241,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT:
21562241
security,
21572242
openapi_extensions,
21582243
deprecated,
2244+
custom_response_validation_http_code,
21592245
middlewares,
21602246
)
21612247

@@ -2509,15 +2595,17 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
25092595
)
25102596

25112597
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
2512-
# 'self._response_validation_error_http_code' is not None
2598+
# 'self._response_validation_error_http_code' is not None or
2599+
# when route has custom_response_validation_http_code
25132600
if isinstance(exp, ResponseValidationError):
2514-
http_code = self._response_validation_error_http_code
2601+
# route validation must take precedence over app validation
2602+
http_code = route.custom_response_validation_http_code or self._response_validation_error_http_code
25152603
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
25162604
return self._response_builder_class(
25172605
response=Response(
25182606
status_code=http_code.value,
25192607
content_type=content_types.APPLICATION_JSON,
2520-
body={"statusCode": self._response_validation_error_http_code, "detail": errors},
2608+
body={"statusCode": http_code, "detail": errors},
25212609
),
25222610
serializer=self._serializer,
25232611
route=route,
@@ -2668,6 +2756,7 @@ def route(
26682756
security: list[dict[str, list[str]]] | None = None,
26692757
openapi_extensions: dict[str, Any] | None = None,
26702758
deprecated: bool = False,
2759+
custom_response_validation_http_code: int | HTTPStatus | None = None,
26712760
middlewares: list[Callable[..., Any]] | None = None,
26722761
) -> Callable[[AnyCallableT], AnyCallableT]:
26732762
def register_route(func: AnyCallableT) -> AnyCallableT:
@@ -2694,6 +2783,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT:
26942783
frozen_security,
26952784
frozen_openapi_extensions,
26962785
deprecated,
2786+
custom_response_validation_http_code,
26972787
)
26982788

26992789
# Collate Middleware for routes
@@ -2780,6 +2870,7 @@ def route(
27802870
security: list[dict[str, list[str]]] | None = None,
27812871
openapi_extensions: dict[str, Any] | None = None,
27822872
deprecated: bool = False,
2873+
custom_response_validation_http_code: int | HTTPStatus | None = None,
27832874
middlewares: list[Callable[..., Any]] | None = None,
27842875
) -> Callable[[AnyCallableT], AnyCallableT]:
27852876
# NOTE: see #1552 for more context.
@@ -2799,6 +2890,7 @@ def route(
27992890
security,
28002891
openapi_extensions,
28012892
deprecated,
2893+
custom_response_validation_http_code,
28022894
middlewares,
28032895
)
28042896

0 commit comments

Comments
 (0)