Skip to content

Commit 5affb63

Browse files
committed
fix(event_handler): allow use of Response with data validation
1 parent f83e1c5 commit 5affb63

File tree

4 files changed

+65
-4
lines changed

4 files changed

+65
-4
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
_ROUTE_REGEX = "^{}$"
6767

6868
ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
69+
ResponseT = TypeVar("ResponseT")
6970

7071
if TYPE_CHECKING:
7172
from aws_lambda_powertools.event_handler.openapi.compat import (
@@ -207,14 +208,14 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
207208
return headers
208209

209210

210-
class Response:
211+
class Response(Generic[ResponseT]):
211212
"""Response data class that provides greater control over what is returned from the proxy event"""
212213

213214
def __init__(
214215
self,
215216
status_code: int,
216217
content_type: Optional[str] = None,
217-
body: Any = None,
218+
body: Optional[ResponseT] = None,
218219
headers: Optional[Dict[str, Union[str, List[str]]]] = None,
219220
cookies: Optional[List[Cookie]] = None,
220221
compress: Optional[bool] = None,

aws_lambda_powertools/event_handler/openapi/params.py

+12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic import BaseConfig
66
from pydantic.fields import FieldInfo
77

8+
from aws_lambda_powertools.event_handler import Response
89
from aws_lambda_powertools.event_handler.openapi.compat import (
910
ModelField,
1011
Required,
@@ -724,13 +725,24 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -
724725
# If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo
725726
if get_origin(annotation) is Annotated:
726727
field_info, type_annotation = get_field_info_annotated_type(annotation, value, is_path_param)
728+
# If the annotation is a Response type, we recursively call this function with the inner type
729+
elif get_origin(annotation) is Response:
730+
field_info, type_annotation = get_field_info_response_type(annotation, value)
727731
# If the annotation is not an Annotated type, we use it as the type annotation
728732
else:
729733
type_annotation = annotation
730734

731735
return field_info, type_annotation
732736

733737

738+
def get_field_info_response_type(annotation, value) -> Tuple[Optional[FieldInfo], Any]:
739+
# Example: get_args(Response[inner_type]) == (inner_type,) # noqa: ERA001
740+
(inner_type,) = get_args(annotation)
741+
742+
# Recursively resolve the inner type
743+
return get_field_info_and_type_annotation(inner_type, value, False)
744+
745+
734746
def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]:
735747
"""
736748
Get the FieldInfo and type annotation from an Annotated type.

tests/functional/event_handler/test_openapi_params.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pydantic import BaseModel
66

7-
from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver
7+
from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response
88
from aws_lambda_powertools.event_handler.openapi.models import (
99
Example,
1010
Parameter,
@@ -153,6 +153,24 @@ def handler() -> str:
153153
assert response.schema_.type == "string"
154154

155155

156+
def test_openapi_with_response_returns():
157+
app = APIGatewayRestResolver()
158+
159+
@app.get("/")
160+
def handler() -> Response[Annotated[str, Body(title="Response title")]]:
161+
return Response(body="Hello, world", status_code=200)
162+
163+
schema = app.get_openapi_schema()
164+
assert len(schema.paths.keys()) == 1
165+
166+
get = schema.paths["/"].get
167+
assert get.parameters is None
168+
169+
response = get.responses[200].content[JSON_CONTENT_TYPE]
170+
assert response.schema_.title == "Response title"
171+
assert response.schema_.type == "string"
172+
173+
156174
def test_openapi_with_omitted_param():
157175
app = APIGatewayRestResolver()
158176

tests/functional/event_handler/test_openapi_validation_middleware.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pydantic import BaseModel
88

9-
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
9+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
1010
from aws_lambda_powertools.event_handler.openapi.params import Body
1111
from aws_lambda_powertools.shared.types import Annotated
1212
from tests.functional.utils import load_event
@@ -330,3 +330,33 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
330330
LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}})
331331
result = app(LOAD_GW_EVENT, {})
332332
assert result["statusCode"] == 200
333+
334+
335+
def test_validate_response_return():
336+
# GIVEN an APIGatewayRestResolver with validation enabled
337+
app = APIGatewayRestResolver(enable_validation=True)
338+
339+
class Model(BaseModel):
340+
name: str
341+
age: int
342+
343+
# WHEN a handler is defined with a body parameter
344+
@app.post("/")
345+
def handler(user: Annotated[Model, Body(embed=True)]) -> Response[Model]:
346+
return Response(body=user, status_code=200)
347+
348+
LOAD_GW_EVENT["httpMethod"] = "POST"
349+
LOAD_GW_EVENT["path"] = "/"
350+
LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30})
351+
352+
# THEN the handler should be invoked and return 422
353+
# THEN the body must be a dict
354+
result = app(LOAD_GW_EVENT, {})
355+
assert result["statusCode"] == 422
356+
assert "missing" in result["body"]
357+
358+
# THEN the handler should be invoked and return 200
359+
# THEN the body must be a dict
360+
LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}})
361+
result = app(LOAD_GW_EVENT, {})
362+
assert result["statusCode"] == 200

0 commit comments

Comments
 (0)