Skip to content

refactor(openapi): add from __future__ import annotations #4990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions aws_lambda_powertools/event_handler/openapi/compat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mypy: ignore-errors
# flake8: noqa
from __future__ import annotations

from collections import deque
from copy import copy

Expand All @@ -8,7 +10,7 @@

from dataclasses import dataclass, is_dataclass
from enum import Enum
from typing import Any, Dict, List, Set, Tuple, Type, Union, FrozenSet, Deque, Sequence, Mapping
from typing import Any, Deque, FrozenSet, List, Mapping, Sequence, Set, Tuple, Union

from typing_extensions import Annotated, Literal, get_origin, get_args

Expand Down Expand Up @@ -56,7 +58,7 @@

sequence_types = tuple(sequence_annotation_to_type.keys())

RequestErrorModel: Type[BaseModel] = create_model("Request")
RequestErrorModel: type[BaseModel] = create_model("Request")


class ErrorWrapper(Exception):
Expand Down Expand Up @@ -101,8 +103,8 @@ def serialize(
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
include: IncEx | None = None,
exclude: IncEx | None = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
Expand All @@ -120,8 +122,8 @@ def serialize(
)

def validate(
self, value: Any, values: Dict[str, Any] = {}, *, loc: Tuple[Union[int, str], ...] = ()
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
self, value: Any, values: dict[str, Any] = {}, *, loc: tuple[int | str, ...] = ()
) -> tuple[Any, list[dict[str, Any]] | None]:
try:
return (self._type_adapter.validate_python(value, from_attributes=True), None)
except ValidationError as exc:
Expand All @@ -136,11 +138,11 @@ def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]],
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]],
JsonSchemaValue,
],
) -> Dict[str, Any]:
) -> dict[str, Any]:
json_schema = field_mapping[(field, field.mode)]
if "$ref" not in json_schema:
# MAINTENANCE: remove when deprecating Pydantic v1
Expand All @@ -151,39 +153,39 @@ def get_schema_from_model_field(

def get_definitions(
*,
fields: List[ModelField],
fields: list[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]],
Dict[str, Any],
) -> tuple[
dict[
tuple[ModelField, Literal["validation", "serialization"]],
dict[str, Any],
],
Dict[str, Dict[str, Any]],
dict[str, dict[str, Any]],
]:
inputs = [(field, field.mode, field._type_adapter.core_schema) for field in fields]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)

return field_mapping, definitions


def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
return {}


def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any:
return annotation


def model_rebuild(model: Type[BaseModel]) -> None:
def model_rebuild(model: type[BaseModel]) -> None:
model.model_rebuild()


def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return type(field_info).from_annotation(annotation)


def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors()[0]
Expand Down Expand Up @@ -220,13 +222,13 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]


def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
def _normalize_errors(errors: Sequence[Any]) -> list[dict[str, Any]]:
return errors # type: ignore[return-value]


def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[BaseModel]:
def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
model: Type[BaseModel] = create_model(model_name, **field_params)
model: type[BaseModel] = create_model(model_name, **field_params)
return model


Expand All @@ -241,7 +243,7 @@ def model_json(model: BaseModel, **kwargs: Any) -> Any:
# Common code for both versions


def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_complex(annotation: type[Any] | None) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
Expand All @@ -258,11 +260,11 @@ def field_annotation_is_scalar(annotation: Any) -> bool:
return annotation is Ellipsis or not field_annotation_is_complex(annotation)


def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_sequence(annotation: type[Any] | None) -> bool:
return _annotation_is_sequence(annotation) or _annotation_is_sequence(get_origin(annotation))


def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_scalar_sequence(annotation: type[Any] | None) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False
Expand Down Expand Up @@ -307,24 +309,22 @@ def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]


def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def _annotation_is_complex(annotation: type[Any] | None) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping)) # TODO: UploadFile
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)


def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def _annotation_is_sequence(annotation: type[Any] | None) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types)


def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
def _regenerate_error_with_loc(*, errors: Sequence[Any], loc_prefix: tuple[str | int, ...]) -> list[dict[str, Any]]:
updated_loc_errors: list[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())} for err in _normalize_errors(errors)
]

