Skip to content

Commit 43c2515

Browse files
feat(api-gateway): add support for custom serializer (aws-powertools#568)
Co-authored-by: Heitor Lessa <[email protected]>
1 parent e0ab7a1 commit 43c2515

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import traceback
77
import zlib
88
from enum import Enum
9+
from functools import partial
910
from http import HTTPStatus
1011
from typing import Any, Callable, Dict, List, Optional, Set, Union
1112

@@ -263,6 +264,7 @@ def __init__(
263264
proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent,
264265
cors: Optional[CORSConfig] = None,
265266
debug: Optional[bool] = None,
267+
serializer: Optional[Callable[[Dict], str]] = None,
266268
):
267269
"""
268270
Parameters
@@ -284,6 +286,13 @@ def __init__(
284286
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
285287
)
286288

289+
# Allow for a custom serializer or a concise json serialization
290+
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
291+
292+
if self._debug:
293+
# Always does a pretty print when in debug mode
294+
self._serializer = partial(json.dumps, indent=4, cls=Encoder)
295+
287296
def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
288297
"""Get route decorator with GET `method`
289298
@@ -592,8 +601,4 @@ def _to_response(self, result: Union[Dict, Response]) -> Response:
592601
)
593602

594603
def _json_dump(self, obj: Any) -> str:
595-
"""Does a concise json serialization or pretty print when in debug mode"""
596-
if self._debug:
597-
return json.dumps(obj, indent=4, cls=Encoder)
598-
else:
599-
return json.dumps(obj, separators=(",", ":"), cls=Encoder)
604+
return self._serializer(obj)

tests/functional/event_handler/test_api_gateway.py

+41
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import zlib
44
from copy import deepcopy
55
from decimal import Decimal
6+
from enum import Enum
7+
from json import JSONEncoder
68
from pathlib import Path
79
from typing import Dict
810

@@ -728,3 +730,42 @@ def get_account(account_id: str):
728730

729731
ret = app.resolve(event, None)
730732
assert ret["statusCode"] == 200
733+
734+
735+
def test_custom_serializer():
736+
# GIVEN a custom serializer to handle enums and sets
737+
class CustomEncoder(JSONEncoder):
738+
def default(self, data):
739+
if isinstance(data, Enum):
740+
return data.value
741+
try:
742+
iterable = iter(data)
743+
except TypeError:
744+
pass
745+
else:
746+
return sorted(iterable)
747+
return JSONEncoder.default(self, data)
748+
749+
def custom_serializer(data) -> str:
750+
return json.dumps(data, cls=CustomEncoder)
751+
752+
app = ApiGatewayResolver(serializer=custom_serializer)
753+
754+
class Color(Enum):
755+
RED = 1
756+
BLUE = 2
757+
758+
@app.get("/colors")
759+
def get_color() -> Dict:
760+
return {
761+
"color": Color.RED,
762+
"variations": {"light", "dark"},
763+
}
764+
765+
# WHEN calling handler
766+
response = app({"httpMethod": "GET", "path": "/colors"}, None)
767+
768+
# THEN then use the custom serializer
769+
body = response["body"]
770+
expected = '{"color": 1, "variations": ["dark", "light"]}'
771+
assert expected == body

0 commit comments

Comments
 (0)