Skip to content

Commit 89eb55b

Browse files
amin-farjadiAmin Farjadileandrodamascena
authored
feat(event_handler): add route-level custom response validation in OpenAPI utility (#6341)
* feat(api-gateway-resolver): Add option for custom response validation error status code. * feat(docs): Added doc for custom response validation error responses. * feat(unit-test): Add tests for custom response validation error. * fix: Formatting. * fix(unit-test): fix failed CI. * feat(unit-test): add tests for incorrect types and invalid configs * refactor: rename response_validation_error_http_status to response_validation_error_http_code * refactor(tests): move unit tests into openapi_validation functional test file * feat: add route-specific custom response validation and tests * fix: except Route implementation * fix: put custom_response_validation_http_code before middleware * feat: route's custom response validation must take precedence over app's. * feat: added more tests. * refactor: improved error messagee and tests' descriptions. * feat: updated docs. * move veritifcation method of route custom http code to BaseRouter. * fix: add validate function for route http code to APIGatewayResolver not Router * feat: add custom_response_validation_http_code to the routes of Bedrock * fix: make mypy happy * fix: address comments * fix(openapi): add response for response validation error and definition for it * minor changes * minor changes --------- Co-authored-by: Amin Farjadi <[email protected]> Co-authored-by: Leandro Damascena <[email protected]>
1 parent 376f4f1 commit 89eb55b

File tree

10 files changed

+351
-11
lines changed

10 files changed

+351
-11
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+98-6
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
OpenAPIResponse,
3636
OpenAPIResponseContentModel,
3737
OpenAPIResponseContentSchema,
38+
response_validation_error_response_definition,
3839
validation_error_definition,
3940
validation_error_response_definition,
4041
)
@@ -319,6 +320,7 @@ def __init__(
319320
security: list[dict[str, list[str]]] | None = None,
320321
openapi_extensions: dict[str, Any] | None = None,
321322
deprecated: bool = False,
323+
custom_response_validation_http_code: HTTPStatus | None = None,
322324
middlewares: list[Callable[..., Response]] | None = None,
323325
):
324326
"""
@@ -360,11 +362,13 @@ def __init__(
360362
Additional OpenAPI extensions as a dictionary.
361363
deprecated: bool
362364
Whether or not to mark this route as deprecated in the OpenAPI schema
365+
custom_response_validation_http_code: int | HTTPStatus | None, optional
366+
Whether to have custom http status code for this route if response validation fails
363367
middlewares: list[Callable[..., Response]] | None
364368
The list of route middlewares to be called in order.
365369
"""
366370
self.method = method.upper()
367-
self.path = "/" if path.strip() == "" else path
371+
self.path = path if path.strip() else "/"
368372

