Skip to content

refactor(event_handler): use standard collections for types + refactor code #6495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functools import partial
from http import HTTPStatus
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Mapping, Match, Pattern, Sequence, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, Literal, Match, Pattern, TypeVar, cast

from typing_extensions import override

Expand Down Expand Up @@ -59,6 +59,9 @@
)
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent

if TYPE_CHECKING:
from collections.abc import Callable, Mapping, Sequence

logger = logging.getLogger(__name__)

_DYNAMIC_ROUTE_PATTERN = r"(<\w+>)"
Expand All @@ -68,6 +71,7 @@
_NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response"
_ROUTE_REGEX = "^{}$"
_JSON_DUMP_CALL = partial(json.dumps, separators=(",", ":"), cls=Encoder)

ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
ResponseT = TypeVar("ResponseT")
Expand Down Expand Up @@ -830,7 +834,7 @@ class ResponseBuilder(Generic[ResponseEventT]):
def __init__(
self,
response: Response,
serializer: Callable[[Any], str] = partial(json.dumps, separators=(",", ":"), cls=Encoder),
serializer: Callable[[Any], str] = _JSON_DUMP_CALL,
route: Route | None = None,
):
self.response = response
Expand Down Expand Up @@ -1723,8 +1727,9 @@ def get_openapi_schema(
security = security or self.openapi_config.security
openapi_extensions = openapi_extensions or self.openapi_config.openapi_extensions

from pydantic.json_schema import GenerateJsonSchema

from aws_lambda_powertools.event_handler.openapi.compat import (
GenerateJsonSchema,
get_compat_model_name_map,
get_definitions,
)
Expand Down
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import asyncio
import logging
import warnings
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.event_handler.graphql_appsync.exceptions import InvalidBatchResponse, ResolverNotFoundError
from aws_lambda_powertools.event_handler.graphql_appsync.router import Router
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.utilities.typing import LambdaContext

from aws_lambda_powertools.warnings import PowertoolsUserWarning
Expand Down
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from typing_extensions import override

Expand All @@ -14,6 +14,7 @@
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION

if TYPE_CHECKING:
from collections.abc import Callable
from http import HTTPStatus
from re import Match

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import logging
from typing import Any, Callable
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Callable

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable


class BaseRouter(ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING

from aws_lambda_powertools.event_handler.graphql_appsync._registry import ResolverRegistry
from aws_lambda_powertools.event_handler.graphql_appsync.base import BaseRouter

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import AppSyncResolverEvent
from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Pattern
from typing import TYPE_CHECKING, Pattern

from aws_lambda_powertools.event_handler.api_gateway import (
ApiGatewayResolver,
ProxyEventType,
)

if TYPE_CHECKING:
from collections.abc import Callable
from http import HTTPStatus

from aws_lambda_powertools.event_handler import CORSConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def _validate_field(
"""
Validate a field, and append any errors to the existing_errors list.
"""
validated_value, errors = field.validate(value, value, loc=loc)
validated_value, errors = field.validate(value=value, loc=loc)

if isinstance(errors, list):
processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=())
Expand Down
34 changes: 16 additions & 18 deletions aws_lambda_powertools/event_handler/openapi/compat.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,30 @@
# mypy: ignore-errors
# flake8: noqa
from __future__ import annotations

from collections import deque
from copy import copy
from collections.abc import Mapping, Sequence

# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different
# versions of a module, so we need to ignore errors here.

from dataclasses import dataclass, is_dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Deque, FrozenSet, List, Mapping, Sequence, Set, Tuple, Union

from typing_extensions import Annotated, Literal, get_origin, get_args

from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from typing import TYPE_CHECKING, Any, Deque, FrozenSet, List, Set, Tuple, Union

from aws_lambda_powertools.event_handler.openapi.types import COMPONENT_REF_PREFIX, UnionType

from pydantic import TypeAdapter, ValidationError
from pydantic import BaseModel, TypeAdapter, ValidationError, create_model

# Importing from internal libraries in Pydantic may introduce potential risks, as these internal libraries
# are not part of the public API and may change without notice in future releases.
# We use this for forward reference, as it allows us to handle forward references in type annotations.
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic.fields import FieldInfo
from pydantic._internal._utils import lenient_issubclass
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from typing_extensions import Annotated, Literal, get_args, get_origin

from aws_lambda_powertools.event_handler.openapi.types import UnionType

if TYPE_CHECKING:
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue

from aws_lambda_powertools.event_handler.openapi.types import IncEx, ModelNameMap

Undefined = PydanticUndefined
Expand Down Expand Up @@ -119,7 +113,10 @@ def serialize(
)

def validate(
self, value: Any, values: dict[str, Any] = {}, *, loc: tuple[int | str, ...] = ()
self,
value: Any,
*,
loc: tuple[int | str, ...] = (),
) -> tuple[Any, list[dict[str, Any]] | None]:
try:
return (self._type_adapter.validate_python(value, from_attributes=True), None)
Expand Down Expand Up @@ -184,7 +181,8 @@ def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:

def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
"Field required",
[{"type": "missing", "loc": loc, "input": {}}],
).errors()[0]
error["input"] = None
return error
Expand Down Expand Up @@ -308,7 +306,7 @@ def value_is_sequence(value: Any) -> bool:

def _annotation_is_complex(annotation: type[Any] | None) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping)) # TODO: UploadFile
lenient_issubclass(annotation, (BaseModel, Mapping)) # Keep it to UploadFile
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
Expand Down
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
import re
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, cast
from typing import TYPE_CHECKING, Any, ForwardRef, cast

