Skip to content

feat(event_handler): add route-level custom response validation in OpenAPI utility #6341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
85321d1
feat(api-gateway-resolver): Add option for custom response validation…
Feb 28, 2025
aa7cf6f
feat(docs): Added doc for custom response validation error responses.
Feb 28, 2025
cb7fd6d
feat(unit-test): Add tests for custom response validation error.
Feb 28, 2025
1228632
fix: Formatting.
Feb 28, 2025
b95d521
fix(unit-test): fix failed CI.
Feb 28, 2025
f849930
feat(unit-test): add tests for incorrect types and invalid configs
Feb 28, 2025
bafd19c
refactor: rename response_validation_error_http_status to response_va…
amin-farjadi Mar 7, 2025
9b09bb7
refactor(tests): move unit tests into openapi_validation functional t…
amin-farjadi Mar 7, 2025
bbbd989
feat: add route-specific custom response validation and tests
amin-farjadi Mar 7, 2025
ce7be15
fix: except Route implementation
amin-farjadi Mar 18, 2025
95d9aee
fix: put custom_response_validation_http_code before middleware
amin-farjadi Mar 21, 2025
210b765
feat: route's custom response validation must take precedence over ap…
amin-farjadi Mar 23, 2025
575e713
feat: added more tests.
amin-farjadi Mar 23, 2025
440a3f4
refactor: improved error messagee and tests' descriptions.
amin-farjadi Mar 23, 2025
249554f
feat: updated docs.
amin-farjadi Mar 25, 2025
d0eadf0
move veritifcation method of route custom http code to BaseRouter.
amin-farjadi Mar 25, 2025
2316637
Merge branch 'develop' into feature/route-custom-response-validation
amin-farjadi Mar 25, 2025
59bb4aa
fix: add validate function for route http code to APIGatewayResolver …
amin-farjadi Mar 25, 2025
020c973
feat: add custom_response_validation_http_code to the routes of Bedrock
amin-farjadi Mar 25, 2025
5ea8ffa
fix: make mypy happy
amin-farjadi Mar 25, 2025
ac5dbf4
Merge branch 'develop' into feature/route-custom-response-validation
amin-farjadi Mar 25, 2025
d23be99
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Mar 25, 2025
ec113cb
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Mar 27, 2025
5794c27
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Mar 31, 2025
3e5fb6e
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Apr 2, 2025
6d00446
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Apr 4, 2025
fca7db0
fix: address comments
amin-farjadi Apr 9, 2025
afbda76
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Apr 9, 2025
2d7a73d
fix(openapi): add response for response validation error and definiti…
amin-farjadi Apr 9, 2025
1686c5c
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Apr 10, 2025
ba7d6c7
minor changes
leandrodamascena Apr 10, 2025
9e04a63
minor changes
leandrodamascena Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: HTTPStatus | None = None,
middlewares: list[Callable[..., Response]] | None = None,
):
"""
Expand Down Expand Up @@ -360,8 +361,11 @@
Additional OpenAPI extensions as a dictionary.
deprecated: bool
Whether or not to mark this route as deprecated in the OpenAPI schema
custom_response_validation_http_code: HTTPStatus | None, optional
Whether to have custom http status code for this route if response validation fails
middlewares: list[Callable[..., Response]] | None
The list of route middlewares to be called in order.
# TODO
"""
self.method = method.upper()
self.path = "/" if path.strip() == "" else path
Expand Down Expand Up @@ -397,6 +401,8 @@
# _body_field is used to cache the dependant model for the body field
self._body_field: ModelField | None = None

self.custom_response_validation_http_code = custom_response_validation_http_code

def __call__(
self,
router_middlewares: list[Callable],
Expand Down Expand Up @@ -505,7 +511,7 @@

return self._body_field

def _get_openapi_path(
def _get_openapi_path( # noqa: PLR0912
self,
*,
dependant: Dependant,
Expand Down Expand Up @@ -565,6 +571,14 @@
},
}

# Add custom response validation response, if exists
if self.custom_response_validation_http_code:
http_code = self.custom_response_validation_http_code.value
operation_responses[http_code] = {

Check warning on line 577 in aws_lambda_powertools/event_handler/api_gateway.py

View check run for this annotation

Codecov / codecov/patch

aws_lambda_powertools/event_handler/api_gateway.py#L576-L577

Added lines #L576 - L577 were not covered by tests
"description": "Response Validation Error",
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}},
}

