Skip to content

Commit 689072f

Browse files
refactor(openapi): add from __future__ import annotations (#4990)
* refactor(openapi): add from __future__ import annotations and update code according to ruff rules TCH, UP006, UP007, UP037 and FA100. * Fix type alias with Python 3.8 See https://bugs.python.org/issue45117 * Fix pydantic not working with Python 3.8 TypeError: You have a type annotation 'str | None' which makes use of newer typing features than are supported in your version of Python. To handle this error, you should either remove the use of new syntax or install the `eval_type_backport` package. * Removing pydantic v1 reference --------- Co-authored-by: Leandro Damascena <[email protected]>
1 parent 161a5a1 commit 689072f

File tree

9 files changed

+316
-301
lines changed

9 files changed

+316
-301
lines changed

aws_lambda_powertools/event_handler/openapi/compat.py

+30-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# mypy: ignore-errors
22
# flake8: noqa
3+
from __future__ import annotations
4+
35
from collections import deque
46
from copy import copy
57

@@ -8,7 +10,7 @@
810

911
from dataclasses import dataclass, is_dataclass
1012
from enum import Enum
11-
from typing import Any, Dict, List, Set, Tuple, Type, Union, FrozenSet, Deque, Sequence, Mapping
13+
from typing import Any, Deque, FrozenSet, List, Mapping, Sequence, Set, Tuple, Union
1214

1315
from typing_extensions import Annotated, Literal, get_origin, get_args
1416

@@ -56,7 +58,7 @@
5658

5759
sequence_types = tuple(sequence_annotation_to_type.keys())
5860

59-
RequestErrorModel: Type[BaseModel] = create_model("Request")
61+
RequestErrorModel: type[BaseModel] = create_model("Request")
6062

6163

6264
class ErrorWrapper(Exception):
@@ -101,8 +103,8 @@ def serialize(
101103
value: Any,
102104
*,
103105
mode: Literal["json", "python"] = "json",
104-
include: Union[IncEx, None] = None,
105-
exclude: Union[IncEx, None] = None,
106+
include: IncEx | None = None,
107+
exclude: IncEx | None = None,
106108
by_alias: bool = True,
107109
exclude_unset: bool = False,
108110
exclude_defaults: bool = False,
@@ -120,8 +122,8 @@ def serialize(
120122
)
121123

122124
def validate(
123-
self, value: Any, values: Dict[str, Any] = {}, *, loc: Tuple[Union[int, str], ...] = ()
124-
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
125+
self, value: Any, values: dict[str, Any] = {}, *, loc: tuple[int | str, ...] = ()
126+
) -> tuple[Any, list[dict[str, Any]] | None]:
125127
try:
126128
return (self._type_adapter.validate_python(value, from_attributes=True), None)
127129
except ValidationError as exc:
@@ -136,11 +138,11 @@ def get_schema_from_model_field(
136138
*,
137139
field: ModelField,
138140
model_name_map: ModelNameMap,
139-
field_mapping: Dict[
140-
Tuple[ModelField, Literal["validation", "serialization"]],
141+
field_mapping: dict[
142+
tuple[ModelField, Literal["validation", "serialization"]],
141143
JsonSchemaValue,
142144
],
143-
) -> Dict[str, Any]:
145+
) -> dict[str, Any]:
144146
json_schema = field_mapping[(field, field.mode)]
145147
if "$ref" not in json_schema:
146148
# MAINTENANCE: remove when deprecating Pydantic v1
@@ -151,39 +153,39 @@ def get_schema_from_model_field(
151153

152154
def get_definitions(
153155
*,
154-
fields: List[ModelField],
156+
fields: list[ModelField],
155157
schema_generator: GenerateJsonSchema,
156158
model_name_map: ModelNameMap,
157-
) -> Tuple[
158-
Dict[
159-
Tuple[ModelField, Literal["validation", "serialization"]],
160-
Dict[str, Any],
159+
) -> tuple[
160+
dict[
161+
tuple[ModelField, Literal["validation", "serialization"]],
162+
dict[str, Any],
161163
],
162-
Dict[str, Dict[str, Any]],
164+
dict[str, dict[str, Any]],
163165
]:
164166
inputs = [(field, field.mode, field._type_adapter.core_schema) for field in fields]
165167
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
166168

167169
return field_mapping, definitions
168170

169171

170-
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
172+
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
171173
return {}
172174

173175

174176
def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any:
175177
return annotation
176178

177179

178-
def model_rebuild(model: Type[BaseModel]) -> None:
180+
def model_rebuild(model: type[BaseModel]) -> None:
179181
model.model_rebuild()
180182

181183

182184
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
183185
return type(field_info).from_annotation(annotation)
184186

185187

186-
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
188+
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
187189
error = ValidationError.from_exception_data(
188190
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
189191
).errors()[0]
@@ -220,13 +222,13 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
220222
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
221223

222224

223-
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
225+
def _normalize_errors(errors: Sequence[Any]) -> list[dict[str, Any]]:
224226
return errors # type: ignore[return-value]
225227

226228

227-
def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[BaseModel]:
229+
def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> type[BaseModel]:
228230
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
229-
model: Type[BaseModel] = create_model(model_name, **field_params)
231+
model: type[BaseModel] = create_model(model_name, **field_params)
230232
return model
231233

232234

@@ -241,7 +243,7 @@ def model_json(model: BaseModel, **kwargs: Any) -> Any:
241243
# Common code for both versions
242244

243245

244-
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
246+
def field_annotation_is_complex(annotation: type[Any] | None) -> bool:
245247
origin = get_origin(annotation)
246248
if origin is Union or origin is UnionType:
247249
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
@@ -258,11 +260,11 @@ def field_annotation_is_scalar(annotation: Any) -> bool:
258260
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
259261

260262

261-
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
263+
def field_annotation_is_sequence(annotation: type[Any] | None) -> bool:
262264
return _annotation_is_sequence(annotation) or _annotation_is_sequence(get_origin(annotation))
263265

264266

265-
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
267+
def field_annotation_is_scalar_sequence(annotation: type[Any] | None) -> bool:
266268
origin = get_origin(annotation)
267269
if origin is Union or origin is UnionType:
268270
at_least_one_scalar_sequence = False
@@ -307,24 +309,22 @@ def value_is_sequence(value: Any) -> bool:
307309
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
308310

309311

310-
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
312+
def _annotation_is_complex(annotation: type[Any] | None) -> bool:
311313
return (
312314
lenient_issubclass(annotation, (BaseModel, Mapping)) # TODO: UploadFile
313315
or _annotation_is_sequence(annotation)
314316
or is_dataclass(annotation)
315317
)
316318

317319

318-
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
320+
def _annotation_is_sequence(annotation: type[Any] | None) -> bool:
319321
if lenient_issubclass(annotation, (str, bytes)):
320322
return False
321323
return lenient_issubclass(annotation, sequence_types)
322324

323325

324-
def _regenerate_error_with_loc(
325-
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
326-
) -> List[Dict[str, Any]]:
327-
updated_loc_errors: List[Any] = [
326+
def _regenerate_error_with_loc(*, errors: Sequence[Any], loc_prefix: tuple[str | int, ...]) -> list[dict[str, Any]]:
327+
updated_loc_errors: list[Any] = [
328328
{**err, "loc": loc_prefix + err.get("loc", ())} for err in _normalize_errors(errors)
329329
]
330330

aws_lambda_powertools/event_handler/openapi/dependant.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
from __future__ import annotations
2+
13
import inspect
24
import re
3-
from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, cast
4-
5-
from pydantic import BaseModel
5+
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, cast
66

77
from aws_lambda_powertools.event_handler.openapi.compat import (
88
ModelField,
@@ -26,6 +26,9 @@
2626
)
2727
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel
2828

29+
if TYPE_CHECKING:
30+
from pydantic import BaseModel
31+
2932
"""
3033
This turns the opaque function signature into typed, validated models.
3134
@@ -76,7 +79,7 @@ def add_param_to_fields(
7679
raise AssertionError(f"Unsupported param type: {field_info.in_}")
7780

7881

79-
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
82+
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
8083
"""
8184
Evaluates a type annotation, which can be a string or a ForwardRef.
8285
"""
@@ -128,7 +131,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
128131
return inspect.Signature(typed_params)
129132

130133

131-
def get_path_param_names(path: str) -> Set[str]:
134+
def get_path_param_names(path: str) -> set[str]:
132135
"""
133136
Returns the path parameter names from a path template. Those are the strings between { and }.
134137
@@ -139,7 +142,7 @@ def get_path_param_names(path: str) -> Set[str]:
139142
140143
Returns
141144
-------
142-
Set[str]
145+
set[str]
143146
The path parameter names
144147
145148
"""
@@ -150,8 +153,8 @@ def get_dependant(
150153
*,
151154
path: str,
152155
call: Callable[..., Any],
153-
name: Optional[str] = None,
154-
responses: Optional[Dict[int, OpenAPIResponse]] = None,
156+
name: str | None = None,
157+
responses: dict[int, OpenAPIResponse] | None = None,
155158
) -> Dependant:
156159
"""
157160
Returns a dependant model for a handler function. A dependant model is a model that contains
@@ -165,7 +168,7 @@ def get_dependant(
165168
The handler function
166169
name: str, optional
167170
The name of the handler function
168-
responses: List[Dict[int, OpenAPIResponse]], optional
171+
responses: list[dict[int, OpenAPIResponse]], optional
169172
The list of extra responses for the handler function
170173
171174
Returns
@@ -210,7 +213,7 @@ def get_dependant(
210213
return dependant
211214

212215

213-
def _add_extra_responses(dependant: Dependant, responses: Optional[Dict[int, OpenAPIResponse]]):
216+
def _add_extra_responses(dependant: Dependant, responses: dict[int, OpenAPIResponse] | None):
214217
# Also add the optional extra responses to the dependant model.
215218
if not responses:
216219
return
@@ -278,7 +281,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
278281
return True
279282

280283

281-
def get_flat_params(dependant: Dependant) -> List[ModelField]:
284+
def get_flat_params(dependant: Dependant) -> list[ModelField]:
282285
"""
283286
Get a list of all the parameters from a Dependant object.
284287
@@ -289,7 +292,7 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
289292
290293
Returns
291294
-------
292-
List[ModelField]
295+
list[ModelField]
293296
A list of ModelField objects containing the flat parameters from the Dependant object.
294297
295298
"""
@@ -302,7 +305,7 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
302305
)
303306

304307

305-
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
308+
def get_body_field(*, dependant: Dependant, name: str) -> ModelField | None:
306309
"""
307310
Get the Body field for a given Dependant object.
308311
"""
@@ -348,24 +351,24 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
348351

349352
def get_body_field_info(
350353
*,
351-
body_model: Type[BaseModel],
354+
body_model: type[BaseModel],
352355
flat_dependant: Dependant,
353356
required: bool,
354-
) -> Tuple[Type[Body], Dict[str, Any]]:
357+
) -> tuple[type[Body], dict[str, Any]]:
355358
"""
356359
Get the Body field info and kwargs for a given body model.
357360
"""
358361

359-
body_field_info_kwargs: Dict[str, Any] = {"annotation": body_model, "alias": "body"}
362+
body_field_info_kwargs: dict[str, Any] = {"annotation": body_model, "alias": "body"}
360363

361364
if not required:
362365
body_field_info_kwargs["default"] = None
363366

364367
if any(isinstance(f.field_info, _File) for f in flat_dependant.body_params):
365-
# MAINTENANCE: body_field_info: Type[Body] = _File
368+
# MAINTENANCE: body_field_info: type[Body] = _File
366369
raise NotImplementedError("_File fields are not supported in request bodies")
367370
elif any(isinstance(f.field_info, _Form) for f in flat_dependant.body_params):
368-
# MAINTENANCE: body_field_info: Type[Body] = _Form
371+
# MAINTENANCE: body_field_info: type[Body] = _Form
369372
raise NotImplementedError("_Form fields are not supported in request bodies")
370373
else:
371374
body_field_info = Body

0 commit comments

Comments
 (0)