from aws_lambda_powertools.event_handler.openapi.compat import (
ModelField,
Expand All @@ -27,6 +27,8 @@
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel

if TYPE_CHECKING:
from collections.abc import Callable

from pydantic import BaseModel

"""
Expand Down
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path, PurePath
from re import Pattern
from types import GeneratorType
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any
from uuid import UUID

from pydantic import BaseModel
Expand All @@ -17,6 +17,8 @@
from aws_lambda_powertools.event_handler.openapi.compat import _model_dump

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.event_handler.openapi.types import IncEx

from aws_lambda_powertools.event_handler.openapi.exceptions import SerializationError
Expand Down
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Literal, Sequence
from collections.abc import Sequence
from typing import Any, Literal


class ValidationException(Exception):
Expand Down
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Literal

from pydantic import BaseConfig
from pydantic.fields import FieldInfo
Expand All @@ -20,6 +20,8 @@
)

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.event_handler.openapi.models import Example
from aws_lambda_powertools.event_handler.openapi.types import CacheKey

Expand Down
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import types
from typing import TYPE_CHECKING, Any, Callable, Dict, Set, Type, TypedDict, Union
from typing import TYPE_CHECKING, Any, Dict, Set, Type, TypedDict, Union

if TYPE_CHECKING:
from collections.abc import Callable
from enum import Enum

from pydantic import BaseModel
Expand Down
7 changes: 5 additions & 2 deletions aws_lambda_powertools/event_handler/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from typing import Any, Dict, List, Mapping
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Mapping


class _FrozenDict(dict):
Expand All @@ -18,7 +21,7 @@ def __hash__(self):
return hash(frozenset(self.keys()))


class _FrozenListDict(List[Dict[str, List[str]]]):
class _FrozenListDict(list[dict[str, list[str]]]):
"""
Freezes a list of dictionaries containing lists of strings.

Expand Down
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/vpc_lattice.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Pattern
from typing import TYPE_CHECKING, Pattern

from aws_lambda_powertools.event_handler.api_gateway import (
ApiGatewayResolver,
ProxyEventType,
)

if TYPE_CHECKING:
from collections.abc import Callable
from http import HTTPStatus

from aws_lambda_powertools.event_handler import CORSConfig
Expand Down
2 changes: 2 additions & 0 deletions tests/functional/event_handler/_pydantic/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json

import fastjsonschema
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from pydantic import BaseModel

from aws_lambda_powertools.event_handler import content_types
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import math
from collections import deque
from dataclasses import dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json

from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Router
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ class User(BaseModel):

@app.get("/")
def handler() -> User:
return User(name="Ruben Fonseca")
return User(name="Powertools")

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.models import (
APIKey,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.models import Server

Expand Down
Loading
Loading