Skip to content

Commit f7e6bc1

Browse files
committed
fix: tests
1 parent 0312c11 commit f7e6bc1

File tree

6 files changed

+119
-52
lines changed

6 files changed

+119
-52
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ def __init__(
263263
# _middleware_stack_built is used to ensure the middleware stack is only built once.
264264
self._middleware_stack_built = False
265265

266+
# _dependant is used to cache the dependant model for the handler function
267+
self._dependant: Optional["Dependant"] = None
268+
266269
def __call__(
267270
self,
268271
router_middlewares: List[Callable],
@@ -353,6 +356,15 @@ def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]])
353356

354357
self._middleware_stack_built = True
355358

359+
@property
360+
def dependant(self) -> "Dependant":
361+
if self._dependant is None:
362+
from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant
363+
364+
self._dependant = get_dependant(path=self.path, call=self.func)
365+
366+
return self._dependant
367+
356368
def _get_openapi_path(
357369
self,
358370
*,
@@ -1168,8 +1180,7 @@ def get_openapi_schema(
11681180
get_compat_model_name_map,
11691181
get_definitions,
11701182
)
1171-
from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant
1172-
from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, Server
1183+
from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server
11731184
from aws_lambda_powertools.event_handler.openapi.types import (
11741185
COMPONENT_REF_TEMPLATE,
11751186
)
@@ -1213,13 +1224,8 @@ def get_openapi_schema(
12131224

12141225
# Add routes to the OpenAPI schema
12151226
for route in all_routes:
1216-
dependant = get_dependant(
1217-
path=route.path,
1218-
call=route.func,
1219-
)
1220-
12211227
result = route._get_openapi_path(
1222-
dependant=dependant,
1228+
dependant=route.dependant,
12231229
operation_ids=operation_ids,
12241230
model_name_map=model_name_map,
12251231
field_mapping=field_mapping,
@@ -1238,7 +1244,7 @@ def get_openapi_schema(
12381244
if tags:
12391245
output["tags"] = tags
12401246

1241-
output["paths"] = paths
1247+
output["paths"] = {k: PathItem(**v) for k, v in paths.items()}
12421248

12431249
return OpenAPI(**output)
12441250

@@ -1701,20 +1707,18 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]:
17011707
"""
17021708

17031709
from aws_lambda_powertools.event_handler.openapi.dependant import (
1704-
get_dependant,
17051710
get_flat_params,
17061711
)
17071712

17081713
responses_from_routes: List["ModelField"] = []
17091714
request_fields_from_routes: List["ModelField"] = []
17101715

17111716
for route in routes:
1712-
dependant = get_dependant(path=route.path, call=route.func)
1713-
params = get_flat_params(dependant)
1717+
params = get_flat_params(route.dependant)
17141718
request_fields_from_routes.extend(params)
17151719

1716-
if dependant.return_param:
1717-
responses_from_routes.append(dependant.return_param)
1720+
if route.dependant.return_param:
1721+
responses_from_routes.append(route.dependant.return_param)
17181722

17191723
flat_models = list(responses_from_routes + request_fields_from_routes)
17201724
return flat_models

Diff for: aws_lambda_powertools/event_handler/openapi/compat.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: ignore-errors
22
# flake8: noqa
3+
from copy import copy
34

45
# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different
56
# versions of a module, so we need to ignore errors here.
@@ -10,19 +11,27 @@
1011

1112
from typing_extensions import Annotated, Literal
1213

13-
from aws_lambda_powertools.event_handler.openapi.types import COMPONENT_REF_PREFIX, PYDANTIC_V2, ModelNameMap
14+
from pydantic import BaseModel
15+
from pydantic.fields import FieldInfo
16+
17+
from aws_lambda_powertools.event_handler.openapi.types import (
18+
COMPONENT_REF_PREFIX,
19+
PYDANTIC_V2,
20+
ModelNameMap,
21+
)
1422

1523
if PYDANTIC_V2:
1624
from pydantic import TypeAdapter
1725
from pydantic._internal._typing_extra import eval_type_lenient
1826
from pydantic.fields import FieldInfo
1927
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
20-
from pydantic_core import PydanticUndefined
28+
from pydantic_core import PydanticUndefined, PydanticUndefinedType
2129

2230
from aws_lambda_powertools.event_handler.openapi.types import IncEx
2331

2432
Undefined = PydanticUndefined
2533
Required = PydanticUndefined
34+
UndefinedType = PydanticUndefinedType
2635

2736
evaluate_forwardref = eval_type_lenient
2837

@@ -49,7 +58,7 @@ def default(self) -> Any:
4958
def type_(self) -> Any:
5059
return self.field_info.annotation
5160

52-
def __post__init__(self) -> None:
61+
def __post_init__(self) -> None:
5362
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
5463
Annotated[self.field_info.annotation, self.field_info],
5564
)
@@ -125,9 +134,15 @@ def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
125134
def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any:
126135
return annotation
127136

137+
def model_rebuild(model: Type[BaseModel]) -> None:
138+
model.model_rebuild()
139+
140+
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
141+
return type(field_info).from_annotation(annotation)
142+
128143
else:
129144
from pydantic import BaseModel
130-
from pydantic.fields import ModelField, Required, Undefined
145+
from pydantic.fields import ModelField, Required, Undefined, UndefinedType
131146
from pydantic.schema import (
132147
field_schema,
133148
get_annotation_from_field_info,
@@ -192,3 +207,9 @@ def get_model_definitions(
192207
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
193208
models = get_flat_models_from_fields(fields, known_models=set())
194209
return get_model_name_map(models)
210+
211+
def model_rebuild(model: Type[BaseModel]) -> None:
212+
model.update_forward_refs()
213+
214+
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
215+
return copy(field_info)

Diff for: aws_lambda_powertools/event_handler/openapi/dependant.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,12 @@ def get_dependant(
160160
is_path_param = param_name in path_param_names
161161

162162
# Analyze the parameter to get the Pydantic field.
163-
_, param_field = analyze_param(
163+
param_field = analyze_param(
164164
param_name=param_name,
165165
annotation=param.annotation,
166166
value=param.default,
167167
is_path_param=is_path_param,
168+
is_response_param=False,
168169
)
169170
if param_field is None:
170171
raise AssertionError(f"Param field is None for param: {param_name}")
@@ -174,11 +175,12 @@ def get_dependant(
174175
# If the return annotation is not empty, add it to the dependant model.
175176
return_annotation = endpoint_signature.return_annotation
176177
if return_annotation is not inspect.Signature.empty:
177-
_, param_field = analyze_param(
178-
param_name="Return",
178+
param_field = analyze_param(
179+
param_name="return",
179180
annotation=return_annotation,
180181
value=None,
181182
is_path_param=False,
183+
is_response_param=True,
182184
)
183185
if param_field is None:
184186
raise AssertionError("Param field is None for return annotation")

Diff for: aws_lambda_powertools/event_handler/openapi/models.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pydantic import AnyUrl, BaseModel, Field
55
from typing_extensions import Annotated, Literal
66

7+
from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild
78
from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2
89

910
"""
@@ -205,7 +206,7 @@ class Schema(BaseModel):
205206
deprecated: Optional[bool] = None
206207
readOnly: Optional[bool] = None
207208
writeOnly: Optional[bool] = None
208-
examples: Optional[List[Any]] = None
209+
examples: Optional[List["Example"]] = None
209210
# Ref: OpenAPI 3.1.0: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#schema-object
210211
# Schema Object
211212
discriminator: Optional[Discriminator] = None
@@ -577,6 +578,6 @@ class Config:
577578
extra = "allow"
578579

579580

580-
Schema.update_forward_refs()
581-
Operation.update_forward_refs()
582-
Encoding.update_forward_refs()
581+
model_rebuild(Schema)
582+
model_rebuild(Operation)
583+
model_rebuild(Encoding)

Diff for: aws_lambda_powertools/event_handler/openapi/params.py

+65-26
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import inspect
2-
from copy import copy
32
from enum import Enum
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
54

65
from pydantic import BaseConfig
76
from pydantic.fields import FieldInfo
8-
from typing_extensions import Annotated, get_args, get_origin
7+
from typing_extensions import Annotated, Literal, get_args, get_origin
98

109
from aws_lambda_powertools.event_handler.openapi.compat import (
1110
ModelField,
1211
Required,
1312
Undefined,
13+
UndefinedType,
14+
copy_field_info,
1415
get_annotation_from_field_info,
1516
)
1617
from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2, CacheKey
@@ -302,7 +303,8 @@ def analyze_param(
302303
annotation: Any,
303304
value: Any,
304305
is_path_param: bool,
305-
) -> Tuple[Any, Optional[ModelField]]:
306+
is_response_param: bool,
307+
) -> Optional[ModelField]:
306308
"""
307309
Analyze a parameter annotation and value to determine the type and default value of the parameter.
308310
@@ -316,10 +318,12 @@ def analyze_param(
316318
The value of the parameter
317319
is_path_param
318320
Whether the parameter is a path parameter
321+
is_response_param
322+
Whether the parameter is the return annotation
319323
320324
Returns
321325
-------
322-
Tuple[Any, Optional[ModelField]]
326+
Optional[ModelField]
323327
The type annotation and the Pydantic field representing the parameter
324328
"""
325329
field_info, type_annotation = _get_field_info_and_type_annotation(annotation, value, is_path_param)
@@ -336,12 +340,16 @@ def analyze_param(
336340

337341
# Check if the parameter is part of the path. Otherwise, defaults to query.
338342
if is_path_param:
339-
field_info = Path(annotation=type_annotation, default=default_value)
343+
field_info = Path(annotation=type_annotation)
340344
else:
341345
field_info = Query(annotation=type_annotation, default=default_value)
342346

347+
# When we have a response field, we need to set the default value to Required
348+
if is_response_param:
349+
field_info.default = Required
350+
343351
field = _create_model_field(field_info, type_annotation, param_name, is_path_param)
344-
return type_annotation, field
352+
return field
345353

346354

347355
def _get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]:
@@ -372,7 +380,10 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu
372380

373381
if isinstance(powertools_annotation, FieldInfo):
374382
# Copy `field_info` because we mutate `field_info.default` later
375-
field_info = copy(powertools_annotation)
383+
field_info = copy_field_info(
384+
field_info=powertools_annotation,
385+
annotation=annotation,
386+
)
376387
if field_info.default not in [Undefined, Required]:
377388
raise AssertionError("FieldInfo needs to have a default value of Undefined or Required")
378389

@@ -386,6 +397,44 @@ def _get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tu
386397
return field_info, type_annotation
387398

388399

400+
def _create_response_field(
401+
name: str,
402+
type_: Type[Any],
403+
default: Optional[Any] = Undefined,
404+
required: Union[bool, UndefinedType] = Undefined,
405+
model_config: Type[BaseConfig] = BaseConfig,
406+
field_info: Optional[FieldInfo] = None,
407+
alias: Optional[str] = None,
408+
mode: Literal["validation", "serialization"] = "validation",
409+
) -> ModelField:
410+
"""
411+
Create a new response field. Raises if type_ is invalid.
412+
"""
413+
if PYDANTIC_V2:
414+
field_info = field_info or FieldInfo(
415+
annotation=type_,
416+
default=default,
417+
alias=alias,
418+
)
419+
else:
420+
field_info = field_info or FieldInfo()
421+
kwargs = {"name": name, "field_info": field_info}
422+
if PYDANTIC_V2:
423+
kwargs.update({"mode": mode})
424+
else:
425+
kwargs.update(
426+
{
427+
"type_": type_,
428+
"class_validators": {},
429+
"default": default,
430+
"required": required,
431+
"model_config": model_config,
432+
"alias": alias,
433+
},
434+
)
435+
return ModelField(**kwargs) # type: ignore[arg-type]
436+
437+
389438
def _create_model_field(
390439
field_info: Optional[FieldInfo],
391440
type_annotation: Any,
@@ -411,21 +460,11 @@ def _create_model_field(
411460
alias = field_info.alias or param_name
412461
field_info.alias = alias
413462

414-
# Create the Pydantic field
415-
kwargs = {"name": param_name, "field_info": field_info}
416-
417-
if PYDANTIC_V2:
418-
kwargs.update({"mode": "validation"})
419-
else:
420-
kwargs.update(
421-
{
422-
"type_": use_annotation,
423-
"class_validators": {},
424-
"default": field_info.default,
425-
"required": field_info.default in (Required, Undefined),
426-
"model_config": BaseConfig,
427-
"alias": alias,
428-
},
429-
)
430-
431-
return ModelField(**kwargs) # type: ignore[arg-type]
463+
return _create_response_field(
464+
name=param_name,
465+
type_=use_annotation,
466+
default=field_info.default,
467+
alias=alias,
468+
required=field_info.default in (Required, Undefined),
469+
field_info=field_info,
470+
)

Diff for: tests/functional/event_handler/test_openapi_params.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_openapi_with_custom_params():
9898
def handler(
9999
count: Annotated[
100100
int,
101-
Query(lt=100, gt=0, examples=[Example(summary="Example 1", value=10)]),
101+
Query(gt=0, lt=100, examples=[Example(summary="Example 1", value=10)]),
102102
] = 1,
103103
):
104104
raise NotImplementedError()

0 commit comments

Comments
 (0)