Skip to content

feat(openapi): enhance support for tuple return type validation #5997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -314,4 +314,7 @@ examples/**/sam/.aws-sam

cdk.out
# NOTE: different accounts will be used for E2E thus creating unnecessary git clutter
cdk.context.json
cdk.context.json

# vim
*.swp
23 changes: 20 additions & 3 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,8 @@ def analyze_param(
ModelField | None
The type annotation and the Pydantic field representing the parameter
"""
field_info, type_annotation = get_field_info_and_type_annotation(annotation, value, is_path_param)
field_info, type_annotation = \
get_field_info_and_type_annotation(annotation, value, is_path_param, is_response_param)

# If the value is a FieldInfo, we use it as the FieldInfo for the parameter
if isinstance(value, FieldInfo):
Expand Down Expand Up @@ -962,7 +963,9 @@ def analyze_param(
return field


def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -> tuple[FieldInfo | None, Any]:
def get_field_info_and_type_annotation(
annotation, value, is_path_param: bool, is_response_param: bool
) -> tuple[FieldInfo | None, Any]:
"""
Get the FieldInfo and type annotation from an annotation and value.
"""
Expand All @@ -976,19 +979,33 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -
# If the annotation is a Response type, we recursively call this function with the inner type
elif get_origin(annotation) is Response:
field_info, type_annotation = get_field_info_response_type(annotation, value)
# If the response param is a tuple with two elements, we use the first element as the type annotation,
# just like we did in the APIGateway._to_response
elif is_response_param and get_origin(annotation) is tuple and len(get_args(annotation)) == 2:
field_info, type_annotation = get_field_info_tuple_type(annotation, value)
# If the annotation is not an Annotated type, we use it as the type annotation
else:
type_annotation = annotation

return field_info, type_annotation


def get_field_info_tuple_type(annotation, value) -> tuple[FieldInfo | None, Any]:
(inner_type, _) = get_args(annotation)

# If the inner type is an Annotated type, we need to extract the type annotation and the FieldInfo
if get_origin(inner_type) is Annotated:
return get_field_info_annotated_type(inner_type, value, False)

return None, inner_type


def get_field_info_response_type(annotation, value) -> tuple[FieldInfo | None, Any]:
# Example: get_args(Response[inner_type]) == (inner_type,) # noqa: ERA001
(inner_type,) = get_args(annotation)

# Recursively resolve the inner type
return get_field_info_and_type_annotation(inner_type, value, False)
return get_field_info_and_type_annotation(inner_type, value, False, True)


def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tuple[FieldInfo | None, Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from datetime import datetime
from typing import List
from typing import List, Tuple

from pydantic import BaseModel, Field
from typing_extensions import Annotated
Expand Down Expand Up @@ -172,6 +172,42 @@ def handler() -> Response[Annotated[str, Body(title="Response title")]]:
assert response.schema_.type == "string"


def test_openapi_with_tuple_returns():
app = APIGatewayRestResolver()

@app.get("/")
def handler() -> Tuple[str, int]:
return "Hello, world", 200

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/"].get
assert get.parameters is None

response = get.responses[200].content[JSON_CONTENT_TYPE]
assert response.schema_.title == "Return"
assert response.schema_.type == "string"


def test_openapi_with_tuple_annotated_returns():
app = APIGatewayRestResolver()

@app.get("/")
def handler() -> Tuple[Annotated[str, Body(title="Response title")], int]:
return "Hello, world", 200

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/"].get
assert get.parameters is None

response = get.responses[200].content[JSON_CONTENT_TYPE]
assert response.schema_.title == "Response title"
assert response.schema_.type == "string"


def test_openapi_with_omitted_param():
app = APIGatewayRestResolver()

Expand Down
Loading