Skip to content

feat(event_handler): use custom serializer during openapi serialization #3900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,9 @@ def __init__(
if self._enable_validation:
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware

self.use([OpenAPIValidationMiddleware()])
# Note the serializer argument: only use custom serializer if provided by the caller
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
self.use([OpenAPIValidationMiddleware(validation_serializer=serializer)])

def get_openapi_schema(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple

from pydantic import BaseModel

Expand Down Expand Up @@ -55,6 +55,18 @@ def get_todos(): List[Todo]:
```
"""

def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None):
"""
Initialize the OpenAPIValidationMiddleware.

Parameters
----------
validation_serializer : Callable, optional
Optional serializer to use when serializing the response for validation.
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
"""
self._validation_serializer = validation_serializer

def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
logger.debug("OpenAPIValidationMiddleware handler")

Expand Down Expand Up @@ -181,10 +193,11 @@ def _serialize_response(
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_serializer=self._validation_serializer,
)
else:
# Just serialize the response content returned from the handler
return jsonable_encoder(response_content)
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)

def _prepare_response_content(
self,
Expand Down
9 changes: 8 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def jsonable_encoder( # noqa: PLR0911
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_serializer: Optional[Callable[[Any], str]] = None,
) -> Any:
"""
JSON encodes an arbitrary Python object into JSON serializable data types.
Expand All @@ -55,6 +56,8 @@ def jsonable_encoder( # noqa: PLR0911
by default False
exclude_none : bool, optional
Whether fields that are equal to None should be excluded, by default False
custom_serializer : Callable, optional
A custom serializer to use for encoding the object, when everything else fails.

Returns
-------
Expand Down Expand Up @@ -134,6 +137,10 @@ def jsonable_encoder( # noqa: PLR0911
if isinstance(obj, classes_tuple):
return encoder(obj)

# Use custom serializer if present
if custom_serializer:
return custom_serializer(obj)

# Default
return _dump_other(
obj=obj,
Expand Down Expand Up @@ -259,7 +266,7 @@ def _dump_other(
exclude_defaults: bool = False,
) -> Any:
"""
Dump an object to ah hashable object, using the same parameters as jsonable_encoder
Dump an object to a hashable object, using the same parameters as jsonable_encoder
"""
try:
data = dict(obj)
Expand Down
24 changes: 24 additions & 0 deletions tests/functional/event_handler/test_openapi_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,27 @@ def handler():

# THEN we should get a dictionary
assert isinstance(schema, Dict)


def test_openapi_serialize_other(gw_event):
# GIVEN a custom serializer
def serializer(_):
return "hello world"

# GIVEN APIGatewayRestResolver is initialized with enable_validation=True and the custom serializer
app = APIGatewayRestResolver(enable_validation=True, serializer=serializer)

# GIVEN a custom class
class CustomClass(object):
__slots__ = []

# GIVEN a handler that returns an instance of that class
@app.get("/my/path")
def handler():
return CustomClass()

# WHEN we invoke the handler
response = app(gw_event, {})

# THEN we the custom serializer should be used
assert response["body"] == "hello world"