# Add the response to the OpenAPI operation
if self.responses:
for status_code in list(self.responses):
Expand Down Expand Up @@ -942,6 +956,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
raise NotImplementedError()
Expand Down Expand Up @@ -1003,6 +1018,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Get route decorator with GET `method`
Expand Down Expand Up @@ -1043,6 +1059,7 @@
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand All @@ -1062,6 +1079,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Post route decorator with POST `method`
Expand Down Expand Up @@ -1103,6 +1121,7 @@
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand All @@ -1122,6 +1141,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Put route decorator with PUT `method`
Expand Down Expand Up @@ -1163,6 +1183,7 @@
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand All @@ -1182,6 +1203,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Delete route decorator with DELETE `method`
Expand Down Expand Up @@ -1222,6 +1244,7 @@
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand All @@ -1241,6 +1264,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Patch route decorator with PATCH `method`
Expand Down Expand Up @@ -1284,6 +1308,7 @@
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand All @@ -1303,6 +1328,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Head route decorator with HEAD `method`
Expand Down Expand Up @@ -1345,6 +1371,7 @@
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand Down Expand Up @@ -2108,6 +2135,29 @@
body=body,
)

def _validate_route_response_validation_error_http_code(
self,
custom_response_validation_http_code: int | HTTPStatus | None,
) -> HTTPStatus | None:
if custom_response_validation_http_code and not self._enable_validation:
msg = (
"'custom_response_validation_http_code' cannot be set for route when enable_validation is False "
"on resolver."
)
raise ValueError(msg)

if (
not isinstance(custom_response_validation_http_code, HTTPStatus)
and custom_response_validation_http_code is not None
):
try:
custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code)
except ValueError:
msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501
raise ValueError(msg) from None

return custom_response_validation_http_code

def route(
self,
rule: str,
Expand All @@ -2125,10 +2175,15 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Route decorator includes parameter `method`"""

custom_response_validation_http_code = self._validate_route_response_validation_error_http_code(
custom_response_validation_http_code,
)

def register_resolver(func: AnyCallableT) -> AnyCallableT:
methods = (method,) if isinstance(method, str) else method
logger.debug(f"Adding route using rule {rule} and methods: {','.join(m.upper() for m in methods)}")
Expand All @@ -2154,6 +2209,7 @@
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand Down Expand Up @@ -2523,15 +2579,22 @@
)

# OpenAPIValidationMiddleware will only raise ResponseValidationError when
# 'self._response_validation_error_http_code' is not None
# 'self._response_validation_error_http_code' is not None or
# when route has custom_response_validation_http_code
if isinstance(exp, ResponseValidationError):
http_code = self._response_validation_error_http_code
# route validation must take precedence over app validation
route_response_validation_http_code = route.custom_response_validation_http_code
http_code = (
route_response_validation_http_code
if route_response_validation_http_code
else self._response_validation_error_http_code
)
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
return self._response_builder_class(
response=Response(
status_code=http_code.value,
content_type=content_types.APPLICATION_JSON,
body={"statusCode": self._response_validation_error_http_code, "detail": errors},
body={"statusCode": http_code, "detail": errors},
),
serializer=self._serializer,
route=route,
Expand Down Expand Up @@ -2682,6 +2745,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
def register_route(func: AnyCallableT) -> AnyCallableT:
Expand All @@ -2708,6 +2772,7 @@
frozen_security,
frozen_openapi_extensions,
deprecated,
custom_response_validation_http_code,
)

# Collate Middleware for routes
Expand Down Expand Up @@ -2794,6 +2859,7 @@
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
# NOTE: see #1552 for more context.
Expand All @@ -2813,6 +2879,7 @@
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand Down
11 changes: 11 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION

if TYPE_CHECKING:
from http import HTTPStatus
from re import Match

from aws_lambda_powertools.event_handler.openapi.models import Contact, License, SecurityScheme, Server, Tag
Expand Down Expand Up @@ -109,6 +110,7 @@ def get( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
openapi_extensions = None
Expand All @@ -129,6 +131,7 @@ def get( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand All @@ -148,6 +151,7 @@ def post( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
openapi_extensions = None
Expand All @@ -168,6 +172,7 @@ def post( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand All @@ -187,6 +192,7 @@ def put( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
openapi_extensions = None
Expand All @@ -207,6 +213,7 @@ def put( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand All @@ -226,6 +233,7 @@ def patch( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
):
openapi_extensions = None
Expand All @@ -246,6 +254,7 @@ def patch( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand All @@ -265,6 +274,7 @@ def delete( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
openapi_extensions = None
Expand All @@ -285,6 +295,7 @@ def delete( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _handle_response(self, *, route: Route, response: Response):
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
has_route_custom_response_validation=route.custom_response_validation_http_code is not None,
)

return response
Expand All @@ -165,6 +166,7 @@ def _serialize_response(
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
has_route_custom_response_validation: bool = False,
) -> Any:
"""
Serialize the response content according to the field type.
Expand All @@ -173,8 +175,16 @@ def _serialize_response(
errors: list[dict[str, Any]] = []
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
if errors:
# route-level validation must take precedence over app-level
if has_route_custom_response_validation:
raise ResponseValidationError(
errors=_normalize_errors(errors),
body=response_content,
source="route",
)
if self._has_response_validation_error:
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content)
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")

raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)

if hasattr(field, "serialize"):
Expand Down
Loading
Loading