Expand Down
39 changes: 21 additions & 18 deletions aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import inspect
import re
from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, cast

from pydantic import BaseModel
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, cast

from aws_lambda_powertools.event_handler.openapi.compat import (
ModelField,
Expand All @@ -26,6 +26,9 @@
)
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel

if TYPE_CHECKING:
from pydantic import BaseModel

"""
This turns the opaque function signature into typed, validated models.

Expand Down Expand Up @@ -76,7 +79,7 @@ def add_param_to_fields(
raise AssertionError(f"Unsupported param type: {field_info.in_}")


def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
"""
Evaluates a type annotation, which can be a string or a ForwardRef.
"""
Expand Down Expand Up @@ -128,7 +131,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
return inspect.Signature(typed_params)


def get_path_param_names(path: str) -> Set[str]:
def get_path_param_names(path: str) -> set[str]:
"""
Returns the path parameter names from a path template. Those are the strings between { and }.

Expand All @@ -139,7 +142,7 @@ def get_path_param_names(path: str) -> Set[str]:

Returns
-------
Set[str]
set[str]
The path parameter names

"""
Expand All @@ -150,8 +153,8 @@ def get_dependant(
*,
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
name: str | None = None,
responses: dict[int, OpenAPIResponse] | None = None,
) -> Dependant:
"""
Returns a dependant model for a handler function. A dependant model is a model that contains
Expand All @@ -165,7 +168,7 @@ def get_dependant(
The handler function
name: str, optional
The name of the handler function
responses: List[Dict[int, OpenAPIResponse]], optional
responses: list[dict[int, OpenAPIResponse]], optional
The list of extra responses for the handler function

Returns
Expand Down Expand Up @@ -210,7 +213,7 @@ def get_dependant(
return dependant


def _add_extra_responses(dependant: Dependant, responses: Optional[Dict[int, OpenAPIResponse]]):
def _add_extra_responses(dependant: Dependant, responses: dict[int, OpenAPIResponse] | None):
# Also add the optional extra responses to the dependant model.
if not responses:
return
Expand Down Expand Up @@ -278,7 +281,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
return True


def get_flat_params(dependant: Dependant) -> List[ModelField]:
def get_flat_params(dependant: Dependant) -> list[ModelField]:
"""
Get a list of all the parameters from a Dependant object.

Expand All @@ -289,7 +292,7 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:

Returns
-------
List[ModelField]
list[ModelField]
A list of ModelField objects containing the flat parameters from the Dependant object.

"""
Expand All @@ -302,7 +305,7 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
)


def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
def get_body_field(*, dependant: Dependant, name: str) -> ModelField | None:
"""
Get the Body field for a given Dependant object.
"""
Expand Down Expand Up @@ -348,24 +351,24 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:

def get_body_field_info(
*,
body_model: Type[BaseModel],
body_model: type[BaseModel],
flat_dependant: Dependant,
required: bool,
) -> Tuple[Type[Body], Dict[str, Any]]:
) -> tuple[type[Body], dict[str, Any]]:
"""
Get the Body field info and kwargs for a given body model.
"""

body_field_info_kwargs: Dict[str, Any] = {"annotation": body_model, "alias": "body"}
body_field_info_kwargs: dict[str, Any] = {"annotation": body_model, "alias": "body"}

if not required:
body_field_info_kwargs["default"] = None

if any(isinstance(f.field_info, _File) for f in flat_dependant.body_params):
# MAINTENANCE: body_field_info: Type[Body] = _File
# MAINTENANCE: body_field_info: type[Body] = _File
raise NotImplementedError("_File fields are not supported in request bodies")
elif any(isinstance(f.field_info, _Form) for f in flat_dependant.body_params):
# MAINTENANCE: body_field_info: Type[Body] = _Form
# MAINTENANCE: body_field_info: type[Body] = _Form
raise NotImplementedError("_Form fields are not supported in request bodies")
else:
body_field_info = Body
Expand Down
Loading
Loading