diff --git a/aws_lambda_powertools/event_handler/middlewares/base.py b/aws_lambda_powertools/event_handler/middlewares/base.py index 700f02e8a30..3998c7c80bd 100644 --- a/aws_lambda_powertools/event_handler/middlewares/base.py +++ b/aws_lambda_powertools/event_handler/middlewares/base.py @@ -1,9 +1,13 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Generic, Protocol +from typing import TYPE_CHECKING, Generic, Protocol -from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.types import EventHandlerInstance +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler.api_gateway import Response + class NextMiddleware(Protocol): def __call__(self, app: EventHandlerInstance) -> Response: diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 12b70987f8a..eaed5083ab7 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -1,16 +1,15 @@ +from __future__ import annotations + import dataclasses import json import logging from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence from pydantic import BaseModel -from aws_lambda_powertools.event_handler import Response -from aws_lambda_powertools.event_handler.api_gateway import Route -from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler from aws_lambda_powertools.event_handler.openapi.compat import ( - ModelField, _model_dump, _normalize_errors, _regenerate_error_with_loc, @@ -20,8 +19,14 @@ from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError from aws_lambda_powertools.event_handler.openapi.params import Param -from aws_lambda_powertools.event_handler.openapi.types import IncEx -from aws_lambda_powertools.event_handler.types import EventHandlerInstance + +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler import Response + from aws_lambda_powertools.event_handler.api_gateway import Route + from aws_lambda_powertools.event_handler.middlewares import NextMiddleware + from aws_lambda_powertools.event_handler.openapi.compat import ModelField + from aws_lambda_powertools.event_handler.openapi.types import IncEx + from aws_lambda_powertools.event_handler.types import EventHandlerInstance logger = logging.getLogger(__name__) @@ -36,8 +41,6 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler): -------- ```python - from typing import List - from pydantic import BaseModel from aws_lambda_powertools.event_handler.api_gateway import ( @@ -50,12 +53,12 @@ class Todo(BaseModel): app = APIGatewayRestResolver(enable_validation=True) @app.get("/todos") - def get_todos(): List[Todo]: + def get_todos(): list[Todo]: return [Todo(name="hello world")] ``` """ - def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None): + def __init__(self, validation_serializer: Callable[[Any], str] | None = None): """ Initialize the OpenAPIValidationMiddleware. @@ -72,8 +75,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> route: Route = app.context["_route"] - values: Dict[str, Any] = {} - errors: List[Any] = [] + values: dict[str, Any] = {} + errors: list[Any] = [] # Process path values, which can be found on the route_args path_values, path_errors = _request_params_to_args( @@ -147,10 +150,10 @@ def _handle_response(self, *, route: Route, response: Response): def _serialize_response( self, *, - field: Optional[ModelField] = None, + field: ModelField | None = None, response_content: Any, - include: Optional[IncEx] = None, - exclude: Optional[IncEx] = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = True, exclude_unset: bool = False, exclude_defaults: bool = False, @@ -160,7 +163,7 @@ def _serialize_response( Serialize the response content according to the field type. """ if field: - errors: List[Dict[str, Any]] = [] + errors: list[dict[str, Any]] = [] # MAINTENANCE: remove this when we drop pydantic v1 if not hasattr(field, "serializable"): response_content = self._prepare_response_content( @@ -232,7 +235,7 @@ def _prepare_response_content( return dataclasses.asdict(res) return res - def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]: + def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]: """ Get the request body from the event, and parse it as JSON. """ @@ -261,7 +264,7 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]: def _request_params_to_args( required_params: Sequence[ModelField], received_params: Mapping[str, Any], -) -> Tuple[Dict[str, Any], List[Any]]: +) -> tuple[dict[str, Any], list[Any]]: """ Convert the request params to a dictionary of values using validation, and returns a list of errors. """ @@ -294,14 +297,14 @@ def _request_params_to_args( def _request_body_to_args( - required_params: List[ModelField], - received_body: Optional[Dict[str, Any]], -) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: + required_params: list[ModelField], + received_body: dict[str, Any] | None, +) -> tuple[dict[str, Any], list[dict[str, Any]]]: """ Convert the request body to a dictionary of values using validation, and returns a list of errors. """ - values: Dict[str, Any] = {} - errors: List[Dict[str, Any]] = [] + values: dict[str, Any] = {} + errors: list[dict[str, Any]] = [] received_body, field_alias_omitted = _get_embed_body( field=required_params[0], @@ -313,11 +316,11 @@ def _request_body_to_args( # This sets the location to: # { "user": { object } } if field.alias == user # { { object } if field_alias is omitted - loc: Tuple[str, ...] = ("body", field.alias) + loc: tuple[str, ...] = ("body", field.alias) if field_alias_omitted: loc = ("body",) - value: Optional[Any] = None + value: Any | None = None # Now that we know what to look for, try to get the value from the received body if received_body is not None: @@ -347,8 +350,8 @@ def _validate_field( *, field: ModelField, value: Any, - loc: Tuple[str, ...], - existing_errors: List[Dict[str, Any]], + loc: tuple[str, ...], + existing_errors: list[dict[str, Any]], ): """ Validate a field, and append any errors to the existing_errors list. @@ -367,9 +370,9 @@ def _validate_field( def _get_embed_body( *, field: ModelField, - required_params: List[ModelField], - received_body: Optional[Dict[str, Any]], -) -> Tuple[Optional[Dict[str, Any]], bool]: + required_params: list[ModelField], + received_body: dict[str, Any] | None, +) -> tuple[dict[str, Any] | None, bool]: field_info = field.field_info embed = getattr(field_info, "embed", None) @@ -382,15 +385,15 @@ def _get_embed_body( def _normalize_multi_query_string_with_param( - query_string: Dict[str, List[str]], + query_string: dict[str, list[str]], params: Sequence[ModelField], -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Extract and normalize resolved_query_string_parameters Parameters ---------- - query_string: Dict + query_string: dict A dictionary containing the initial query string parameters. params: Sequence[ModelField] A sequence of ModelField objects representing parameters. @@ -399,7 +402,7 @@ def _normalize_multi_query_string_with_param( ------- A dictionary containing the processed multi_query_string_parameters. """ - resolved_query_string: Dict[str, Any] = query_string + resolved_query_string: dict[str, Any] = query_string for param in filter(is_scalar_field, params): try: # if the target parameter is a scalar, we keep the first value of the query string @@ -416,7 +419,7 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], Parameters ---------- - headers: Dict + headers: MutableMapping[str, Any] A dictionary containing the initial header parameters. params: Sequence[ModelField] A sequence of ModelField objects representing parameters. diff --git a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py index 66be47a48f3..c31d15bec03 100644 --- a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import logging -from typing import Dict, Optional +from typing import TYPE_CHECKING -from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware -from aws_lambda_powertools.event_handler.types import EventHandlerInstance from aws_lambda_powertools.utilities.validation import validate from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler.api_gateway import Response + from aws_lambda_powertools.event_handler.types import EventHandlerInstance + logger = logging.getLogger(__name__) @@ -48,21 +52,21 @@ def lambda_handler(event, context): def __init__( self, - inbound_schema: Dict, - inbound_formats: Optional[Dict] = None, - outbound_schema: Optional[Dict] = None, - outbound_formats: Optional[Dict] = None, + inbound_schema: dict, + inbound_formats: dict | None = None, + outbound_schema: dict | None = None, + outbound_formats: dict | None = None, ): """See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters. Parameters ---------- - inbound_schema : Dict + inbound_schema : dict JSON Schema to validate incoming event - inbound_formats : Optional[Dict], optional + inbound_formats : dict | None, optional Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None JSON Schema to validate outbound event, by default None - outbound_formats : Optional[Dict], optional + outbound_formats : dict | None, optional Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None """ # noqa: E501 super().__init__()