369373
# OpenAPI spec only understands paths with { }. So we'll have to convert Powertools' < >.
370374
# https://swagger.io/specification/#path-templating
@@ -397,6 +401,8 @@ def __init__(
397401
# _body_field is used to cache the dependant model for the body field
398402
self._body_field: ModelField | None = None
399403

404+
self.custom_response_validation_http_code = custom_response_validation_http_code
405+
400406
def __call__(
401407
self,
402408
router_middlewares: list[Callable],
@@ -565,6 +571,16 @@ def _get_openapi_path(
565571
},
566572
}
567573

574+
# Add custom response validation response, if exists
575+
if self.custom_response_validation_http_code:
576+
http_code = self.custom_response_validation_http_code.value
577+
operation_responses[http_code] = {
578+
"description": "Response Validation Error",
579+
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}},
580+
}
581+
# Add model definition
582+
definitions["ResponseValidationError"] = response_validation_error_response_definition
583+
568584
# Add the response to the OpenAPI operation
569585
if self.responses:
570586
for status_code in list(self.responses):
@@ -942,6 +958,7 @@ def route(
942958
security: list[dict[str, list[str]]] | None = None,
943959
openapi_extensions: dict[str, Any] | None = None,
944960
deprecated: bool = False,
961+
custom_response_validation_http_code: int | HTTPStatus | None = None,
945962
middlewares: list[Callable[..., Any]] | None = None,
946963
) -> Callable[[AnyCallableT], AnyCallableT]:
947964
raise NotImplementedError()
@@ -1003,6 +1020,7 @@ def get(
10031020
security: list[dict[str, list[str]]] | None = None,
10041021
openapi_extensions: dict[str, Any] | None = None,
10051022
deprecated: bool = False,
1023+
custom_response_validation_http_code: int | HTTPStatus | None = None,
10061024
middlewares: list[Callable[..., Any]] | None = None,
10071025
) -> Callable[[AnyCallableT], AnyCallableT]:
10081026
"""Get route decorator with GET `method`
@@ -1043,6 +1061,7 @@ def lambda_handler(event, context):
10431061
security,
10441062
openapi_extensions,
10451063
deprecated,
1064+
custom_response_validation_http_code,
10461065
middlewares,
10471066
)
10481067

@@ -1062,6 +1081,7 @@ def post(
10621081
security: list[dict[str, list[str]]] | None = None,
10631082
openapi_extensions: dict[str, Any] | None = None,
10641083
deprecated: bool = False,
1084+
custom_response_validation_http_code: int | HTTPStatus | None = None,
10651085
middlewares: list[Callable[..., Any]] | None = None,
10661086
) -> Callable[[AnyCallableT], AnyCallableT]:
10671087
"""Post route decorator with POST `method`
@@ -1103,6 +1123,7 @@ def lambda_handler(event, context):
11031123
security,
11041124
openapi_extensions,
11051125
deprecated,
1126+
custom_response_validation_http_code,
11061127
middlewares,
11071128
)
11081129

@@ -1122,6 +1143,7 @@ def put(
11221143
security: list[dict[str, list[str]]] | None = None,
11231144
openapi_extensions: dict[str, Any] | None = None,
11241145
deprecated: bool = False,
1146+
custom_response_validation_http_code: int | HTTPStatus | None = None,
11251147
middlewares: list[Callable[..., Any]] | None = None,
11261148
) -> Callable[[AnyCallableT], AnyCallableT]:
11271149
"""Put route decorator with PUT `method`
@@ -1163,6 +1185,7 @@ def lambda_handler(event, context):
11631185
security,
11641186
openapi_extensions,
11651187
deprecated,
1188+
custom_response_validation_http_code,
11661189
middlewares,
11671190
)
11681191

@@ -1182,6 +1205,7 @@ def delete(
11821205
security: list[dict[str, list[str]]] | None = None,
11831206
openapi_extensions: dict[str, Any] | None = None,
11841207
deprecated: bool = False,
1208+
custom_response_validation_http_code: int | HTTPStatus | None = None,
11851209
middlewares: list[Callable[..., Any]] | None = None,
11861210
) -> Callable[[AnyCallableT], AnyCallableT]:
11871211
"""Delete route decorator with DELETE `method`
@@ -1222,6 +1246,7 @@ def lambda_handler(event, context):
12221246
security,
12231247
openapi_extensions,
12241248
deprecated,
1249+
custom_response_validation_http_code,
12251250
middlewares,
12261251
)
12271252

