Skip to content

feat(event_handler): add custom response validation in OpenAPI utility #6189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cc27ba7
feat(openapi-validation): Add response validation flag and distinct e…
Feb 28, 2025
bc69d18
feat(api-gateway-resolver): Add option for custom response validation…
Feb 28, 2025
6ddfdc0
feat(docs): Added doc for custom response validation error responses.
Feb 28, 2025
a9be196
refactor(docs): Make exception handler function name better.
Feb 28, 2025
276d7cd
feat(unit-test): Add tests for custom response validation error.
Feb 28, 2025
fb49e9b
fix: Formatting.
Feb 28, 2025
df105dc
fix(docs): Fix grammar in response validation docs
Feb 28, 2025
63fd201
fix(unit-test): fix failed CI.
Feb 28, 2025
1c33611
bugfix(lint): Ignore lint error FA102, irrelevant for python >=3.9
Feb 28, 2025
f8ead84
refactor: make response_validation_error_http_status accept more type…
Feb 28, 2025
eb2430b
feat(unit-test): add tests for incorrect types and invalid configs
Feb 28, 2025
d6b7638
Merge branch 'develop' into feature/response-validation
leandrodamascena Mar 1, 2025
0090692
Merge branch 'develop' into feature/response-validation
leandrodamascena Mar 3, 2025
13b7380
Merge branch 'develop' into feature/response-validation
leandrodamascena Mar 4, 2025
218c666
Merge branch 'develop' into feature/response-validation
leandrodamascena Mar 4, 2025
82918c7
Merge branch 'develop' into feature/response-validation
leandrodamascena Mar 9, 2025
2a4d57f
refactor: rename response_validation_error_http_status to response_va…
amin-farjadi Mar 7, 2025
fece0e8
refactor(api_gateway): add method for validating response_validation_…
amin-farjadi Mar 7, 2025
f85c749
fix(api_gateway): fix type and docstring for response_validation_erro…
amin-farjadi Mar 7, 2025
c4f0819
fix(api_gateway): remove unncessary check of response_validation_erro…
amin-farjadi Mar 7, 2025
8fe4edc
fix(openapi-validation): docstring for has_response_validation_error …
amin-farjadi Mar 7, 2025
f89b598
refactor(tests): move unit tests into openapi_validation functional t…
amin-farjadi Mar 7, 2025
f60b812
Merge branch 'develop' into feature/response-validation
leandrodamascena Mar 10, 2025
aaa0086
fix(tests): skipping validation for falsy response
Mar 12, 2025
6d5f913
Merge branch 'develop' into feature/response-validation
amin-farjadi Mar 12, 2025
bd93ee6
Making Ruff happy
leandrodamascena Mar 17, 2025
558c03c
Refactoring documentation
leandrodamascena Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, SchemaValidationError
from aws_lambda_powertools.event_handler.openapi.exceptions import (
RequestValidationError,
ResponseValidationError,
SchemaValidationError,
)
from aws_lambda_powertools.event_handler.openapi.types import (
COMPONENT_REF_PREFIX,
METHODS_WITH_BODY,
Expand Down Expand Up @@ -1496,6 +1500,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status=None,
):
"""
Parameters
Expand All @@ -1515,6 +1520,8 @@ def __init__(
Each prefix can be a static string or a compiled regex pattern
enable_validation: bool | None
Enables validation of the request body against the route schema, by default False.
response_validation_error_http_status
Enables response validation and sets returned status code if response is not validated.
"""
self._proxy_type = proxy_type
self._dynamic_routes: list[Route] = []
Expand All @@ -1530,6 +1537,28 @@ def __init__(
self.context: dict = {} # early init as customers might add context before event resolution
self.processed_stack_frames = []
self._response_builder_class = ResponseBuilder[BaseProxyEvent]
self._has_response_validation_error = response_validation_error_http_status is not None

if response_validation_error_http_status and not enable_validation:
msg = "'response_validation_error_http_status' cannot be set when enable_validation is False."
raise ValueError(msg)

if (
not isinstance(response_validation_error_http_status, HTTPStatus)
and response_validation_error_http_status is not None
):

try:
response_validation_error_http_status = HTTPStatus(response_validation_error_http_status)
except ValueError:
msg = f"'{response_validation_error_http_status}' must be an integer representing an HTTP status code."
raise ValueError(msg) from None

self._response_validation_error_http_status = (
response_validation_error_http_status
if response_validation_error_http_status
else HTTPStatus.UNPROCESSABLE_ENTITY
)

# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
Expand All @@ -1539,7 +1568,14 @@ def __init__(

# Note the serializer argument: only use custom serializer if provided by the caller
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
self.use([OpenAPIValidationMiddleware(validation_serializer=serializer)])
self.use(
[
OpenAPIValidationMiddleware(
validation_serializer=serializer,
has_response_validation_error=self._has_response_validation_error,
),
],
)

def get_openapi_schema(
self,
Expand Down Expand Up @@ -2370,6 +2406,25 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
route=route,
)

# OpenAPIValidationMiddleware will only raise ResponseValidationError when
# 'self._response_validation_error_http_status' is not None
if isinstance(exp, ResponseValidationError):
http_status = (
self._response_validation_error_http_status
if self._response_validation_error_http_status
else HTTPStatus.UNPROCESSABLE_ENTITY
)
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
return self._response_builder_class(
response=Response(
status_code=http_status.value,
content_type=content_types.APPLICATION_JSON,
body={"statusCode": self._response_validation_error_http_status, "detail": errors},
),
serializer=self._serializer,
route=route,
)

if isinstance(exp, ServiceError):
return self._response_builder_class(
response=Response(
Expand Down Expand Up @@ -2582,6 +2637,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon API Gateway REST and HTTP API v1 payload resolver"""
super().__init__(
Expand All @@ -2591,6 +2647,7 @@ def __init__(
serializer,
strip_prefixes,
enable_validation,
response_validation_error_http_status,
)

def _get_base_path(self) -> str:
Expand Down Expand Up @@ -2664,6 +2721,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon API Gateway HTTP API v2 payload resolver"""
super().__init__(
Expand All @@ -2673,6 +2731,7 @@ def __init__(
serializer,
strip_prefixes,
enable_validation,
response_validation_error_http_status,
)

def _get_base_path(self) -> str:
Expand Down Expand Up @@ -2701,9 +2760,18 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon Application Load Balancer (ALB) resolver"""
super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes, enable_validation)
super().__init__(
ProxyEventType.ALBEvent,
cors,
debug,
serializer,
strip_prefixes,
enable_validation,
response_validation_error_http_status,
)

def _get_base_path(self) -> str:
# ALB doesn't have a stage variable, so we just return an empty string
Expand Down
4 changes: 4 additions & 0 deletions aws_lambda_powertools/event_handler/lambda_function_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
)

if TYPE_CHECKING:
from http import HTTPStatus

from aws_lambda_powertools.event_handler import CORSConfig
from aws_lambda_powertools.utilities.data_classes import LambdaFunctionUrlEvent

Expand Down Expand Up @@ -57,6 +59,7 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
super().__init__(
ProxyEventType.LambdaFunctionUrlEvent,
Expand All @@ -65,6 +68,7 @@ def __init__(
serializer,
strip_prefixes,
enable_validation,
response_validation_error_http_status,
)

def _get_base_path(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
from aws_lambda_powertools.event_handler.openapi.params import Param

if TYPE_CHECKING:
Expand Down Expand Up @@ -58,7 +58,11 @@ def get_todos(): list[Todo]:
```
"""

def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
def __init__(
self,
validation_serializer: Callable[[Any], str] | None = None,
has_response_validation_error: bool = False,
):
"""
Initialize the OpenAPIValidationMiddleware.

Expand All @@ -67,8 +71,14 @@ def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
validation_serializer : Callable, optional
Optional serializer to use when serializing the response for validation.
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.

custom_serialize_response_error: ValidationException, optional
Optional error type to raise when response to be returned by the endpoint is not
serialisable according to field type.
Raises RequestValidationError by default.
"""
self._validation_serializer = validation_serializer
self._has_response_validation_error = has_response_validation_error

def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
logger.debug("OpenAPIValidationMiddleware handler")
Expand Down Expand Up @@ -165,6 +175,8 @@ def _serialize_response(
errors: list[dict[str, Any]] = []
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
if errors:
if self._has_response_validation_error:
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content)
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)

if hasattr(field, "serialize"):
Expand Down
10 changes: 10 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
self.body = body


class ResponseValidationError(ValidationException):
"""
Raised when the response body does not match the OpenAPI schema
"""

def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
super().__init__(errors)
self.body = body


class SerializationError(Exception):
"""
Base exception for all encoding errors
Expand Down
24 changes: 22 additions & 2 deletions aws_lambda_powertools/event_handler/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
)

if TYPE_CHECKING:
from http import HTTPStatus

from aws_lambda_powertools.event_handler import CORSConfig
from aws_lambda_powertools.utilities.data_classes import VPCLatticeEvent, VPCLatticeEventV2

Expand Down Expand Up @@ -53,9 +55,18 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon VPC Lattice resolver"""
super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation)
super().__init__(
ProxyEventType.VPCLatticeEvent,
cors,
debug,
serializer,
strip_prefixes,
enable_validation,
response_validation_error_http_status,
)

def _get_base_path(self) -> str:
return ""
Expand Down Expand Up @@ -102,9 +113,18 @@ def __init__(
serializer: Callable[[dict], str] | None = None,
strip_prefixes: list[str | Pattern] | None = None,
enable_validation: bool = False,
response_validation_error_http_status: HTTPStatus | int | None = None,
):
"""Amazon VPC Lattice resolver"""
super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation)
super().__init__(
ProxyEventType.VPCLatticeEventV2,
cors,
debug,
serializer,
strip_prefixes,
enable_validation,
response_validation_error_http_status,
)

def _get_base_path(self) -> str:
return ""
31 changes: 30 additions & 1 deletion docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ Let's rewrite the previous examples to signal our resolver what shape we expect

!!! info "By default, we hide extended error details for security reasons _(e.g., pydantic url, Pydantic code)_."

Any incoming request that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this:
Any incoming request or and outgoing response that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this:

```json hl_lines="2 3" title="data_validation_error_unsanitized_output.json"
--8<-- "examples/event_handler_rest/src/data_validation_error_unsanitized_output.json"
Expand Down Expand Up @@ -398,6 +398,35 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou
--8<-- "examples/event_handler_rest/src/validating_payload_subset_output.json"
```

#### Validating responses

The optional `response_validation_error_http_status` argument can be set for all the resolvers to distinguish between failed data validation of payload and response. The desired HTTP status code for failed response validation must be passed to this argument.

Following on from our previous example, we want to distinguish between an invalid payload sent by the user and an invalid response which is being proxying to the user from another endpoint.

=== "customizing_response_validation.py"

```python hl_lines="18 30 34 36"
--8<-- "examples/event_handler_rest/src/customizing_response_validation.py"
```

1. This enforces response data validation at runtime. A response with status code set here will be returned if response data is not valid.
2. We validate our response body against `Todo`.
3. Operation returns a string as oppose to a Todo object. This will lead to a `500` response as set in line 18.
4. The distinct `ResponseValidationError` exception can be caught to customise the response—see difference between the sanitized and unsanitized responses.

=== "sanitized_error_response.json"

```json hl_lines="2-3"
--8<-- "examples/event_handler_rest/src/response_validation_sanitized_error_output.json"
```

=== "unsanitized_error_response.json"

```json hl_lines="2-3"
--8<-- "examples/event_handler_rest/src/response_validation_error_unsanitized_output.json"
```

#### Validating query strings

!!! info "We will automatically validate and inject incoming query strings via type annotation."
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from http import HTTPStatus
from typing import Optional

import requests
from pydantic import BaseModel, Field

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, content_types
from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.utilities.typing import LambdaContext

tracer = Tracer()
logger = Logger()
app = APIGatewayRestResolver(
enable_validation=True,
response_validation_error_http_status=HTTPStatus.INTERNAL_SERVER_ERROR, # (1)!
)


class Todo(BaseModel):
userId: int
id_: Optional[int] = Field(alias="id", default=None)
title: str
completed: bool

@app.get("/todos_bad_response/<todo_id>")
@tracer.capture_method
def get_todo_by_id(todo_id: int) -> Todo: # (2)!
todo = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}")
todo.raise_for_status()

return todo.json()["title"] # (3)!

@app.exception_handler(ResponseValidationError) # (4)!
def handle_response_validation_error(ex: ResponseValidationError):
logger.error("Request failed validation", path=app.current_event.path, errors=ex.errors())

return Response(
status_code=500,
content_type=content_types.APPLICATION_JSON,
body="Unexpected response.",
)

@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP)
@tracer.capture_lambda_handler
def lambda_handler(event: dict, context: LambdaContext) -> dict:
return app.resolve(event, context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"statusCode": 500,
"body": "{\"statusCode\": 500, \"detail\": [{\"type\": \"model_attributes_type\", \"loc\": [\"response\", ]}]}",
"isBase64Encoded": false,
"headers": {
"Content-Type": "application/json"
}
}
Loading
Loading