Skip to content

Commit 8d48b3d

Browse files
committed
refactor(openapi): add from __future__ import annotations
and update code according to ruff rules TCH, UP006, UP007, UP037 and FA100.
1 parent 456bf82 commit 8d48b3d

File tree

8 files changed

+492
-473
lines changed

8 files changed

+492
-473
lines changed

aws_lambda_powertools/event_handler/openapi/compat.py

+30-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# mypy: ignore-errors
22
# flake8: noqa
3+
from __future__ import annotations
4+
35
from collections import deque
46
from copy import copy
57

@@ -8,7 +10,7 @@
810

911
from dataclasses import dataclass, is_dataclass
1012
from enum import Enum
11-
from typing import Any, Dict, List, Set, Tuple, Type, Union, FrozenSet, Deque, Sequence, Mapping
13+
from typing import Any, Deque, FrozenSet, List, Mapping, Sequence, Set, Tuple, Union
1214

1315
from typing_extensions import Annotated, Literal, get_origin, get_args
1416

@@ -56,7 +58,7 @@
5658

5759
sequence_types = tuple(sequence_annotation_to_type.keys())
5860

59-
RequestErrorModel: Type[BaseModel] = create_model("Request")
61+
RequestErrorModel: type[BaseModel] = create_model("Request")
6062

6163