@@ -1241,6 +1266,7 @@ def patch(
12411266
security: list[dict[str, list[str]]] | None = None,
12421267
openapi_extensions: dict[str, Any] | None = None,
12431268
deprecated: bool = False,
1269+
custom_response_validation_http_code: int | HTTPStatus | None = None,
12441270
middlewares: list[Callable] | None = None,
12451271
) -> Callable[[AnyCallableT], AnyCallableT]:
12461272
"""Patch route decorator with PATCH `method`
@@ -1284,6 +1310,7 @@ def lambda_handler(event, context):
12841310
security,
12851311
openapi_extensions,
12861312
deprecated,
1313+
custom_response_validation_http_code,
12871314
middlewares,
12881315
)
12891316

@@ -1303,6 +1330,7 @@ def head(
13031330
security: list[dict[str, list[str]]] | None = None,
13041331
openapi_extensions: dict[str, Any] | None = None,
13051332
deprecated: bool = False,
1333+
custom_response_validation_http_code: int | HTTPStatus | None = None,
13061334
middlewares: list[Callable] | None = None,
13071335
) -> Callable[[AnyCallableT], AnyCallableT]:
13081336
"""Head route decorator with HEAD `method`
@@ -1345,6 +1373,7 @@ def lambda_handler(event, context):
13451373
security,
13461374
openapi_extensions,
13471375
deprecated,
1376+
custom_response_validation_http_code,
13481377
middlewares,
13491378
)
13501379

@@ -1571,6 +1600,7 @@ def _validate_response_validation_error_http_code(
15711600
response_validation_error_http_code: HTTPStatus | int | None,
15721601
enable_validation: bool,
15731602
) -> HTTPStatus:
1603+
15741604
if response_validation_error_http_code and not enable_validation:
15751605
msg = "'response_validation_error_http_code' cannot be set when enable_validation is False."
15761606
raise ValueError(msg)
@@ -1588,6 +1618,33 @@ def _validate_response_validation_error_http_code(
15881618

15891619
return response_validation_error_http_code or HTTPStatus.UNPROCESSABLE_ENTITY
15901620

1621+
def _add_resolver_response_validation_error_response_to_route(
1622+
self,
1623+
route_openapi_path: tuple[dict[str, Any], dict[str, Any]],
1624+
) -> tuple[dict[str, Any], dict[str, Any]]:
1625+
"""Adds resolver response validation error response to route's operations."""
1626+
path, path_definitions = route_openapi_path
1627+
if self._has_response_validation_error and "ResponseValidationError" not in path_definitions:
1628+
response_validation_error_response = {
1629+
"description": "Response Validation Error",
1630+
"content": {
1631+
"application/json": {
1632+
"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"},
1633+
},
1634+
},
1635+
}
1636+
http_code = self._response_validation_error_http_code.value
1637+
for operation in path.values():
1638+
operation["responses"][http_code] = response_validation_error_response
1639+
return path, path_definitions
1640+
1641+
def _generate_schemas(self, definitions: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
1642+
schemas = {k: definitions[k] for k in sorted(definitions)}
1643+
# add response validation error definition
1644+
if self._response_validation_error_http_code:
1645+
schemas.setdefault("ResponseValidationError", response_validation_error_response_definition)
1646+
return schemas
1647+
15911648
def get_openapi_schema(
15921649
self,
15931650
*,
@@ -1739,14 +1796,14 @@ def get_openapi_schema(
17391796
field_mapping=field_mapping,
17401797
)
17411798
if result:
1742-
path, path_definitions = result
1799+
path, path_definitions = self._add_resolver_response_validation_error_response_to_route(result)
17431800
if path:
17441801
paths.setdefault(route.openapi_path, {}).update(path)
17451802
if path_definitions:
17461803
definitions.update(path_definitions)
17471804

17481805
if definitions:
1749-
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
1806+
components["schemas"] = self._generate_schemas(definitions)
17501807
if security_schemes:
17511808
components["securitySchemes"] = security_schemes
17521809
if components:
@@ -2108,6 +2165,29 @@ def swagger_handler():
21082165
body=body,
21092166
)
21102167

2168+
def _validate_route_response_validation_error_http_code(
2169+
self,
2170+
custom_response_validation_http_code: int | HTTPStatus | None,
2171+
) -> HTTPStatus | None:
2172+
if custom_response_validation_http_code and not self._enable_validation:
2173+
msg = (
2174+
"'custom_response_validation_http_code' cannot be set for route when enable_validation is False "
2175+
"on resolver."
2176+
)
2177+
raise ValueError(msg)
2178+
2179+
if (
2180+
not isinstance(custom_response_validation_http_code, HTTPStatus)
2181+
and custom_response_validation_http_code is not None
2182+
):
2183+
try:
2184+
custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code)
2185+
except ValueError:
2186+
msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501
2187+
raise ValueError(msg) from None
2188+
2189+
return custom_response_validation_http_code
2190+
21112191
def route(
21122192
self,
21132193
rule: str,
@@ -2125,10 +2205,15 @@ def route(
21252205
security: list[dict[str, list[str]]] | None = None,
21262206
openapi_extensions: dict[str, Any] | None = None,
21272207
deprecated: bool = False,
2208+
custom_response_validation_http_code: int | HTTPStatus | None = None,
21282209
middlewares: list[Callable[..., Any]] | None = None,
21292210
) -> Callable[[AnyCallableT], AnyCallableT]:
21302211
"""Route decorator includes parameter `method`"""
21312212

2213+
custom_response_validation_http_code = self._validate_route_response_validation_error_http_code(
2214+
custom_response_validation_http_code,
2215+
)
2216+
21322217
def register_resolver(func: AnyCallableT) -> AnyCallableT:
21332218
methods = (method,) if isinstance(method, str) else method
21342219
logger.debug(f"Adding route using rule {rule} and methods: {','.join(m.upper() for m in methods)}")
@@ -2154,6 +2239,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT:
21542239
security,
21552240
openapi_extensions,
21562241
deprecated,
2242+
custom_response_validation_http_code,
21572243
middlewares,
21582244
)
21592245

