Skip to content

Commit 6a47ee8

Browse files
authored
fix(event_handler): apply serialization as the last operation for middlewares (#3392)
1 parent f94526f commit 6a47ee8

File tree

6 files changed

+48
-43
lines changed

6 files changed

+48
-43
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+33-26
Original file line numberDiff line numberDiff line change
@@ -699,8 +699,14 @@ def _generate_operation_id(self) -> str:
699699
class ResponseBuilder(Generic[ResponseEventT]):
700700
"""Internally used Response builder"""
701701

702-
def __init__(self, response: Response, route: Optional[Route] = None):
702+
def __init__(
703+
self,
704+
response: Response,
705+
serializer: Callable[[Any], str] = json.dumps,
706+
route: Optional[Route] = None,
707+
):
703708
self.response = response
709+
self.serializer = serializer
704710
self.route = route
705711

706712
def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
@@ -783,6 +789,11 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic
783789
self.response.base64_encoded = True
784790
self.response.body = base64.b64encode(self.response.body).decode()
785791

792+
# We only apply the serializer when the content type is JSON and the
793+
# body is not a str, to avoid double encoding
794+
elif self.response.is_json() and not isinstance(self.response.body, str):
795+
self.response.body = self.serializer(self.response.body)
796+
786797
return {
787798
"statusCode": self.response.status_code,
788799
"body": self.response.body,
@@ -1332,14 +1343,6 @@ def __init__(
13321343

13331344
self.use([OpenAPIValidationMiddleware()])
13341345

1335-
# When using validation, we need to skip the serializer, as the middleware is doing it automatically.
1336-
# However, if the user is using a custom serializer, we need to abort.
1337-
if serializer:
1338-
raise ValueError("Cannot use a custom serializer when using validation")
1339-
1340-
# Install a dummy serializer
1341-
self._serializer = lambda args: args # type: ignore
1342-
13431346
def get_openapi_schema(
13441347
self,
13451348
*,
@@ -1717,7 +1720,7 @@ def resolve(self, event, context) -> Dict[str, Any]:
17171720
event = event.raw_event
17181721

17191722
if self._debug:
1720-
print(self._json_dump(event))
1723+
print(self._serializer(event))
17211724

17221725
# Populate router(s) dependencies without keeping a reference to each registered router
17231726
BaseRouter.current_event = self._to_proxy_event(event)
@@ -1881,19 +1884,23 @@ def _not_found(self, method: str) -> ResponseBuilder:
18811884
if method == "OPTIONS":
18821885
logger.debug("Pre-flight request detected. Returning CORS with null response")
18831886
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
1884-
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=""))
1887+
return ResponseBuilder(
1888+
response=Response(status_code=204, content_type=None, headers=headers, body=""),
1889+
serializer=self._serializer,
1890+
)
18851891

18861892
handler = self._lookup_exception_handler(NotFoundError)
18871893
if handler:
1888-
return self._response_builder_class(handler(NotFoundError()))
1894+
return self._response_builder_class(response=handler(NotFoundError()), serializer=self._serializer)
18891895

18901896
return self._response_builder_class(
1891-
Response(
1897+
response=Response(
18921898
status_code=HTTPStatus.NOT_FOUND.value,
18931899
content_type=content_types.APPLICATION_JSON,
18941900
headers=headers,
1895-
body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}),
1901+
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
18961902
),
1903+
serializer=self._serializer,
18971904
)
18981905

18991906
def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder:
@@ -1903,10 +1910,11 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
19031910
self._reset_processed_stack()
19041911

19051912
return self._response_builder_class(
1906-
self._to_response(
1913+
response=self._to_response(
19071914
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
19081915
),
1909-
route,
1916+
serializer=self._serializer,
1917+
route=route,
19101918
)
19111919
except Exception as exc:
19121920
# If exception is handled then return the response builder to reduce noise
@@ -1920,12 +1928,13 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
19201928
# we'll let the original exception propagate, so
19211929
# they get more information about what went wrong.
19221930
return self._response_builder_class(
1923-
Response(
1931+
response=Response(
19241932
status_code=500,
19251933
content_type=content_types.TEXT_PLAIN,
19261934
body="".join(traceback.format_exc()),
19271935
),
1928-
route,
1936+
serializer=self._serializer,
1937+
route=route,
19291938
)
19301939

19311940
raise
@@ -1958,18 +1967,19 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
19581967
handler = self._lookup_exception_handler(type(exp))
19591968
if handler:
19601969
try:
1961-
return self._response_builder_class(handler(exp), route)
1970+
return self._response_builder_class(response=handler(exp), serializer=self._serializer, route=route)
19621971
except ServiceError as service_error:
19631972
exp = service_error
19641973

19651974
if isinstance(exp, ServiceError):
19661975
return self._response_builder_class(
1967-
Response(
1976+
response=Response(
19681977
status_code=exp.status_code,
19691978
content_type=content_types.APPLICATION_JSON,
1970-
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
1979+
body={"statusCode": exp.status_code, "message": exp.msg},
19711980
),
1972-
route,
1981+
serializer=self._serializer,
1982+
route=route,
19731983
)
19741984

19751985
return None
@@ -1995,12 +2005,9 @@ def _to_response(self, result: Union[Dict, Tuple, Response]) -> Response:
19952005
return Response(
19962006
status_code=status_code,
19972007
content_type=content_types.APPLICATION_JSON,
1998-
body=self._json_dump(result),
2008+
body=result,
19992009
)
20002010

2001-
def _json_dump(self, obj: Any) -> str:
2002-
return self._serializer(obj)
2003-
20042011
def include_router(self, router: "Router", prefix: Optional[str] = None) -> None:
20052012
"""Adds all routes and context defined in a router
20062013

aws_lambda_powertools/event_handler/bedrock_agent.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
2323
"""Build the full response dict to be returned by the lambda"""
2424
self._route(event, None)
2525

26+
body = self.response.body
27+
if self.response.is_json() and not isinstance(self.response.body, str):
28+
body = self.serializer(self.response.body)
29+
2630
return {
2731
"messageVersion": "1.0",
2832
"response": {
@@ -32,7 +36,7 @@ def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
3236
"httpStatusCode": self.response.status_code,
3337
"responseBody": {
3438
self.response.content_type: {
35-
"body": self.response.body,
39+
"body": body,
3640
},
3741
},
3842
},

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def _handle_response(self, *, route: Route, response: Response):
112112
if response.body:
113113
# Validate and serialize the response, if it's JSON
114114
if response.is_json():
115-
response.body = json.dumps(
116-
self._serialize_response(field=route.dependant.return_param, response_content=response.body),
117-
sort_keys=True,
115+
response.body = self._serialize_response(
116+
field=route.dependant.return_param,
117+
response_content=response.body,
118118
)
119119

120120
return response

tests/functional/event_handler/test_api_gateway.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def handler(event, context):
379379
# WHEN calling the event handler
380380
result = handler(mock_event, None)
381381

382-
# THEN then the response is not compressed
382+
# THEN the response is not compressed
383383
assert result["isBase64Encoded"] is False
384384
assert result["body"] == expected_value
385385
assert result["multiValueHeaders"].get("Content-Encoding") is None

tests/functional/event_handler/test_bedrock_agent.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def claims() -> Dict[str, Any]:
3131
assert result["response"]["httpStatusCode"] == 200
3232

3333
body = result["response"]["responseBody"]["application/json"]["body"]
34-
assert body == json.dumps({"output": claims_response})
34+
assert json.loads(body) == {"output": claims_response}
3535

3636

3737
def test_bedrock_agent_with_path_params():
@@ -79,7 +79,7 @@ def claims():
7979
assert result["response"]["httpStatusCode"] == 200
8080

8181
body = result["response"]["responseBody"]["application/json"]["body"]
82-
assert body == json.dumps(output)
82+
assert json.loads(body) == output
8383

8484

8585
def test_bedrock_agent_event_with_no_matches():

tests/functional/event_handler/test_openapi_validation_middleware.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pathlib import PurePath
55
from typing import List, Tuple
66

7-
import pytest
87
from pydantic import BaseModel
98

109
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
@@ -15,11 +14,6 @@
1514
LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")
1615

1716

18-
def test_validate_with_customn_serializer():
19-
with pytest.raises(ValueError):
20-
APIGatewayRestResolver(enable_validation=True, serializer=json.dumps)
21-
22-
2317
def test_validate_scalars():
2418
# GIVEN an APIGatewayRestResolver with validation enabled
2519
app = APIGatewayRestResolver(enable_validation=True)
@@ -128,7 +122,7 @@ def handler() -> List[int]:
128122
# THEN the body must be [123, 234]
129123
result = app(LOAD_GW_EVENT, {})
130124
assert result["statusCode"] == 200
131-
assert result["body"] == "[123, 234]"
125+
assert json.loads(result["body"]) == [123, 234]
132126

133127

134128
def test_validate_return_tuple():
@@ -148,7 +142,7 @@ def handler() -> Tuple:
148142
# THEN the body must be a tuple
149143
result = app(LOAD_GW_EVENT, {})
150144
assert result["statusCode"] == 200
151-
assert result["body"] == "[1, 2, 3]"
145+
assert json.loads(result["body"]) == [1, 2, 3]
152146

153147

154148
def test_validate_return_purepath():
@@ -169,7 +163,7 @@ def handler() -> str:
169163
# THEN the body must be a string
170164
result = app(LOAD_GW_EVENT, {})
171165
assert result["statusCode"] == 200
172-
assert result["body"] == json.dumps(sample_path.as_posix())
166+
assert result["body"] == sample_path.as_posix()
173167

174168

175169
def test_validate_return_enum():
@@ -190,7 +184,7 @@ def handler() -> Model:
190184
# THEN the body must be a string
191185
result = app(LOAD_GW_EVENT, {})
192186
assert result["statusCode"] == 200
193-
assert result["body"] == '"powertools"'
187+
assert result["body"] == "powertools"
194188

195189

196190
def test_validate_return_dataclass():

0 commit comments

Comments
 (0)