Skip to content

Commit 199f33e

Browse files
refactor(event_handler): add from __future__ import annotations in the Middlewares (#4975)
* refactor(middlewares): add from __future__ import annotations and update code according to ruff rules TCH, UP006, UP007, UP037 and FA100. * Move more types to TYPE_CHECKING --------- Co-authored-by: Leandro Damascena <[email protected]>
1 parent 46414a8 commit 199f33e

File tree

3 files changed

+59
-48
lines changed

3 files changed

+59
-48
lines changed

aws_lambda_powertools/event_handler/middlewares/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
2-
from typing import Generic, Protocol
4+
from typing import TYPE_CHECKING, Generic, Protocol
35

4-
from aws_lambda_powertools.event_handler.api_gateway import Response
56
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
67

8+
if TYPE_CHECKING:
9+
from aws_lambda_powertools.event_handler.api_gateway import Response
10+
711

812
class NextMiddleware(Protocol):
913
def __call__(self, app: EventHandlerInstance) -> Response:

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+39-36
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1+
from __future__ import annotations
2+
13
import dataclasses
24
import json
35
import logging
46
from copy import deepcopy
5-
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple
7+
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
68

79
from pydantic import BaseModel
810

9-
from aws_lambda_powertools.event_handler import Response
10-
from aws_lambda_powertools.event_handler.api_gateway import Route
11-
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
11+
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler
1212
from aws_lambda_powertools.event_handler.openapi.compat import (
13-
ModelField,
1413
_model_dump,
1514
_normalize_errors,
1615
_regenerate_error_with_loc,
@@ -20,8 +19,14 @@
2019
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
2120
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
2221
from aws_lambda_powertools.event_handler.openapi.params import Param
23-
from aws_lambda_powertools.event_handler.openapi.types import IncEx
24-
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
22+
23+
if TYPE_CHECKING:
24+
from aws_lambda_powertools.event_handler import Response
25+
from aws_lambda_powertools.event_handler.api_gateway import Route
26+
from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
27+
from aws_lambda_powertools.event_handler.openapi.compat import ModelField
28+
from aws_lambda_powertools.event_handler.openapi.types import IncEx
29+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
2530

2631
logger = logging.getLogger(__name__)
2732

@@ -36,8 +41,6 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
3641
--------
3742
3843
```python
39-
from typing import List
40-
4144
from pydantic import BaseModel
4245
4346
from aws_lambda_powertools.event_handler.api_gateway import (
@@ -50,12 +53,12 @@ class Todo(BaseModel):
5053
app = APIGatewayRestResolver(enable_validation=True)
5154
5255
@app.get("/todos")
53-
def get_todos(): List[Todo]:
56+
def get_todos(): list[Todo]:
5457
return [Todo(name="hello world")]
5558
```
5659
"""
5760

58-
def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None):
61+
def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
5962
"""
6063
Initialize the OpenAPIValidationMiddleware.
6164
@@ -72,8 +75,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
7275

7376
route: Route = app.context["_route"]
7477

75-
values: Dict[str, Any] = {}
76-
errors: List[Any] = []
78+
values: dict[str, Any] = {}
79+
errors: list[Any] = []
7780

7881
# Process path values, which can be found on the route_args
7982
path_values, path_errors = _request_params_to_args(
@@ -147,10 +150,10 @@ def _handle_response(self, *, route: Route, response: Response):
147150
def _serialize_response(
148151
self,
149152
*,
150-
field: Optional[ModelField] = None,
153+
field: ModelField | None = None,
151154
response_content: Any,
152-
include: Optional[IncEx] = None,
153-
exclude: Optional[IncEx] = None,
155+
include: IncEx | None = None,
156+
exclude: IncEx | None = None,
154157
by_alias: bool = True,
155158
exclude_unset: bool = False,
156159
exclude_defaults: bool = False,
@@ -160,7 +163,7 @@ def _serialize_response(
160163
Serialize the response content according to the field type.
161164
"""
162165
if field:
163-
errors: List[Dict[str, Any]] = []
166+
errors: list[dict[str, Any]] = []
164167
# MAINTENANCE: remove this when we drop pydantic v1
165168
if not hasattr(field, "serializable"):
166169
response_content = self._prepare_response_content(
@@ -232,7 +235,7 @@ def _prepare_response_content(
232235
return dataclasses.asdict(res)
233236
return res
234237

235-
def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
238+
def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
236239
"""
237240
Get the request body from the event, and parse it as JSON.
238241
"""
@@ -261,7 +264,7 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
261264
def _request_params_to_args(
262265
required_params: Sequence[ModelField],
263266
received_params: Mapping[str, Any],
264-
) -> Tuple[Dict[str, Any], List[Any]]:
267+
) -> tuple[dict[str, Any], list[Any]]:
265268
"""
266269
Convert the request params to a dictionary of values using validation, and returns a list of errors.
267270
"""
@@ -294,14 +297,14 @@ def _request_params_to_args(
294297

295298

296299
def _request_body_to_args(
297-
required_params: List[ModelField],
298-
received_body: Optional[Dict[str, Any]],
299-
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
300+
required_params: list[ModelField],
301+
received_body: dict[str, Any] | None,
302+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
300303
"""
301304
Convert the request body to a dictionary of values using validation, and returns a list of errors.
302305
"""
303-
values: Dict[str, Any] = {}
304-
errors: List[Dict[str, Any]] = []
306+
values: dict[str, Any] = {}
307+
errors: list[dict[str, Any]] = []
305308

306309
received_body, field_alias_omitted = _get_embed_body(
307310
field=required_params[0],
@@ -313,11 +316,11 @@ def _request_body_to_args(
313316
# This sets the location to:
314317
# { "user": { object } } if field.alias == user
315318
# { { object } if field_alias is omitted
316-
loc: Tuple[str, ...] = ("body", field.alias)
319+
loc: tuple[str, ...] = ("body", field.alias)
317320
if field_alias_omitted:
318321
loc = ("body",)
319322

320-
value: Optional[Any] = None
323+
value: Any | None = None
321324

322325
# Now that we know what to look for, try to get the value from the received body
323326
if received_body is not None:
@@ -347,8 +350,8 @@ def _validate_field(
347350
*,
348351
field: ModelField,
349352
value: Any,
350-
loc: Tuple[str, ...],
351-
existing_errors: List[Dict[str, Any]],
353+
loc: tuple[str, ...],
354+
existing_errors: list[dict[str, Any]],
352355
):
353356
"""
354357
Validate a field, and append any errors to the existing_errors list.
@@ -367,9 +370,9 @@ def _validate_field(
367370
def _get_embed_body(
368371
*,
369372
field: ModelField,
370-
required_params: List[ModelField],
371-
received_body: Optional[Dict[str, Any]],
372-
) -> Tuple[Optional[Dict[str, Any]], bool]:
373+
required_params: list[ModelField],
374+
received_body: dict[str, Any] | None,
375+
) -> tuple[dict[str, Any] | None, bool]:
373376
field_info = field.field_info
374377
embed = getattr(field_info, "embed", None)
375378

@@ -382,15 +385,15 @@ def _get_embed_body(
382385

383386

384387
def _normalize_multi_query_string_with_param(
385-
query_string: Dict[str, List[str]],
388+
query_string: dict[str, list[str]],
386389
params: Sequence[ModelField],
387-
) -> Dict[str, Any]:
390+
) -> dict[str, Any]:
388391
"""
389392
Extract and normalize resolved_query_string_parameters
390393
391394
Parameters
392395
----------
393-
query_string: Dict
396+
query_string: dict
394397
A dictionary containing the initial query string parameters.
395398
params: Sequence[ModelField]
396399
A sequence of ModelField objects representing parameters.
@@ -399,7 +402,7 @@ def _normalize_multi_query_string_with_param(
399402
-------
400403
A dictionary containing the processed multi_query_string_parameters.
401404
"""
402-
resolved_query_string: Dict[str, Any] = query_string
405+
resolved_query_string: dict[str, Any] = query_string
403406
for param in filter(is_scalar_field, params):
404407
try:
405408
# if the target parameter is a scalar, we keep the first value of the query string
@@ -416,7 +419,7 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],
416419
417420
Parameters
418421
----------
419-
headers: Dict
422+
headers: MutableMapping[str, Any]
420423
A dictionary containing the initial header parameters.
421424
params: Sequence[ModelField]
422425
A sequence of ModelField objects representing parameters.

aws_lambda_powertools/event_handler/middlewares/schema_validation.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
from __future__ import annotations
2+
13
import logging
2-
from typing import Dict, Optional
4+
from typing import TYPE_CHECKING
35

4-
from aws_lambda_powertools.event_handler.api_gateway import Response
56
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError
67
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
7-
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
88
from aws_lambda_powertools.utilities.validation import validate
99
from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError
1010

11+
if TYPE_CHECKING:
12+
from aws_lambda_powertools.event_handler.api_gateway import Response
13+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
14+
1115
logger = logging.getLogger(__name__)
1216

1317

@@ -48,21 +52,21 @@ def lambda_handler(event, context):
4852

4953
def __init__(
5054
self,
51-
inbound_schema: Dict,
52-
inbound_formats: Optional[Dict] = None,
53-
outbound_schema: Optional[Dict] = None,
54-
outbound_formats: Optional[Dict] = None,
55+
inbound_schema: dict,
56+
inbound_formats: dict | None = None,
57+
outbound_schema: dict | None = None,
58+
outbound_formats: dict | None = None,
5559
):
5660
"""See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters.
5761
5862
Parameters
5963
----------
60-
inbound_schema : Dict
64+
inbound_schema : dict
6165
JSON Schema to validate incoming event
62-
inbound_formats : Optional[Dict], optional
66+
inbound_formats : dict | None, optional
6367
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
6468
JSON Schema to validate outbound event, by default None
65-
outbound_formats : Optional[Dict], optional
69+
outbound_formats : dict | None, optional
6670
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
6771
""" # noqa: E501
6872
super().__init__()

0 commit comments

Comments
 (0)