Skip to content

fix(event_handler): allow fine grained Response with data validation #3394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
_ROUTE_REGEX = "^{}$"

ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
ResponseT = TypeVar("ResponseT")

if TYPE_CHECKING:
from aws_lambda_powertools.event_handler.openapi.compat import (
Expand Down Expand Up @@ -207,14 +208,14 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
return headers


class Response:
class Response(Generic[ResponseT]):
"""Response data class that provides greater control over what is returned from the proxy event"""

def __init__(
self,
status_code: int,
content_type: Optional[str] = None,
body: Any = None,
body: Optional[ResponseT] = None,
headers: Optional[Dict[str, Union[str, List[str]]]] = None,
cookies: Optional[List[Cookie]] = None,
compress: Optional[bool] = None,
Expand Down
12 changes: 12 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseConfig
from pydantic.fields import FieldInfo

from aws_lambda_powertools.event_handler import Response
from aws_lambda_powertools.event_handler.openapi.compat import (
ModelField,
Required,
Expand Down Expand Up @@ -724,13 +725,24 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -
# If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo
if get_origin(annotation) is Annotated:
field_info, type_annotation = get_field_info_annotated_type(annotation, value, is_path_param)
# If the annotation is a Response type, we recursively call this function with the inner type
elif get_origin(annotation) is Response:
field_info, type_annotation = get_field_info_response_type(annotation, value)
# If the annotation is not an Annotated type, we use it as the type annotation
else:
type_annotation = annotation

return field_info, type_annotation


def get_field_info_response_type(annotation, value) -> Tuple[Optional[FieldInfo], Any]:
# Example: get_args(Response[inner_type]) == (inner_type,) # noqa: ERA001
(inner_type,) = get_args(annotation)

# Recursively resolve the inner type
return get_field_info_and_type_annotation(inner_type, value, False)


def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]:
"""
Get the FieldInfo and type annotation from an Annotated type.
Expand Down
20 changes: 19 additions & 1 deletion tests/functional/event_handler/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.openapi.models import (
Example,
Parameter,
Expand Down Expand Up @@ -153,6 +153,24 @@ def handler() -> str:
assert response.schema_.type == "string"


def test_openapi_with_response_returns():
app = APIGatewayRestResolver()

@app.get("/")
def handler() -> Response[Annotated[str, Body(title="Response title")]]:
return Response(body="Hello, world", status_code=200)

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/"].get
assert get.parameters is None

response = get.responses[200].content[JSON_CONTENT_TYPE]
assert response.schema_.title == "Response title"
assert response.schema_.type == "string"


def test_openapi_with_omitted_param():
app = APIGatewayRestResolver()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic import BaseModel

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.openapi.params import Body
from aws_lambda_powertools.shared.types import Annotated
from tests.functional.utils import load_event
Expand Down Expand Up @@ -330,3 +330,51 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}})
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200


def test_validate_response_return():
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

# WHEN a handler is defined with a body parameter
@app.post("/")
def handler(user: Model) -> Response[Model]:
return Response(body=user, status_code=200)

LOAD_GW_EVENT["httpMethod"] = "POST"
LOAD_GW_EVENT["path"] = "/"
LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30})

# THEN the handler should be invoked and return 200
# THEN the body must be a dict
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == {"name": "John", "age": 30}


def test_validate_response_invalid_return():
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

# WHEN a handler is defined with a body parameter
@app.post("/")
def handler(user: Model) -> Response[Model]:
return Response(body=user, status_code=200)

LOAD_GW_EVENT["httpMethod"] = "POST"
LOAD_GW_EVENT["path"] = "/"
LOAD_GW_EVENT["body"] = json.dumps({})

# THEN the handler should be invoked and return 422
# THEN the body should have the word missing
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 422
assert "missing" in result["body"]