Skip to content

Commit 8765206

Browse files
authored
feat(event_handler): use custom serializer during openapi serialization (#3900)
* feat(event_handler): use custom serializer during openapi serialization * fix: comments
1 parent e79eef4 commit 8765206

File tree

4 files changed

+50
-4
lines changed

4 files changed

+50
-4
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1401,7 +1401,9 @@ def __init__(
14011401
if self._enable_validation:
14021402
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware
14031403

1404-
self.use([OpenAPIValidationMiddleware()])
1404+
# Note the serializer argument: only use custom serializer if provided by the caller
1405+
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
1406+
self.use([OpenAPIValidationMiddleware(validation_serializer=serializer)])
14051407

14061408
def get_openapi_schema(
14071409
self,

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import logging
44
from copy import deepcopy
5-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
5+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
66

77
from pydantic import BaseModel
88

@@ -55,6 +55,18 @@ def get_todos(): List[Todo]:
5555
```
5656
"""
5757

58+
def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None):
59+
"""
60+
Initialize the OpenAPIValidationMiddleware.
61+
62+
Parameters
63+
----------
64+
validation_serializer : Callable, optional
65+
Optional serializer to use when serializing the response for validation.
66+
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
67+
"""
68+
self._validation_serializer = validation_serializer
69+
5870
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
5971
logger.debug("OpenAPIValidationMiddleware handler")
6072

@@ -181,10 +193,11 @@ def _serialize_response(
181193
exclude_unset=exclude_unset,
182194
exclude_defaults=exclude_defaults,
183195
exclude_none=exclude_none,
196+
custom_serializer=self._validation_serializer,
184197
)
185198
else:
186199
# Just serialize the response content returned from the handler
187-
return jsonable_encoder(response_content)
200+
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)
188201

189202
def _prepare_response_content(
190203
self,

aws_lambda_powertools/event_handler/openapi/encoders.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def jsonable_encoder( # noqa: PLR0911
2929
exclude_unset: bool = False,
3030
exclude_defaults: bool = False,
3131
exclude_none: bool = False,
32+
custom_serializer: Optional[Callable[[Any], str]] = None,
3233
) -> Any:
3334
"""
3435
JSON encodes an arbitrary Python object into JSON serializable data types.
@@ -55,6 +56,8 @@ def jsonable_encoder( # noqa: PLR0911
5556
by default False
5657
exclude_none : bool, optional
5758
Whether fields that are equal to None should be excluded, by default False
59+
custom_serializer : Callable, optional
60+
A custom serializer to use for encoding the object, when everything else fails.
5861
5962
Returns
6063
-------
@@ -134,6 +137,10 @@ def jsonable_encoder( # noqa: PLR0911
134137
if isinstance(obj, classes_tuple):
135138
return encoder(obj)
136139

140+
# Use custom serializer if present
141+
if custom_serializer:
142+
return custom_serializer(obj)
143+
137144
# Default
138145
return _dump_other(
139146
obj=obj,
@@ -259,7 +266,7 @@ def _dump_other(
259266
exclude_defaults: bool = False,
260267
) -> Any:
261268
"""
262-
Dump an object to ah hashable object, using the same parameters as jsonable_encoder
269+
Dump an object to a hashable object, using the same parameters as jsonable_encoder
263270
"""
264271
try:
265272
data = dict(obj)

tests/functional/event_handler/test_openapi_serialization.py

+24
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,27 @@ def handler():
3737

3838
# THEN we should get a dictionary
3939
assert isinstance(schema, Dict)
40+
41+
42+
def test_openapi_serialize_other(gw_event):
43+
# GIVEN a custom serializer
44+
def serializer(_):
45+
return "hello world"
46+
47+
# GIVEN APIGatewayRestResolver is initialized with enable_validation=True and the custom serializer
48+
app = APIGatewayRestResolver(enable_validation=True, serializer=serializer)
49+
50+
# GIVEN a custom class
51+
class CustomClass(object):
52+
__slots__ = []
53+
54+
# GIVEN a handler that returns an instance of that class
55+
@app.get("/my/path")
56+
def handler():
57+
return CustomClass()
58+
59+
# WHEN we invoke the handler
60+
response = app(gw_event, {})
61+
62+
# THEN we the custom serializer should be used
63+
assert response["body"] == "hello world"

0 commit comments

Comments
 (0)