Skip to content

fix(event_handler): apply serialization as the last operation for middlewares #3392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,9 @@ def _generate_operation_id(self) -> str:
class ResponseBuilder(Generic[ResponseEventT]):
"""Internally used Response builder"""

def __init__(self, response: Response, route: Optional[Route] = None):
def __init__(self, response: Response, serializer: Callable[[Any], str], route: Optional[Route] = None):
self.response = response
self.serializer = serializer
self.route = route

def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
Expand Down Expand Up @@ -782,6 +783,8 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic
logger.debug("Encoding bytes response with base64")
self.response.base64_encoded = True
self.response.body = base64.b64encode(self.response.body).decode()
elif self.response.is_json():
self.response.body = self.serializer(self.response.body)

return {
"statusCode": self.response.status_code,
Expand Down Expand Up @@ -1332,14 +1335,6 @@ def __init__(

self.use([OpenAPIValidationMiddleware()])

# When using validation, we need to skip the serializer, as the middleware is doing it automatically.
# However, if the user is using a custom serializer, we need to abort.
if serializer:
raise ValueError("Cannot use a custom serializer when using validation")

# Install a dummy serializer
self._serializer = lambda args: args # type: ignore

def get_openapi_schema(
self,
*,
Expand Down Expand Up @@ -1717,7 +1712,7 @@ def resolve(self, event, context) -> Dict[str, Any]:
event = event.raw_event

if self._debug:
print(self._json_dump(event))
print(self._serializer(event))

# Populate router(s) dependencies without keeping a reference to each registered router
BaseRouter.current_event = self._to_proxy_event(event)
Expand Down Expand Up @@ -1881,19 +1876,23 @@ def _not_found(self, method: str) -> ResponseBuilder:
if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=""))
return ResponseBuilder(
response=Response(status_code=204, content_type=None, headers=headers, body=""),
serializer=self._serializer,
)

handler = self._lookup_exception_handler(NotFoundError)
if handler:
return self._response_builder_class(handler(NotFoundError()))
return self._response_builder_class(response=handler(NotFoundError()), serializer=self._serializer)

return self._response_builder_class(
Response(
response=Response(
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
headers=headers,
body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}),
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
),
serializer=self._serializer,
)

def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder:
Expand All @@ -1903,10 +1902,11 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
self._reset_processed_stack()