@@ -2523,15 +2609,17 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
25232609
)
25242610

25252611
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
2526-
# 'self._response_validation_error_http_code' is not None
2612+
# 'self._response_validation_error_http_code' is not None or
2613+
# when route has custom_response_validation_http_code
25272614
if isinstance(exp, ResponseValidationError):
2528-
http_code = self._response_validation_error_http_code
2615+
# route validation must take precedence over app validation
2616+
http_code = route.custom_response_validation_http_code or self._response_validation_error_http_code
25292617
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
25302618
return self._response_builder_class(
25312619
response=Response(
25322620
status_code=http_code.value,
25332621
content_type=content_types.APPLICATION_JSON,
2534-
body={"statusCode": self._response_validation_error_http_code, "detail": errors},
2622+
body={"statusCode": http_code, "detail": errors},
25352623
),
25362624
serializer=self._serializer,
25372625
route=route,
@@ -2682,6 +2770,7 @@ def route(
26822770
security: list[dict[str, list[str]]] | None = None,
26832771
openapi_extensions: dict[str, Any] | None = None,
26842772
deprecated: bool = False,
2773+
custom_response_validation_http_code: int | HTTPStatus | None = None,
26852774
middlewares: list[Callable[..., Any]] | None = None,
26862775
) -> Callable[[AnyCallableT], AnyCallableT]:
26872776
def register_route(func: AnyCallableT) -> AnyCallableT:
@@ -2708,6 +2797,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT:
27082797
frozen_security,
27092798
frozen_openapi_extensions,
27102799
deprecated,
2800+
custom_response_validation_http_code,
27112801
)
27122802

27132803
# Collate Middleware for routes
@@ -2794,6 +2884,7 @@ def route(
27942884
security: list[dict[str, list[str]]] | None = None,
27952885
openapi_extensions: dict[str, Any] | None = None,
27962886
deprecated: bool = False,
2887+
custom_response_validation_http_code: int | HTTPStatus | None = None,
27972888
middlewares: list[Callable[..., Any]] | None = None,
27982889
) -> Callable[[AnyCallableT], AnyCallableT]:
27992890
# NOTE: see #1552 for more context.
@@ -2813,6 +2904,7 @@ def route(
28132904
security,
28142905
openapi_extensions,
28152906
deprecated,
2907+
custom_response_validation_http_code,
28162908
middlewares,
28172909
)
28182910

0 commit comments

Comments
 (0)