Skip to content

fix(event_handler): apply serialization as the last operation for middlewares #3392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,14 @@ def _generate_operation_id(self) -> str:
class ResponseBuilder(Generic[ResponseEventT]):
"""Internally used Response builder"""

def __init__(self, response: Response, route: Optional[Route] = None):
def __init__(
self,
response: Response,
serializer: Callable[[Any], str] = json.dumps,
route: Optional[Route] = None,
):
self.response = response
self.serializer = serializer
self.route = route

def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
Expand Down Expand Up @@ -783,6 +789,11 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic
self.response.base64_encoded = True
self.response.body = base64.b64encode(self.response.body).decode()

# We only apply the serializer when the content type is JSON and the
# body is not a str, to avoid double encoding
elif self.response.is_json() and not isinstance(self.response.body, str):
self.response.body = self.serializer(self.response.body)

return {
"statusCode": self.response.status_code,
"body": self.response.body,
Expand Down Expand Up @@ -1332,14 +1343,6 @@ def __init__(

self.use([OpenAPIValidationMiddleware()])

# When using validation, we need to skip the serializer, as the middleware is doing it automatically.
# However, if the user is using a custom serializer, we need to abort.
if serializer:
raise ValueError("Cannot use a custom serializer when using validation")

# Install a dummy serializer
self._serializer = lambda args: args # type: ignore

def get_openapi_schema(
self,
*,
Expand Down Expand Up @@ -1717,7 +1720,7 @@ def resolve(self, event, context) -> Dict[str, Any]:
event = event.raw_event

if self._debug:
print(self._json_dump(event))
print(self._serializer(event))

# Populate router(s) dependencies without keeping a reference to each registered router
BaseRouter.current_event = self._to_proxy_event(event)
Expand Down Expand Up @@ -1881,19 +1884,23 @@ def _not_found(self, method: str) -> ResponseBuilder:
if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=""))
return ResponseBuilder(
response=Response(status_code=204, content_type=None, headers=headers, body=""),
serializer=self._serializer,
)

handler = self._lookup_exception_handler(NotFoundError)
if handler:
return self._response_builder_class(handler(NotFoundError()))
return self._response_builder_class(response=handler(NotFoundError()), serializer=self._serializer)

return self._response_builder_class(
Response(
response=Response(
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
headers=headers,
body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}),
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
),
serializer=self._serializer,
)

def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder:
Expand All @@ -1903,10 +1910,11 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
self._reset_processed_stack()

return self._response_builder_class(
self._to_response(
response=self._to_response(
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
),
route,
serializer=self._serializer,
route=route,
)
except Exception as exc:
# If exception is handled then return the response builder to reduce noise
Expand All @@ -1920,12 +1928,13 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
# we'll let the original exception propagate, so
# they get more information about what went wrong.
return self._response_builder_class(
Response(
response=Response(
status_code=500,
content_type=content_types.TEXT_PLAIN,
body="".join(traceback.format_exc()),
),
route,
serializer=self._serializer,
route=route,
)

raise
Expand Down Expand Up @@ -1958,18 +1967,19 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
handler = self._lookup_exception_handler(type(exp))
if handler:
try:
return self._response_builder_class(handler(exp), route)
return self._response_builder_class(response=handler(exp), serializer=self._serializer, route=route)
except ServiceError as service_error:
exp = service_error

if isinstance(exp, ServiceError):
return self._response_builder_class(
Response(
response=Response(
status_code=exp.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
body={"statusCode": exp.status_code, "message": exp.msg},
),
route,
serializer=self._serializer,
route=route,
)

return None
Expand All @@ -1995,12 +2005,9 @@ def _to_response(self, result: Union[Dict, Tuple, Response]) -> Response:
return Response(
status_code=status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump(result),
body=result,
)

def _json_dump(self, obj: Any) -> str:
return self._serializer(obj)

def include_router(self, router: "Router", prefix: Optional[str] = None) -> None:
"""Adds all routes and context defined in a router

Expand Down
6 changes: 5 additions & 1 deletion aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
"""Build the full response dict to be returned by the lambda"""
self._route(event, None)

body = self.response.body
if self.response.is_json() and not isinstance(self.response.body, str):
body = self.serializer(self.response.body)

return {
"messageVersion": "1.0",
"response": {
Expand All @@ -32,7 +36,7 @@ def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
"httpStatusCode": self.response.status_code,
"responseBody": {
self.response.content_type: {
"body": self.response.body,
"body": body,
},
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def _handle_response(self, *, route: Route, response: Response):
if response.body:
# Validate and serialize the response, if it's JSON
if response.is_json():
response.body = json.dumps(
self._serialize_response(field=route.dependant.return_param, response_content=response.body),
sort_keys=True,
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
)

return response
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def handler(event, context):
# WHEN calling the event handler
result = handler(mock_event, None)

# THEN then the response is not compressed
# THEN the response is not compressed
assert result["isBase64Encoded"] is False
assert result["body"] == expected_value
assert result["multiValueHeaders"].get("Content-Encoding") is None
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/event_handler/test_bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def claims() -> Dict[str, Any]:
assert result["response"]["httpStatusCode"] == 200

body = result["response"]["responseBody"]["application/json"]["body"]
assert body == json.dumps({"output": claims_response})
assert json.loads(body) == {"output": claims_response}


def test_bedrock_agent_with_path_params():
Expand Down Expand Up @@ -79,7 +79,7 @@ def claims():
assert result["response"]["httpStatusCode"] == 200

body = result["response"]["responseBody"]["application/json"]["body"]
assert body == json.dumps(output)
assert json.loads(body) == output


def test_bedrock_agent_event_with_no_matches():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import PurePath
from typing import List, Tuple

import pytest
from pydantic import BaseModel

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
Expand All @@ -15,11 +14,6 @@
LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")


def test_validate_with_customn_serializer():
with pytest.raises(ValueError):
APIGatewayRestResolver(enable_validation=True, serializer=json.dumps)


def test_validate_scalars():
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)
Expand Down Expand Up @@ -128,7 +122,7 @@ def handler() -> List[int]:
# THEN the body must be [123, 234]
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == "[123, 234]"
assert json.loads(result["body"]) == [123, 234]


def test_validate_return_tuple():
Expand All @@ -148,7 +142,7 @@ def handler() -> Tuple:
# THEN the body must be a tuple
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == "[1, 2, 3]"
assert json.loads(result["body"]) == [1, 2, 3]


def test_validate_return_purepath():
Expand All @@ -169,7 +163,7 @@ def handler() -> str:
# THEN the body must be a string
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == json.dumps(sample_path.as_posix())
assert result["body"] == sample_path.as_posix()


def test_validate_return_enum():
Expand All @@ -190,7 +184,7 @@ def handler() -> Model:
# THEN the body must be a string
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == '"powertools"'
assert result["body"] == "powertools"


def test_validate_return_dataclass():
Expand Down