Skip to content

fix(openapi): validate response serialization when falsy #6119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 20, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,13 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
return self._handle_response(route=route, response=response)

def _handle_response(self, *, route: Route, response: Response):
# Process the response body if it exists
if response.body:
# Validate and serialize the response, if it's JSON
if response.is_json():
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
)
# Check if we have a return type defined
if route.dependant.return_param:
# Validate and serialize the response, including None
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
)

return response

Expand All @@ -164,15 +163,6 @@ def _serialize_response(
"""
if field:
errors: list[dict[str, Any]] = []
# MAINTENANCE: remove this when we drop pydantic v1
if not hasattr(field, "serializable"):
response_content = self._prepare_response_content(
response_content,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
if errors:
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
Expand All @@ -187,7 +177,6 @@ def _serialize_response(
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

return jsonable_encoder(
value,
include=include,
Expand All @@ -199,7 +188,7 @@ def _serialize_response(
custom_serializer=self._validation_serializer,
)
else:
# Just serialize the response content returned from the handler
# Just serialize the response content returned from the handler.
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)

def _prepare_response_content(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import json
from typing import Dict
from dataclasses import dataclass
from typing import Dict, Optional, Set

import pytest
from pydantic import BaseModel

from aws_lambda_powertools.event_handler import APIGatewayRestResolver


@dataclass
class Person:
name: str
birth_date: str
scores: Set[int]


def test_openapi_duplicated_serialization():
# GIVEN APIGatewayRestResolver is initialized with enable_validation=True
app = APIGatewayRestResolver(enable_validation=True)
Expand Down Expand Up @@ -61,3 +70,124 @@ def handler():

# THEN we the custom serializer should be used
assert response["body"] == "hello world"


def test_valid_model_returned_for_optional_type(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

@app.get("/valid_optional")
def handler_valid_optional() -> Optional[Model]:
return Model(name="John", age=30)

# WHEN returning a valid model for an Optional type
gw_event["path"] = "/valid_optional"
result = app(gw_event, {})

# THEN it should succeed and return the serialized model
assert result["statusCode"] == 200
assert json.loads(result["body"]) == {"name": "John", "age": 30}


def test_serialize_response_without_field(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler is defined without return type annotation
@app.get("/test")
def handler():
return {"message": "Hello, World!"}

gw_event["path"] = "/test"

# THEN the handler should be invoked and return 200
# AND the body must be a JSON object
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '{"message":"Hello, World!"}'


def test_serialize_response_list(gw_event):
"""Test serialization of list responses containing complex types"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a list containing various types
@app.get("/test")
def handler():
return [{"set": [1, 2, 3]}, {"simple": "value"}]

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '[{"set":[1,2,3]},{"simple":"value"}]'


def test_serialize_response_nested_dict(gw_event):
"""Test serialization of nested dictionary responses"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a nested dictionary with complex types
@app.get("/test")
def handler():
return {"nested": {"date": "2000-01-01", "set": [1, 2, 3]}, "simple": "value"}

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '{"nested":{"date":"2000-01-01","set":[1,2,3]},"simple":"value"}'


def test_serialize_response_dataclass(gw_event):
"""Test serialization of dataclass responses"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a dataclass instance
@app.get("/test")
def handler():
return Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91])

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '{"name":"John Doe","birth_date":"1990-01-01","scores":[95,87,91]}'


def test_serialize_response_mixed_types(gw_event):
"""Test serialization of mixed type responses"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a response with mixed types
@app.get("/test")
def handler():
person = Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91])
return {
"person": person,
"records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}],
"metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]},
}

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
expected = {
"person": {"name": "John Doe", "birth_date": "1990-01-01", "scores": [95, 87, 91]},
"records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}],
"metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]},
}
assert json.loads(response["body"]) == expected
Original file line number Diff line number Diff line change
Expand Up @@ -1128,3 +1128,76 @@ def handler(user_id: int = 123):
# THEN the handler should be invoked and return 200
result = app(minimal_event, {})
assert result["statusCode"] == 200


def test_validation_error_none_returned_non_optional_type(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

@app.get("/none_not_allowed")
def handler_none_not_allowed() -> Model:
return None # type: ignore

# WHEN returning None for a non-Optional type
gw_event["path"] = "/none_not_allowed"
result = app(gw_event, {})

# THEN it should return a validation error
assert result["statusCode"] == 422
body = json.loads(result["body"])
assert "model_attributes_type" in body["detail"][0]["type"]


def test_none_returned_for_optional_type(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

@app.get("/none_allowed")
def handler_none_allowed() -> Optional[Model]:
return None

# WHEN returning None for an Optional type
gw_event["path"] = "/none_allowed"
result = app(gw_event, {})

# THEN it should succeed
assert result["statusCode"] == 200
assert result["body"] == "null"


@pytest.mark.parametrize(
"path, body",
[
("/empty_dict", {}),
("/empty_list", []),
("/none", "null"),
("/empty_string", ""),
],
ids=["empty_dict", "empty_list", "none", "empty_string"],
)
def test_none_returned_for_falsy_return(gw_event, path, body):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

@app.get(path)
def handler_none_allowed() -> Model:
return body

# WHEN returning None for an Optional type
gw_event["path"] = path
result = app(gw_event, {})

# THEN it should succeed
assert result["statusCode"] == 422
Loading