6264
class ErrorWrapper(Exception):
@@ -101,8 +103,8 @@ def serialize(
101103
value: Any,
102104
*,
103105
mode: Literal["json", "python"] = "json",
104-
include: Union[IncEx, None] = None,
105-
exclude: Union[IncEx, None] = None,
106+
include: IncEx | None = None,
107+
exclude: IncEx | None = None,
106108
by_alias: bool = True,
107109
exclude_unset: bool = False,
108110
exclude_defaults: bool = False,
@@ -120,8 +122,8 @@ def serialize(
120122
)
121123

122124
def validate(
123-
self, value: Any, values: Dict[str, Any] = {}, *, loc: Tuple[Union[int, str], ...] = ()
124-
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
125+
self, value: Any, values: dict[str, Any] = {}, *, loc: tuple[int | str, ...] = ()
126+
) -> tuple[Any, list[dict[str, Any]] | None]:
125127
try:
126128
return (self._type_adapter.validate_python(value, from_attributes=True), None)
127129
except ValidationError as exc:
@@ -136,11 +138,11 @@ def get_schema_from_model_field(
136138
*,
137139
field: ModelField,
138140
model_name_map: ModelNameMap,
139-
field_mapping: Dict[
140-
Tuple[ModelField, Literal["validation", "serialization"]],
141+
field_mapping: dict[
142+
tuple[ModelField, Literal["validation", "serialization"]],
141143
JsonSchemaValue,
142144
],
143-
) -> Dict[str, Any]:
145+
) -> dict[str, Any]:
144146
json_schema = field_mapping[(field, field.mode)]
145147
if "$ref" not in json_schema:
146148
# MAINTENANCE: remove when deprecating Pydantic v1
@@ -151,39 +153,39 @@ def get_schema_from_model_field(
151153

152154
def get_definitions(
153155
*,
154-
fields: List[ModelField],
156+
fields: list[ModelField],
155157
schema_generator: GenerateJsonSchema,
156158
model_name_map: ModelNameMap,
157-
) -> Tuple[
158-
Dict[
159-
Tuple[ModelField, Literal["validation", "serialization"]],
160-
Dict[str, Any],
159+
) -> tuple[
160+
dict[
161+
tuple[ModelField, Literal["validation", "serialization"]],
162+
dict[str, Any],
161163
],
162-
Dict[str, Dict[str, Any]],
164+
dict[str, dict[str, Any]],
163165
]:
164166
inputs = [(field, field.mode, field._type_adapter.core_schema) for field in fields]
165167
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
166168

167169
return field_mapping, definitions
168170

169171

170-
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
172+
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
171173
return {}
172174

173175

174176
def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any:
175177
return annotation
176178

177179

178-
def model_rebuild(model: Type[BaseModel]) -> None:
180+
def model_rebuild(model: type[BaseModel]) -> None:
179181
model.model_rebuild()
180182

181183

182184
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
183185
return type(field_info).from_annotation(annotation)
184186

185187

186-
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
188+
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
187189
error = ValidationError.from_exception_data(
188190
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
189191
).errors()[0]
@@ -220,13 +222,13 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
220222
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
221223

222224

223-
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
225+
def _normalize_errors(errors: Sequence[Any]) -> list[dict[str, Any]]:
224226
return errors # type: ignore[return-value]
225227

226228

227-
def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[BaseModel]:
229+
def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> type[BaseModel]:
228230
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
229-
model: Type[BaseModel] = create_model(model_name, **field_params)
231+
model: type[BaseModel] = create_model(model_name, **field_params)
230232
return model
231233

232234

@@ -241,7 +243,7 @@ def model_json(model: BaseModel, **kwargs: Any) -> Any:
241243
# Common code for both versions
242244

243245

244-
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
246+
def field_annotation_is_complex(annotation: type[Any] | None) -> bool:
245247
origin = get_origin(annotation)
246248
if origin is Union or origin is UnionType:
247249
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
@@ -258,11 +260,11 @@ def field_annotation_is_scalar(annotation: Any) -> bool:
258260
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
259261

260262

261-
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
263+
def field_annotation_is_sequence(annotation: type[Any] | None) -> bool:
262264
return _annotation_is_sequence(annotation) or _annotation_is_sequence(get_origin(annotation))
263265

264266

265-
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
267+
def field_annotation_is_scalar_sequence(annotation: type[Any] | None) -> bool:
266268
origin = get_origin(annotation)
267269
if origin is Union or origin is UnionType:
268270
at_least_one_scalar_sequence = False
@@ -307,24 +309,22 @@ def value_is_sequence(value: Any) -> bool:
307309
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
308310

309311

310-
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
312+
def _annotation_is_complex(annotation: type[Any] | None) -> bool:
311313
return (
312314
lenient_issubclass(annotation, (BaseModel, Mapping)) # TODO: UploadFile
313315
or _annotation_is_sequence(annotation)
314316
or is_dataclass(annotation)
315317
)
316318

317319

318-
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
320+
def _annotation_is_sequence(annotation: type[Any] | None) -> bool:
319321
if lenient_issubclass(annotation, (str, bytes)):
320322
return False
321323
return lenient_issubclass(annotation, sequence_types)
322324

323325

324-
def _regenerate_error_with_loc(
325-
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
326-
) -> List[Dict[str, Any]]:
327-
updated_loc_errors: List[Any] = [
326+
def _regenerate_error_with_loc(*, errors: Sequence[Any], loc_prefix: tuple[str | int, ...]) -> list[dict[str, Any]]:
327+
updated_loc_errors: list[Any] = [
328328
{**err, "loc": loc_prefix + err.get("loc", ())} for err in _normalize_errors(errors)
329329
]
330330

aws_lambda_powertools/event_handler/openapi/dependant.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
from __future__ import annotations
2+
13
import inspect
24
import re
3-
from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, cast
4-
5-
from pydantic import BaseModel
5+
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, cast
66

77
from aws_lambda_powertools.event_handler.openapi.compat import (
88
ModelField,
@@ -26,6 +26,9 @@
2626
)
2727
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel
2828

29+
if TYPE_CHECKING:
30+
from pydantic import BaseModel
31+
2932
"""
3033
This turns the opaque function signature into typed, validated models.
3134
@@ -76,7 +79,7 @@ def add_param_to_fields(
7679
raise AssertionError(f"Unsupported param type: {field_info.in_}")
7780

7881

79-
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
82+
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
8083
"""
8184
Evaluates a type annotation, which can be a string or a ForwardRef.
8285
"""
@@ -128,7 +131,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
128131
return inspect.Signature(typed_params)
129132

130133

131-
def get_path_param_names(path: str) -> Set[str]:
134+
def get_path_param_names(path: str) -> set[str]:
132135
"""
133136
Returns the path parameter names from a path template. Those are the strings between { and }.
134137
@@ -139,7 +142,7 @@ def get_path_param_names(path: str) -> Set[str]:
139142
140143
Returns
141144
-------
142-
Set[str]
145+
set[str]
143146
The path parameter names
144147
145148
"""
@@ -150,8 +153,8 @@ def get_dependant(
150153
*,
151154
path: str,
152155
call: Callable[..., Any],
153-
name: Optional[str] = None,
154-
responses: Optional[Dict[int, OpenAPIResponse]] = None,
156+
name: str | None = None,
157+
responses: dict[int, OpenAPIResponse] | None = None,
155158
) -> Dependant:
156159
"""
157160
Returns a dependant model for a handler function. A dependant model is a model that contains
@@ -165,7 +168,7 @@ def get_dependant(
165168
The handler function
166169
name: str, optional
167170
The name of the handler function
168-
responses: List[Dict[int, OpenAPIResponse]], optional
171+
responses: list[dict[int, OpenAPIResponse]], optional
169172
The list of extra responses for the handler function
170173
171174
Returns
@@ -210,7 +213,7 @@ def get_dependant(
210213
return dependant
211214

212215

213-
def _add_extra_responses(dependant: Dependant, responses: Optional[Dict[int, OpenAPIResponse]]):
216+
def _add_extra_responses(dependant: Dependant, responses: dict[int, OpenAPIResponse] | None):
214217
# Also add the optional extra responses to the dependant model.
215218
if not responses:
216219
return
@@ -278,7 +281,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
278281
return True
279282

280283

281-
def get_flat_params(dependant: Dependant) -> List[ModelField]:
284+
def get_flat_params(dependant: Dependant) -> list[ModelField]:
282285
"""
283286
Get a list of all the parameters from a Dependant object.
284287
@@ -289,7 +292,7 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
289292
290293
Returns
291294
-------
292-
List[ModelField]
295+
list[ModelField]
293296
A list of ModelField objects containing the flat parameters from the Dependant object.
294297
295298
"""
@@ -302,7 +305,7 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
302305
)
303306

304307

305-
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
308+
def get_body_field(*, dependant: Dependant, name: str) -> ModelField | None:
306309
"""
307310
Get the Body field for a given Dependant object.
308311
"""
@@ -348,24 +351,24 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
348351

349352
def get_body_field_info(
350353
*,
351-
body_model: Type[BaseModel],
354+
body_model: type[BaseModel],
352355
flat_dependant: Dependant,
353356
required: bool,
354-
) -> Tuple[Type[Body], Dict[str, Any]]:
357+
) -> tuple[type[Body], dict[str, Any]]:
355358
"""
356359
Get the Body field info and kwargs for a given body model.
357360
"""
358361

359-
body_field_info_kwargs: Dict[str, Any] = {"annotation": body_model, "alias": "body"}
362+
body_field_info_kwargs: dict[str, Any] = {"annotation": body_model, "alias": "body"}
360363

361364
if not required:
362365
body_field_info_kwargs["default"] = None
363366

364367
if any(isinstance(f.field_info, _File) for f in flat_dependant.body_params):
365-
# MAINTENANCE: body_field_info: Type[Body] = _File
368+
# MAINTENANCE: body_field_info: type[Body] = _File
366369
raise NotImplementedError("_File fields are not supported in request bodies")
367370
elif any(isinstance(f.field_info, _Form) for f in flat_dependant.body_params):
368-
# MAINTENANCE: body_field_info: Type[Body] = _Form
371+
# MAINTENANCE: body_field_info: type[Body] = _Form
369372
raise NotImplementedError("_Form fields are not supported in request bodies")
370373
else:
371374
body_field_info = Body

0 commit comments

Comments
 (0)