From 998f2707a41888b1469a8e26a010af842ba40e6d Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Wed, 14 Aug 2024 19:47:25 -0500 Subject: [PATCH 1/2] refactor(middlewares): add from __future__ import annotations and update code according to ruff rules TCH, UP006, UP007, UP037 and FA100. --- .../event_handler/middlewares/base.py | 8 ++- .../middlewares/openapi_validation.py | 70 ++++++++++--------- .../middlewares/schema_validation.py | 24 ++++--- 3 files changed, 56 insertions(+), 46 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/base.py b/aws_lambda_powertools/event_handler/middlewares/base.py index 700f02e8a30..3998c7c80bd 100644 --- a/aws_lambda_powertools/event_handler/middlewares/base.py +++ b/aws_lambda_powertools/event_handler/middlewares/base.py @@ -1,9 +1,13 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Generic, Protocol +from typing import TYPE_CHECKING, Generic, Protocol -from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.types import EventHandlerInstance +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler.api_gateway import Response + class NextMiddleware(Protocol): def __call__(self, app: EventHandlerInstance) -> Response: diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 12b70987f8a..1a179da0ff5 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -1,13 +1,13 @@ +from __future__ import annotations + import dataclasses import json import logging from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence from pydantic import BaseModel -from aws_lambda_powertools.event_handler import Response -from aws_lambda_powertools.event_handler.api_gateway import Route from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware from aws_lambda_powertools.event_handler.openapi.compat import ( ModelField, @@ -20,8 +20,12 @@ from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError from aws_lambda_powertools.event_handler.openapi.params import Param -from aws_lambda_powertools.event_handler.openapi.types import IncEx -from aws_lambda_powertools.event_handler.types import EventHandlerInstance + +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler import Response + from aws_lambda_powertools.event_handler.api_gateway import Route + from aws_lambda_powertools.event_handler.openapi.types import IncEx + from aws_lambda_powertools.event_handler.types import EventHandlerInstance logger = logging.getLogger(__name__) @@ -36,8 +40,6 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler): -------- ```python - from typing import List - from pydantic import BaseModel from aws_lambda_powertools.event_handler.api_gateway import ( @@ -50,12 +52,12 @@ class Todo(BaseModel): app = APIGatewayRestResolver(enable_validation=True) @app.get("/todos") - def get_todos(): List[Todo]: + def get_todos(): list[Todo]: return [Todo(name="hello world")] ``` """ - def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None): + def __init__(self, validation_serializer: Callable[[Any], str] | None = None): """ Initialize the OpenAPIValidationMiddleware. @@ -72,8 +74,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> route: Route = app.context["_route"] - values: Dict[str, Any] = {} - errors: List[Any] = [] + values: dict[str, Any] = {} + errors: list[Any] = [] # Process path values, which can be found on the route_args path_values, path_errors = _request_params_to_args( @@ -147,10 +149,10 @@ def _handle_response(self, *, route: Route, response: Response): def _serialize_response( self, *, - field: Optional[ModelField] = None, + field: ModelField | None = None, response_content: Any, - include: Optional[IncEx] = None, - exclude: Optional[IncEx] = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = True, exclude_unset: bool = False, exclude_defaults: bool = False, @@ -160,7 +162,7 @@ def _serialize_response( Serialize the response content according to the field type. """ if field: - errors: List[Dict[str, Any]] = [] + errors: list[dict[str, Any]] = [] # MAINTENANCE: remove this when we drop pydantic v1 if not hasattr(field, "serializable"): response_content = self._prepare_response_content( @@ -232,7 +234,7 @@ def _prepare_response_content( return dataclasses.asdict(res) return res - def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]: + def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]: """ Get the request body from the event, and parse it as JSON. """ @@ -261,7 +263,7 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]: def _request_params_to_args( required_params: Sequence[ModelField], received_params: Mapping[str, Any], -) -> Tuple[Dict[str, Any], List[Any]]: +) -> tuple[dict[str, Any], list[Any]]: """ Convert the request params to a dictionary of values using validation, and returns a list of errors. """ @@ -294,14 +296,14 @@ def _request_params_to_args( def _request_body_to_args( - required_params: List[ModelField], - received_body: Optional[Dict[str, Any]], -) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: + required_params: list[ModelField], + received_body: dict[str, Any] | None, +) -> tuple[dict[str, Any], list[dict[str, Any]]]: """ Convert the request body to a dictionary of values using validation, and returns a list of errors. """ - values: Dict[str, Any] = {} - errors: List[Dict[str, Any]] = [] + values: dict[str, Any] = {} + errors: list[dict[str, Any]] = [] received_body, field_alias_omitted = _get_embed_body( field=required_params[0], @@ -313,11 +315,11 @@ def _request_body_to_args( # This sets the location to: # { "user": { object } } if field.alias == user # { { object } if field_alias is omitted - loc: Tuple[str, ...] = ("body", field.alias) + loc: tuple[str, ...] = ("body", field.alias) if field_alias_omitted: loc = ("body",) - value: Optional[Any] = None + value: Any | None = None # Now that we know what to look for, try to get the value from the received body if received_body is not None: @@ -347,8 +349,8 @@ def _validate_field( *, field: ModelField, value: Any, - loc: Tuple[str, ...], - existing_errors: List[Dict[str, Any]], + loc: tuple[str, ...], + existing_errors: list[dict[str, Any]], ): """ Validate a field, and append any errors to the existing_errors list. @@ -367,9 +369,9 @@ def _validate_field( def _get_embed_body( *, field: ModelField, - required_params: List[ModelField], - received_body: Optional[Dict[str, Any]], -) -> Tuple[Optional[Dict[str, Any]], bool]: + required_params: list[ModelField], + received_body: dict[str, Any] | None, +) -> tuple[dict[str, Any] | None, bool]: field_info = field.field_info embed = getattr(field_info, "embed", None) @@ -382,15 +384,15 @@ def _get_embed_body( def _normalize_multi_query_string_with_param( - query_string: Dict[str, List[str]], + query_string: dict[str, list[str]], params: Sequence[ModelField], -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Extract and normalize resolved_query_string_parameters Parameters ---------- - query_string: Dict + query_string: dict A dictionary containing the initial query string parameters. params: Sequence[ModelField] A sequence of ModelField objects representing parameters. @@ -399,7 +401,7 @@ def _normalize_multi_query_string_with_param( ------- A dictionary containing the processed multi_query_string_parameters. """ - resolved_query_string: Dict[str, Any] = query_string + resolved_query_string: dict[str, Any] = query_string for param in filter(is_scalar_field, params): try: # if the target parameter is a scalar, we keep the first value of the query string @@ -416,7 +418,7 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], Parameters ---------- - headers: Dict + headers: MutableMapping[str, Any] A dictionary containing the initial header parameters. params: Sequence[ModelField] A sequence of ModelField objects representing parameters. diff --git a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py index 66be47a48f3..c31d15bec03 100644 --- a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import logging -from typing import Dict, Optional +from typing import TYPE_CHECKING -from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware -from aws_lambda_powertools.event_handler.types import EventHandlerInstance from aws_lambda_powertools.utilities.validation import validate from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler.api_gateway import Response + from aws_lambda_powertools.event_handler.types import EventHandlerInstance + logger = logging.getLogger(__name__) @@ -48,21 +52,21 @@ def lambda_handler(event, context): def __init__( self, - inbound_schema: Dict, - inbound_formats: Optional[Dict] = None, - outbound_schema: Optional[Dict] = None, - outbound_formats: Optional[Dict] = None, + inbound_schema: dict, + inbound_formats: dict | None = None, + outbound_schema: dict | None = None, + outbound_formats: dict | None = None, ): """See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters. Parameters ---------- - inbound_schema : Dict + inbound_schema : dict JSON Schema to validate incoming event - inbound_formats : Optional[Dict], optional + inbound_formats : dict | None, optional Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None JSON Schema to validate outbound event, by default None - outbound_formats : Optional[Dict], optional + outbound_formats : dict | None, optional Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None """ # noqa: E501 super().__init__() From f7b8f732a9021cb1ddece86c94f09be9d6643be5 Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Thu, 15 Aug 2024 08:56:25 -0500 Subject: [PATCH 2/2] Move more types to TYPE_CHECKING --- .../event_handler/middlewares/openapi_validation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 1a179da0ff5..eaed5083ab7 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -8,9 +8,8 @@ from pydantic import BaseModel -from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler from aws_lambda_powertools.event_handler.openapi.compat import ( - ModelField, _model_dump, _normalize_errors, _regenerate_error_with_loc, @@ -24,6 +23,8 @@ if TYPE_CHECKING: from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler.api_gateway import Route + from aws_lambda_powertools.event_handler.middlewares import NextMiddleware + from aws_lambda_powertools.event_handler.openapi.compat import ModelField from aws_lambda_powertools.event_handler.openapi.types import IncEx from aws_lambda_powertools.event_handler.types import EventHandlerInstance