From daa17eb758e3d021387aa9d206dbe72124c08611 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 15:05:43 +0100 Subject: [PATCH 1/7] chore(event_handler): only apply serialization at the end --- aws_lambda_powertools/event_handler/api_gateway.py | 2 ++ tests/functional/event_handler/test_api_gateway.py | 12 ++++++------ tests/functional/event_handler/test_base_path.py | 12 ++++++------ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index a2b81974a21..8c44f6b8a9f 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -788,6 +788,8 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic logger.debug("Encoding bytes response with base64") self.response.base64_encoded = True self.response.body = base64.b64encode(self.response.body).decode() + elif self.response.is_json(): + self.response.body = self.serializer(self.response.body) # We only apply the serializer when the content type is JSON and the # body is not a str, to avoid double encoding diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 570de9ec808..4ef5fa0896f 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -367,7 +367,7 @@ def test_override_route_compress_parameter(): # AND the Response object with compress=False app = ApiGatewayResolver() mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} - expected_value = '{"test": "value"}' + expected_value = {"test": "value"} @app.get("/my/request", compress=True) def with_compression() -> Response: @@ -381,7 +381,7 @@ def handler(event, context): # THEN the response is not compressed assert result["isBase64Encoded"] is False - assert result["body"] == expected_value + assert json.loads(result["body"]) == expected_value assert result["multiValueHeaders"].get("Content-Encoding") is None @@ -681,7 +681,7 @@ def another_one(): def test_no_content_response(): # GIVEN a response with no content-type or body response = Response(status_code=204, content_type=None, body=None, headers=None) - response_builder = ResponseBuilder(response) + response_builder = ResponseBuilder(response, serializer=json.dumps) # WHEN calling to_dict result = response_builder.build(APIGatewayProxyEvent(LOAD_GW_EVENT)) @@ -1482,7 +1482,7 @@ def get_lambda() -> Response: # THEN call the exception_handler assert result["statusCode"] == 500 assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] - assert result["body"] == "CUSTOM ERROR FORMAT" + assert result["body"] == '"CUSTOM ERROR FORMAT"' def test_exception_handler_not_found(): @@ -1778,11 +1778,11 @@ def test_route_match_prioritize_full_match(): @router.get("/my/{path}") def dynamic_handler() -> Response: - return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "dynamic"})) + return Response(200, content_types.APPLICATION_JSON, {"hello": "dynamic"}) @router.get("/my/path") def static_handler() -> Response: - return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "static"})) + return Response(200, content_types.APPLICATION_JSON, {"hello": "static"}) app.include_router(router) diff --git a/tests/functional/event_handler/test_base_path.py b/tests/functional/event_handler/test_base_path.py index 479a46bda55..adf3c5849df 100644 --- a/tests/functional/event_handler/test_base_path.py +++ b/tests/functional/event_handler/test_base_path.py @@ -21,7 +21,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_base_path_api_gateway_http(): @@ -38,7 +38,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_base_path_alb(): @@ -53,7 +53,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_base_path_lambda_function_url(): @@ -70,7 +70,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_vpc_lattice(): @@ -85,7 +85,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_vpc_latticev2(): @@ -100,4 +100,4 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' From 9a0b142df0232cf055eb52fec851e0325409ed81 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 14:01:51 +0100 Subject: [PATCH 2/7] fix: avoid double encoding --- aws_lambda_powertools/event_handler/api_gateway.py | 5 ++++- tests/functional/event_handler/test_api_gateway.py | 2 +- tests/functional/event_handler/test_base_path.py | 12 ++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 8c44f6b8a9f..721da601077 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -788,7 +788,10 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic logger.debug("Encoding bytes response with base64") self.response.base64_encoded = True self.response.body = base64.b64encode(self.response.body).decode() - elif self.response.is_json(): + + # We only apply the serializer when the content type is JSON and the + # body is not a str, to avoid double encoding + elif self.response.is_json() and not isinstance(self.response.body, str): self.response.body = self.serializer(self.response.body) # We only apply the serializer when the content type is JSON and the diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 4ef5fa0896f..3cb1261eccd 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1482,7 +1482,7 @@ def get_lambda() -> Response: # THEN call the exception_handler assert result["statusCode"] == 500 assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] - assert result["body"] == '"CUSTOM ERROR FORMAT"' + assert result["body"] == "CUSTOM ERROR FORMAT" def test_exception_handler_not_found(): diff --git a/tests/functional/event_handler/test_base_path.py b/tests/functional/event_handler/test_base_path.py index adf3c5849df..479a46bda55 100644 --- a/tests/functional/event_handler/test_base_path.py +++ b/tests/functional/event_handler/test_base_path.py @@ -21,7 +21,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_base_path_api_gateway_http(): @@ -38,7 +38,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_base_path_alb(): @@ -53,7 +53,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_base_path_lambda_function_url(): @@ -70,7 +70,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_vpc_lattice(): @@ -85,7 +85,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_vpc_latticev2(): @@ -100,4 +100,4 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" From 5f99afc49c4acf67980a77c52ac11679a2b20039 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 14:06:12 +0100 Subject: [PATCH 3/7] fix: rolled back test changes --- tests/functional/event_handler/test_api_gateway.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 3cb1261eccd..e370ca4b99d 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -367,7 +367,7 @@ def test_override_route_compress_parameter(): # AND the Response object with compress=False app = ApiGatewayResolver() mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} - expected_value = {"test": "value"} + expected_value = '{"test": "value"}' @app.get("/my/request", compress=True) def with_compression() -> Response: @@ -381,7 +381,7 @@ def handler(event, context): # THEN the response is not compressed assert result["isBase64Encoded"] is False - assert json.loads(result["body"]) == expected_value + assert result["body"] == expected_value assert result["multiValueHeaders"].get("Content-Encoding") is None @@ -1778,11 +1778,11 @@ def test_route_match_prioritize_full_match(): @router.get("/my/{path}") def dynamic_handler() -> Response: - return Response(200, content_types.APPLICATION_JSON, {"hello": "dynamic"}) + return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "dynamic"})) @router.get("/my/path") def static_handler() -> Response: - return Response(200, content_types.APPLICATION_JSON, {"hello": "static"}) + return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "static"})) app.include_router(router) From 79a677ef9dfa2778546e77ea8f7ded8387af95a3 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 14:59:04 +0100 Subject: [PATCH 4/7] fix(event_handler): allow use of Response with data validation --- .../event_handler/api_gateway.py | 5 +-- .../event_handler/openapi/params.py | 12 +++++++ .../event_handler/test_openapi_params.py | 20 +++++++++++- .../test_openapi_validation_middleware.py | 32 ++++++++++++++++++- 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 721da601077..ef4b2be5860 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -66,6 +66,7 @@ _ROUTE_REGEX = "^{}$" ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent) +ResponseT = TypeVar("ResponseT") if TYPE_CHECKING: from aws_lambda_powertools.event_handler.openapi.compat import ( @@ -207,14 +208,14 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: return headers -class Response: +class Response(Generic[ResponseT]): """Response data class that provides greater control over what is returned from the proxy event""" def __init__( self, status_code: int, content_type: Optional[str] = None, - body: Any = None, + body: Optional[ResponseT] = None, headers: Optional[Dict[str, Union[str, List[str]]]] = None, cookies: Optional[List[Cookie]] = None, compress: Optional[bool] = None, diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index c8099d20404..28154466ff6 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -5,6 +5,7 @@ from pydantic import BaseConfig from pydantic.fields import FieldInfo +from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler.openapi.compat import ( ModelField, Required, @@ -724,6 +725,9 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) - # If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo if get_origin(annotation) is Annotated: field_info, type_annotation = get_field_info_annotated_type(annotation, value, is_path_param) + # If the annotation is a Response type, we recursively call this function with the inner type + elif get_origin(annotation) is Response: + field_info, type_annotation = get_field_info_response_type(annotation, value) # If the annotation is not an Annotated type, we use it as the type annotation else: type_annotation = annotation @@ -731,6 +735,14 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) - return field_info, type_annotation +def get_field_info_response_type(annotation, value) -> Tuple[Optional[FieldInfo], Any]: + # Example: get_args(Response[inner_type]) == (inner_type,) # noqa: ERA001 + (inner_type,) = get_args(annotation) + + # Recursively resolve the inner type + return get_field_info_and_type_annotation(inner_type, value, False) + + def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: """ Get the FieldInfo and type annotation from an Annotated type. diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index ec31bb14236..6e4f0395aff 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response from aws_lambda_powertools.event_handler.openapi.models import ( Example, Parameter, @@ -153,6 +153,24 @@ def handler() -> str: assert response.schema_.type == "string" +def test_openapi_with_response_returns(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler() -> Response[Annotated[str, Body(title="Response title")]]: + return Response(body="Hello, world", status_code=200) + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses[200].content[JSON_CONTENT_TYPE] + assert response.schema_.title == "Response title" + assert response.schema_.type == "string" + + def test_openapi_with_omitted_param(): app = APIGatewayRestResolver() diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 2e14979acce..56ea3eec019 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -6,7 +6,7 @@ from pydantic import BaseModel -from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from aws_lambda_powertools.event_handler.openapi.params import Body from aws_lambda_powertools.shared.types import Annotated from tests.functional.utils import load_event @@ -330,3 +330,33 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model: LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 + + +def test_validate_response_return(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Annotated[Model, Body(embed=True)]) -> Response[Model]: + return Response(body=user, status_code=200) + + LOAD_GW_EVENT["httpMethod"] = "POST" + LOAD_GW_EVENT["path"] = "/" + LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + + # THEN the handler should be invoked and return 422 + # THEN the body must be a dict + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict + LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 From 9cbe227b1cac42af574c4c8cffe475ca10b005c0 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 15:11:28 +0100 Subject: [PATCH 5/7] fix: remove code from bad rebase --- aws_lambda_powertools/event_handler/api_gateway.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index ef4b2be5860..5b7262e5d55 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -795,11 +795,6 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic elif self.response.is_json() and not isinstance(self.response.body, str): self.response.body = self.serializer(self.response.body) - # We only apply the serializer when the content type is JSON and the - # body is not a str, to avoid double encoding - elif self.response.is_json() and not isinstance(self.response.body, str): - self.response.body = self.serializer(self.response.body) - return { "statusCode": self.response.status_code, "body": self.response.body, From eff77c83f08d7afaff44dbf0227577b905f8662f Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 15:13:22 +0100 Subject: [PATCH 6/7] fix: remove unused code --- tests/functional/event_handler/test_api_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index e370ca4b99d..570de9ec808 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -681,7 +681,7 @@ def another_one(): def test_no_content_response(): # GIVEN a response with no content-type or body response = Response(status_code=204, content_type=None, body=None, headers=None) - response_builder = ResponseBuilder(response, serializer=json.dumps) + response_builder = ResponseBuilder(response) # WHEN calling to_dict result = response_builder.build(APIGatewayProxyEvent(LOAD_GW_EVENT)) From 129f93e653e1ebdf40fdc44e33c23674c1219ac0 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 15:30:26 +0100 Subject: [PATCH 7/7] fix: refactored test --- .../test_openapi_validation_middleware.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 56ea3eec019..9c7ca371d54 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -342,21 +342,39 @@ class Model(BaseModel): # WHEN a handler is defined with a body parameter @app.post("/") - def handler(user: Annotated[Model, Body(embed=True)]) -> Response[Model]: + def handler(user: Model) -> Response[Model]: return Response(body=user, status_code=200) LOAD_GW_EVENT["httpMethod"] = "POST" LOAD_GW_EVENT["path"] = "/" LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) - # THEN the handler should be invoked and return 422 - # THEN the body must be a dict - result = app(LOAD_GW_EVENT, {}) - assert result["statusCode"] == 422 - assert "missing" in result["body"] - # THEN the handler should be invoked and return 200 # THEN the body must be a dict - LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 + assert result["body"] == {"name": "John", "age": 30} + + +def test_validate_response_invalid_return(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Model) -> Response[Model]: + return Response(body=user, status_code=200) + + LOAD_GW_EVENT["httpMethod"] = "POST" + LOAD_GW_EVENT["path"] = "/" + LOAD_GW_EVENT["body"] = json.dumps({}) + + # THEN the handler should be invoked and return 422 + # THEN the body should have the word missing + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "missing" in result["body"]