From 1a0802264b1c04aa24ea3dfbc8a0e5ec360b905f Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Mon, 17 Feb 2025 14:06:04 -0300 Subject: [PATCH 1/7] fix(openapi): validate response serialization when falsy --- .../middlewares/openapi_validation.py | 42 +++++------------- .../test_openapi_validation_middleware.py | 44 +++++++++++++++++++ 2 files changed, 55 insertions(+), 31 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 555ec519bf6..891b9b2f592 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -136,59 +136,40 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> return self._handle_response(route=route, response=response) def _handle_response(self, *, route: Route, response: Response): - # Process the response body if it exists - if response.body: - # Validate and serialize the response, if it's JSON - if response.is_json(): + # Check if we have a return type defined + if route.dependant.return_param: + try: + # Validate all responses, including None response.body = self._serialize_response( field=route.dependant.return_param, response_content=response.body, ) + except RequestValidationError as e: + logger.error(f"Response validation failed: {str(e)}") + response.status_code = 422 + response.body = {"detail": e.errors()} return response def _serialize_response( self, *, - field: ModelField | None = None, + field: Any = None, response_content: Any, include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = True, + by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: - """ - Serialize the response content according to the field type. - """ if field: errors: list[dict[str, Any]] = [] - # MAINTENANCE: remove this when we drop pydantic v1 - if not hasattr(field, "serializable"): - response_content = self._prepare_response_content( - response_content, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) if errors: raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) - if hasattr(field, "serialize"): - return field.serialize( - value, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - - return jsonable_encoder( + return field.serialize( value, include=include, exclude=exclude, @@ -196,7 +177,6 @@ def _serialize_response( exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, - custom_serializer=self._validation_serializer, ) else: # Just serialize the response content returned from the handler diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 54425f34986..c7078adf850 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1128,3 +1128,47 @@ def handler(user_id: int = 123): # THEN the handler should be invoked and return 200 result = app(minimal_event, {}) assert result["statusCode"] == 200 + + +def test_validate_optional_return_types(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # AND handlers defined with different Optional return types + @app.get("/none_not_allowed") + def handler_none_not_allowed() -> Model: + return None # type: ignore + + @app.get("/none_allowed") + def handler_none_allowed() -> Optional[Model]: + return None + + @app.get("/valid_optional") + def handler_valid_optional() -> Optional[Model]: + return Model(name="John", age=30) + + # WHEN returning None for a non-Optional type + gw_event["path"] = "/none_not_allowed" + result = app(gw_event, {}) + # THEN it should return a validation error + assert result["statusCode"] == 422 + body = json.loads(result["body"]) + assert "model_attributes_type" in body["detail"][0]["type"] + + # WHEN returning None for an Optional type + gw_event["path"] = "/none_allowed" + result = app(gw_event, {}) + # THEN it should succeed + assert result["statusCode"] == 200 + assert result["body"] == "null" + + # WHEN returning a valid model for an Optional type + gw_event["path"] = "/valid_optional" + result = app(gw_event, {}) + # THEN it should succeed and return the serialized model + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"name": "John", "age": 30} From e414efedbebae6826698a0ff9358adaaae276261 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Mon, 17 Feb 2025 21:46:57 -0300 Subject: [PATCH 2/7] revert serialize --- .../event_handler/middlewares/openapi_validation.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 891b9b2f592..aad24806b9d 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -139,7 +139,7 @@ def _handle_response(self, *, route: Route, response: Response): # Check if we have a return type defined if route.dependant.return_param: try: - # Validate all responses, including None + # Validate and serialize the response, including None response.body = self._serialize_response( field=route.dependant.return_param, response_content=response.body, @@ -154,15 +154,18 @@ def _handle_response(self, *, route: Route, response: Response): def _serialize_response( self, *, - field: Any = None, + field: ModelField | None = None, response_content: Any, include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + by_alias: bool = True, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: + """ + Serialize the response content according to the field type. + """ if field: errors: list[dict[str, Any]] = [] value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) From 2d17882131f5563603e5934295dd5d4a54d3c15d Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Tue, 18 Feb 2025 11:28:11 -0300 Subject: [PATCH 3/7] change comment --- .../event_handler/middlewares/openapi_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index aad24806b9d..2bf365f4910 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -182,7 +182,7 @@ def _serialize_response( exclude_none=exclude_none, ) else: - # Just serialize the response content returned from the handler + # Just serialize the response content returned from the handler. return jsonable_encoder(response_content, custom_serializer=self._validation_serializer) def _prepare_response_content( From 5e46875eac8641c2e616bb7178369fb68964de0c Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Tue, 18 Feb 2025 17:08:05 -0300 Subject: [PATCH 4/7] add more tests --- .../test_openapi_validation_middleware.py | 109 +++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index c7078adf850..4b1a17dcf5d 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import pytest from pydantic import BaseModel @@ -1172,3 +1172,110 @@ def handler_valid_optional() -> Optional[Model]: # THEN it should succeed and return the serialized model assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} + + +def test_serialize_response_without_field(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined without return type annotation + @app.get("/test") + def handler(): + return {"message": "Hello, World!"} + + gw_event["path"] = "/test" + + # THEN the handler should be invoked and return 200 + # AND the body must be a JSON object + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '{"message":"Hello, World!"}' + + +def test_serialize_response_list(gw_event): + """Test serialization of list responses containing complex types""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a list containing various types + @app.get("/test") + def handler(): + return [{"set": [1, 2, 3]}, {"simple": "value"}] + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '[{"set":[1,2,3]},{"simple":"value"}]' + + +def test_serialize_response_nested_dict(gw_event): + """Test serialization of nested dictionary responses""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a nested dictionary with complex types + @app.get("/test") + def handler(): + return {"nested": {"date": "2000-01-01", "set": [1, 2, 3]}, "simple": "value"} + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '{"nested":{"date":"2000-01-01","set":[1,2,3]},"simple":"value"}' + + +@dataclass +class Person: + name: str + birth_date: str + scores: Set[int] + + +def test_serialize_response_dataclass(gw_event): + """Test serialization of dataclass responses""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a dataclass instance + @app.get("/test") + def handler(): + return Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91]) + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '{"name":"John Doe","birth_date":"1990-01-01","scores":[95,87,91]}' + + +def test_serialize_response_mixed_types(gw_event): + """Test serialization of mixed type responses""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a response with mixed types + @app.get("/test") + def handler(): + person = Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91]) + return { + "person": person, + "records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}], + "metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]}, + } + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + expected = { + "person": {"name": "John Doe", "birth_date": "1990-01-01", "scores": [95, 87, 91]}, + "records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}], + "metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]}, + } + assert json.loads(response["body"]) == expected From 4c4cee3fc59ab2d662c57d4ba0422acf20b00580 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Wed, 19 Feb 2025 10:51:27 -0300 Subject: [PATCH 5/7] revert serialize --- .../middlewares/openapi_validation.py | 12 ++++++ .../test_openapi_validation_middleware.py | 40 ++++++++++++++----- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 2bf365f4910..708d4f771fa 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -172,6 +172,17 @@ def _serialize_response( if errors: raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) + if hasattr(field, "serialize"): + return field.serialize( + value, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + return field.serialize( value, include=include, @@ -180,6 +191,7 @@ def _serialize_response( exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, + custom_serializer=self._validation_serializer, ) else: # Just serialize the response content returned from the handler. diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 4b1a17dcf5d..14b79cb5f97 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1130,7 +1130,7 @@ def handler(user_id: int = 123): assert result["statusCode"] == 200 -def test_validate_optional_return_types(gw_event): +def test_validation_error_none_returned_non_optional_type(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -1138,37 +1138,57 @@ class Model(BaseModel): name: str age: int - # AND handlers defined with different Optional return types @app.get("/none_not_allowed") def handler_none_not_allowed() -> Model: return None # type: ignore - @app.get("/none_allowed") - def handler_none_allowed() -> Optional[Model]: - return None - - @app.get("/valid_optional") - def handler_valid_optional() -> Optional[Model]: - return Model(name="John", age=30) - # WHEN returning None for a non-Optional type gw_event["path"] = "/none_not_allowed" result = app(gw_event, {}) + # THEN it should return a validation error assert result["statusCode"] == 422 body = json.loads(result["body"]) assert "model_attributes_type" in body["detail"][0]["type"] + +def test_none_returned_for_optional_type(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/none_allowed") + def handler_none_allowed() -> Optional[Model]: + return None + # WHEN returning None for an Optional type gw_event["path"] = "/none_allowed" result = app(gw_event, {}) + # THEN it should succeed assert result["statusCode"] == 200 assert result["body"] == "null" + +def test_valid_model_returned_for_optional_type(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/valid_optional") + def handler_valid_optional() -> Optional[Model]: + return Model(name="John", age=30) + # WHEN returning a valid model for an Optional type gw_event["path"] = "/valid_optional" result = app(gw_event, {}) + # THEN it should succeed and return the serialized model assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} From df0c9ea73a97e699e23e66ed4e9b13a2b77b61a0 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Wed, 19 Feb 2025 10:59:38 -0300 Subject: [PATCH 6/7] fix mypy --- .../event_handler/middlewares/openapi_validation.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 708d4f771fa..c546fab94fe 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -182,15 +182,8 @@ def _serialize_response( exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) - - return field.serialize( + return jsonable_encoder( value, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, custom_serializer=self._validation_serializer, ) else: From 05a4ede6c7ca3824dad74f62eef508db117191f7 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 19 Feb 2025 17:43:19 +0000 Subject: [PATCH 7/7] Refactoring tests + removing additional code --- .../middlewares/openapi_validation.py | 21 +-- .../_pydantic/test_openapi_serialization.py | 132 ++++++++++++++++- .../test_openapi_validation_middleware.py | 136 +++--------------- 3 files changed, 161 insertions(+), 128 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index c546fab94fe..5420d76469f 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -138,16 +138,11 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> def _handle_response(self, *, route: Route, response: Response): # Check if we have a return type defined if route.dependant.return_param: - try: - # Validate and serialize the response, including None - response.body = self._serialize_response( - field=route.dependant.return_param, - response_content=response.body, - ) - except RequestValidationError as e: - logger.error(f"Response validation failed: {str(e)}") - response.status_code = 422 - response.body = {"detail": e.errors()} + # Validate and serialize the response, including None + response.body = self._serialize_response( + field=route.dependant.return_param, + response_content=response.body, + ) return response @@ -184,6 +179,12 @@ def _serialize_response( ) return jsonable_encoder( value, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, custom_serializer=self._validation_serializer, ) else: diff --git a/tests/functional/event_handler/_pydantic/test_openapi_serialization.py b/tests/functional/event_handler/_pydantic/test_openapi_serialization.py index 7d70488c021..ef5c8ddd938 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_serialization.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_serialization.py @@ -1,11 +1,20 @@ import json -from typing import Dict +from dataclasses import dataclass +from typing import Dict, Optional, Set import pytest +from pydantic import BaseModel from aws_lambda_powertools.event_handler import APIGatewayRestResolver +@dataclass +class Person: + name: str + birth_date: str + scores: Set[int] + + def test_openapi_duplicated_serialization(): # GIVEN APIGatewayRestResolver is initialized with enable_validation=True app = APIGatewayRestResolver(enable_validation=True) @@ -61,3 +70,124 @@ def handler(): # THEN we the custom serializer should be used assert response["body"] == "hello world" + + +def test_valid_model_returned_for_optional_type(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/valid_optional") + def handler_valid_optional() -> Optional[Model]: + return Model(name="John", age=30) + + # WHEN returning a valid model for an Optional type + gw_event["path"] = "/valid_optional" + result = app(gw_event, {}) + + # THEN it should succeed and return the serialized model + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"name": "John", "age": 30} + + +def test_serialize_response_without_field(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined without return type annotation + @app.get("/test") + def handler(): + return {"message": "Hello, World!"} + + gw_event["path"] = "/test" + + # THEN the handler should be invoked and return 200 + # AND the body must be a JSON object + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '{"message":"Hello, World!"}' + + +def test_serialize_response_list(gw_event): + """Test serialization of list responses containing complex types""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a list containing various types + @app.get("/test") + def handler(): + return [{"set": [1, 2, 3]}, {"simple": "value"}] + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '[{"set":[1,2,3]},{"simple":"value"}]' + + +def test_serialize_response_nested_dict(gw_event): + """Test serialization of nested dictionary responses""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a nested dictionary with complex types + @app.get("/test") + def handler(): + return {"nested": {"date": "2000-01-01", "set": [1, 2, 3]}, "simple": "value"} + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '{"nested":{"date":"2000-01-01","set":[1,2,3]},"simple":"value"}' + + +def test_serialize_response_dataclass(gw_event): + """Test serialization of dataclass responses""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a dataclass instance + @app.get("/test") + def handler(): + return Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91]) + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '{"name":"John Doe","birth_date":"1990-01-01","scores":[95,87,91]}' + + +def test_serialize_response_mixed_types(gw_event): + """Test serialization of mixed type responses""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a response with mixed types + @app.get("/test") + def handler(): + person = Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91]) + return { + "person": person, + "records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}], + "metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]}, + } + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + expected = { + "person": {"name": "John Doe", "birth_date": "1990-01-01", "scores": [95, 87, 91]}, + "records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}], + "metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]}, + } + assert json.loads(response["body"]) == expected diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 14b79cb5f97..f0b4acc94ad 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from typing import List, Optional, Set, Tuple +from typing import List, Optional, Tuple import pytest from pydantic import BaseModel @@ -1173,7 +1173,17 @@ def handler_none_allowed() -> Optional[Model]: assert result["body"] == "null" -def test_valid_model_returned_for_optional_type(gw_event): +@pytest.mark.parametrize( + "path, body", + [ + ("/empty_dict", {}), + ("/empty_list", []), + ("/none", "null"), + ("/empty_string", ""), + ], + ids=["empty_dict", "empty_list", "none", "empty_string"], +) +def test_none_returned_for_falsy_return(gw_event, path, body): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -1181,121 +1191,13 @@ class Model(BaseModel): name: str age: int - @app.get("/valid_optional") - def handler_valid_optional() -> Optional[Model]: - return Model(name="John", age=30) + @app.get(path) + def handler_none_allowed() -> Model: + return body - # WHEN returning a valid model for an Optional type - gw_event["path"] = "/valid_optional" + # WHEN returning None for an Optional type + gw_event["path"] = path result = app(gw_event, {}) - # THEN it should succeed and return the serialized model - assert result["statusCode"] == 200 - assert json.loads(result["body"]) == {"name": "John", "age": 30} - - -def test_serialize_response_without_field(gw_event): - # GIVEN an APIGatewayRestResolver with validation enabled - app = APIGatewayRestResolver(enable_validation=True) - - # WHEN a handler is defined without return type annotation - @app.get("/test") - def handler(): - return {"message": "Hello, World!"} - - gw_event["path"] = "/test" - - # THEN the handler should be invoked and return 200 - # AND the body must be a JSON object - response = app(gw_event, None) - assert response["statusCode"] == 200 - assert response["body"] == '{"message":"Hello, World!"}' - - -def test_serialize_response_list(gw_event): - """Test serialization of list responses containing complex types""" - # GIVEN an APIGatewayRestResolver with validation enabled - app = APIGatewayRestResolver(enable_validation=True) - - # WHEN a handler returns a list containing various types - @app.get("/test") - def handler(): - return [{"set": [1, 2, 3]}, {"simple": "value"}] - - gw_event["path"] = "/test" - - # THEN the response should be properly serialized - response = app(gw_event, None) - assert response["statusCode"] == 200 - assert response["body"] == '[{"set":[1,2,3]},{"simple":"value"}]' - - -def test_serialize_response_nested_dict(gw_event): - """Test serialization of nested dictionary responses""" - # GIVEN an APIGatewayRestResolver with validation enabled - app = APIGatewayRestResolver(enable_validation=True) - - # WHEN a handler returns a nested dictionary with complex types - @app.get("/test") - def handler(): - return {"nested": {"date": "2000-01-01", "set": [1, 2, 3]}, "simple": "value"} - - gw_event["path"] = "/test" - - # THEN the response should be properly serialized - response = app(gw_event, None) - assert response["statusCode"] == 200 - assert response["body"] == '{"nested":{"date":"2000-01-01","set":[1,2,3]},"simple":"value"}' - - -@dataclass -class Person: - name: str - birth_date: str - scores: Set[int] - - -def test_serialize_response_dataclass(gw_event): - """Test serialization of dataclass responses""" - # GIVEN an APIGatewayRestResolver with validation enabled - app = APIGatewayRestResolver(enable_validation=True) - - # WHEN a handler returns a dataclass instance - @app.get("/test") - def handler(): - return Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91]) - - gw_event["path"] = "/test" - - # THEN the response should be properly serialized - response = app(gw_event, None) - assert response["statusCode"] == 200 - assert response["body"] == '{"name":"John Doe","birth_date":"1990-01-01","scores":[95,87,91]}' - - -def test_serialize_response_mixed_types(gw_event): - """Test serialization of mixed type responses""" - # GIVEN an APIGatewayRestResolver with validation enabled - app = APIGatewayRestResolver(enable_validation=True) - - # WHEN a handler returns a response with mixed types - @app.get("/test") - def handler(): - person = Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91]) - return { - "person": person, - "records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}], - "metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]}, - } - - gw_event["path"] = "/test" - - # THEN the response should be properly serialized - response = app(gw_event, None) - assert response["statusCode"] == 200 - expected = { - "person": {"name": "John Doe", "birth_date": "1990-01-01", "scores": [95, 87, 91]}, - "records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}], - "metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]}, - } - assert json.loads(response["body"]) == expected + # THEN it should succeed + assert result["statusCode"] == 422