diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 555ec519bf6..5420d76469f 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -136,14 +136,13 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> return self._handle_response(route=route, response=response) def _handle_response(self, *, route: Route, response: Response): - # Process the response body if it exists - if response.body: - # Validate and serialize the response, if it's JSON - if response.is_json(): - response.body = self._serialize_response( - field=route.dependant.return_param, - response_content=response.body, - ) + # Check if we have a return type defined + if route.dependant.return_param: + # Validate and serialize the response, including None + response.body = self._serialize_response( + field=route.dependant.return_param, + response_content=response.body, + ) return response @@ -164,15 +163,6 @@ def _serialize_response( """ if field: errors: list[dict[str, Any]] = [] - # MAINTENANCE: remove this when we drop pydantic v1 - if not hasattr(field, "serializable"): - response_content = self._prepare_response_content( - response_content, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) if errors: raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) @@ -187,7 +177,6 @@ def _serialize_response( exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) - return jsonable_encoder( value, include=include, @@ -199,7 +188,7 @@ def _serialize_response( custom_serializer=self._validation_serializer, ) else: - # Just serialize the response content returned from the handler + # Just serialize the response content returned from the handler. return jsonable_encoder(response_content, custom_serializer=self._validation_serializer) def _prepare_response_content( diff --git a/tests/functional/event_handler/_pydantic/test_openapi_serialization.py b/tests/functional/event_handler/_pydantic/test_openapi_serialization.py index 7d70488c021..ef5c8ddd938 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_serialization.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_serialization.py @@ -1,11 +1,20 @@ import json -from typing import Dict +from dataclasses import dataclass +from typing import Dict, Optional, Set import pytest +from pydantic import BaseModel from aws_lambda_powertools.event_handler import APIGatewayRestResolver +@dataclass +class Person: + name: str + birth_date: str + scores: Set[int] + + def test_openapi_duplicated_serialization(): # GIVEN APIGatewayRestResolver is initialized with enable_validation=True app = APIGatewayRestResolver(enable_validation=True) @@ -61,3 +70,124 @@ def handler(): # THEN we the custom serializer should be used assert response["body"] == "hello world" + + +def test_valid_model_returned_for_optional_type(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/valid_optional") + def handler_valid_optional() -> Optional[Model]: + return Model(name="John", age=30) + + # WHEN returning a valid model for an Optional type + gw_event["path"] = "/valid_optional" + result = app(gw_event, {}) + + # THEN it should succeed and return the serialized model + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"name": "John", "age": 30} + + +def test_serialize_response_without_field(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined without return type annotation + @app.get("/test") + def handler(): + return {"message": "Hello, World!"} + + gw_event["path"] = "/test" + + # THEN the handler should be invoked and return 200 + # AND the body must be a JSON object + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '{"message":"Hello, World!"}' + + +def test_serialize_response_list(gw_event): + """Test serialization of list responses containing complex types""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a list containing various types + @app.get("/test") + def handler(): + return [{"set": [1, 2, 3]}, {"simple": "value"}] + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '[{"set":[1,2,3]},{"simple":"value"}]' + + +def test_serialize_response_nested_dict(gw_event): + """Test serialization of nested dictionary responses""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a nested dictionary with complex types + @app.get("/test") + def handler(): + return {"nested": {"date": "2000-01-01", "set": [1, 2, 3]}, "simple": "value"} + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '{"nested":{"date":"2000-01-01","set":[1,2,3]},"simple":"value"}' + + +def test_serialize_response_dataclass(gw_event): + """Test serialization of dataclass responses""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a dataclass instance + @app.get("/test") + def handler(): + return Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91]) + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + assert response["body"] == '{"name":"John Doe","birth_date":"1990-01-01","scores":[95,87,91]}' + + +def test_serialize_response_mixed_types(gw_event): + """Test serialization of mixed type responses""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler returns a response with mixed types + @app.get("/test") + def handler(): + person = Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91]) + return { + "person": person, + "records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}], + "metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]}, + } + + gw_event["path"] = "/test" + + # THEN the response should be properly serialized + response = app(gw_event, None) + assert response["statusCode"] == 200 + expected = { + "person": {"name": "John Doe", "birth_date": "1990-01-01", "scores": [95, 87, 91]}, + "records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}], + "metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]}, + } + assert json.loads(response["body"]) == expected diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 54425f34986..f0b4acc94ad 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1128,3 +1128,76 @@ def handler(user_id: int = 123): # THEN the handler should be invoked and return 200 result = app(minimal_event, {}) assert result["statusCode"] == 200 + + +def test_validation_error_none_returned_non_optional_type(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/none_not_allowed") + def handler_none_not_allowed() -> Model: + return None # type: ignore + + # WHEN returning None for a non-Optional type + gw_event["path"] = "/none_not_allowed" + result = app(gw_event, {}) + + # THEN it should return a validation error + assert result["statusCode"] == 422 + body = json.loads(result["body"]) + assert "model_attributes_type" in body["detail"][0]["type"] + + +def test_none_returned_for_optional_type(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get("/none_allowed") + def handler_none_allowed() -> Optional[Model]: + return None + + # WHEN returning None for an Optional type + gw_event["path"] = "/none_allowed" + result = app(gw_event, {}) + + # THEN it should succeed + assert result["statusCode"] == 200 + assert result["body"] == "null" + + +@pytest.mark.parametrize( + "path, body", + [ + ("/empty_dict", {}), + ("/empty_list", []), + ("/none", "null"), + ("/empty_string", ""), + ], + ids=["empty_dict", "empty_list", "none", "empty_string"], +) +def test_none_returned_for_falsy_return(gw_event, path, body): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + @app.get(path) + def handler_none_allowed() -> Model: + return body + + # WHEN returning None for an Optional type + gw_event["path"] = path + result = app(gw_event, {}) + + # THEN it should succeed + assert result["statusCode"] == 422