return self._response_builder_class(
self._to_response(
response=self._to_response(
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
),
route,
serializer=self._serializer,
route=route,
)
except Exception as exc:
# If exception is handled then return the response builder to reduce noise
Expand All @@ -1920,12 +1920,13 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
# we'll let the original exception propagate, so
# they get more information about what went wrong.
return self._response_builder_class(
Response(
response=Response(
status_code=500,
content_type=content_types.TEXT_PLAIN,
body="".join(traceback.format_exc()),
),
route,
serializer=self._serializer,
route=route,
)

raise
Expand Down Expand Up @@ -1958,18 +1959,19 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
handler = self._lookup_exception_handler(type(exp))
if handler:
try:
return self._response_builder_class(handler(exp), route)
return self._response_builder_class(response=handler(exp), serializer=self._serializer, route=route)
except ServiceError as service_error:
exp = service_error

if isinstance(exp, ServiceError):
return self._response_builder_class(
Response(
response=Response(
status_code=exp.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
body={"statusCode": exp.status_code, "message": exp.msg},
),
route,
serializer=self._serializer,
route=route,
)

return None
Expand All @@ -1995,12 +1997,9 @@ def _to_response(self, result: Union[Dict, Tuple, Response]) -> Response:
return Response(
status_code=status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump(result),
body=result,
)

def _json_dump(self, obj: Any) -> str:
return self._serializer(obj)

def include_router(self, router: "Router", prefix: Optional[str] = None) -> None:
"""Adds all routes and context defined in a router

Expand Down
6 changes: 5 additions & 1 deletion aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
"""Build the full response dict to be returned by the lambda"""
self._route(event, None)

body = self.response.body
if self.response.is_json():
body = self.serializer(self.response.body)

return {
"messageVersion": "1.0",
"response": {
Expand All @@ -32,7 +36,7 @@ def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
"httpStatusCode": self.response.status_code,
"responseBody": {
self.response.content_type: {
"body": self.response.body,
"body": body,
},
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def _handle_response(self, *, route: Route, response: Response):
if response.body:
# Validate and serialize the response, if it's JSON
if response.is_json():
response.body = json.dumps(
self._serialize_response(field=route.dependant.return_param, response_content=response.body),
sort_keys=True,
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
)

return response
Expand Down
14 changes: 7 additions & 7 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def test_override_route_compress_parameter():
# AND the Response object with compress=False
app = ApiGatewayResolver()
mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}
expected_value = '{"test": "value"}'
expected_value = {"test": "value"}

@app.get("/my/request", compress=True)
def with_compression() -> Response:
Expand All @@ -379,9 +379,9 @@ def handler(event, context):
# WHEN calling the event handler
result = handler(mock_event, None)

# THEN then the response is not compressed
# THEN the response is not compressed
assert result["isBase64Encoded"] is False
assert result["body"] == expected_value
assert json.loads(result["body"]) == expected_value
assert result["multiValueHeaders"].get("Content-Encoding") is None


Expand Down Expand Up @@ -681,7 +681,7 @@ def another_one():
def test_no_content_response():
# GIVEN a response with no content-type or body
response = Response(status_code=204, content_type=None, body=None, headers=None)
response_builder = ResponseBuilder(response)
response_builder = ResponseBuilder(response, serializer=json.dumps)

# WHEN calling to_dict
result = response_builder.build(APIGatewayProxyEvent(LOAD_GW_EVENT))
Expand Down Expand Up @@ -1482,7 +1482,7 @@ def get_lambda() -> Response:
# THEN call the exception_handler
assert result["statusCode"] == 500
assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON]
assert result["body"] == "CUSTOM ERROR FORMAT"
assert result["body"] == '"CUSTOM ERROR FORMAT"'


def test_exception_handler_not_found():
Expand Down Expand Up @@ -1778,11 +1778,11 @@ def test_route_match_prioritize_full_match():

@router.get("/my/{path}")
def dynamic_handler() -> Response:
return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "dynamic"}))
return Response(200, content_types.APPLICATION_JSON, {"hello": "dynamic"})

@router.get("/my/path")
def static_handler() -> Response:
return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "static"}))
return Response(200, content_types.APPLICATION_JSON, {"hello": "static"})

app.include_router(router)

Expand Down
12 changes: 6 additions & 6 deletions tests/functional/event_handler/test_base_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def handle():

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""
assert result["body"] == '""'


def test_base_path_api_gateway_http():
Expand All @@ -38,7 +38,7 @@ def handle():

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""
assert result["body"] == '""'


def test_base_path_alb():
Expand All @@ -53,7 +53,7 @@ def handle():

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""
assert result["body"] == '""'


def test_base_path_lambda_function_url():
Expand All @@ -70,7 +70,7 @@ def handle():

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""
assert result["body"] == '""'


def test_vpc_lattice():
Expand All @@ -85,7 +85,7 @@ def handle():

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""
assert result["body"] == '""'


def test_vpc_latticev2():
Expand All @@ -100,4 +100,4 @@ def handle():

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""
assert result["body"] == '""'
4 changes: 2 additions & 2 deletions tests/functional/event_handler/test_bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def claims() -> Dict[str, Any]:
assert result["response"]["httpStatusCode"] == 200

body = result["response"]["responseBody"]["application/json"]["body"]
assert body == json.dumps({"output": claims_response})
assert json.loads(body) == {"output": claims_response}


def test_bedrock_agent_with_path_params():
Expand Down Expand Up @@ -79,7 +79,7 @@ def claims():
assert result["response"]["httpStatusCode"] == 200

body = result["response"]["responseBody"]["application/json"]["body"]
assert body == json.dumps(output)
assert json.loads(body) == output


def test_bedrock_agent_event_with_no_matches():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import PurePath
from typing import List, Tuple

import pytest
from pydantic import BaseModel

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
Expand All @@ -15,11 +14,6 @@
LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")


def test_validate_with_customn_serializer():
with pytest.raises(ValueError):
APIGatewayRestResolver(enable_validation=True, serializer=json.dumps)


def test_validate_scalars():
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)
Expand Down Expand Up @@ -128,7 +122,7 @@ def handler() -> List[int]:
# THEN the body must be [123, 234]
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == "[123, 234]"
assert json.loads(result["body"]) == [123, 234]


def test_validate_return_tuple():
Expand All @@ -148,7 +142,7 @@ def handler() -> Tuple:
# THEN the body must be a tuple
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == "[1, 2, 3]"
assert json.loads(result["body"]) == [1, 2, 3]


def test_validate_return_purepath():
Expand Down