Skip to content

Commit 998f270

Browse files
committed
refactor(middlewares): add from __future__ import annotations
and update code according to ruff rules TCH, UP006, UP007, UP037 and FA100.
1 parent 2d59b7a commit 998f270

File tree

3 files changed

+56
-46
lines changed

3 files changed

+56
-46
lines changed

aws_lambda_powertools/event_handler/middlewares/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
2-
from typing import Generic, Protocol
4+
from typing import TYPE_CHECKING, Generic, Protocol
35

4-
from aws_lambda_powertools.event_handler.api_gateway import Response
56
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
67

8+
if TYPE_CHECKING:
9+
from aws_lambda_powertools.event_handler.api_gateway import Response
10+
711

812
class NextMiddleware(Protocol):
913
def __call__(self, app: EventHandlerInstance) -> Response:

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1+
from __future__ import annotations
2+
13
import dataclasses
24
import json
35
import logging
46
from copy import deepcopy
5-
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple
7+
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
68

79
from pydantic import BaseModel
810

9-
from aws_lambda_powertools.event_handler import Response
10-
from aws_lambda_powertools.event_handler.api_gateway import Route
1111
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
1212
from aws_lambda_powertools.event_handler.openapi.compat import (
1313
ModelField,
@@ -20,8 +20,12 @@
2020
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
2121
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
2222
from aws_lambda_powertools.event_handler.openapi.params import Param
23-
from aws_lambda_powertools.event_handler.openapi.types import IncEx
24-
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
23+
24+
if TYPE_CHECKING:
25+
from aws_lambda_powertools.event_handler import Response
26+
from aws_lambda_powertools.event_handler.api_gateway import Route
27+
from aws_lambda_powertools.event_handler.openapi.types import IncEx
28+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
2529

2630
logger = logging.getLogger(__name__)
2731

@@ -36,8 +40,6 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
3640
--------
3741
3842
```python
39-
from typing import List
40-
4143
from pydantic import BaseModel
4244
4345
from aws_lambda_powertools.event_handler.api_gateway import (
@@ -50,12 +52,12 @@ class Todo(BaseModel):
5052
app = APIGatewayRestResolver(enable_validation=True)
5153
5254
@app.get("/todos")
53-
def get_todos(): List[Todo]:
55+
def get_todos(): list[Todo]:
5456
return [Todo(name="hello world")]
5557
```
5658
"""
5759

58-
def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None):
60+
def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
5961
"""
6062
Initialize the OpenAPIValidationMiddleware.
6163
@@ -72,8 +74,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
7274

7375
route: Route = app.context["_route"]
7476

75-
values: Dict[str, Any] = {}
76-
errors: List[Any] = []
77+
values: dict[str, Any] = {}
78+
errors: list[Any] = []
7779

7880
# Process path values, which can be found on the route_args
7981
path_values, path_errors = _request_params_to_args(
@@ -147,10 +149,10 @@ def _handle_response(self, *, route: Route, response: Response):
147149
def _serialize_response(
148150
self,
149151
*,
150-
field: Optional[ModelField] = None,
152+
field: ModelField | None = None,
151153
response_content: Any,
152-
include: Optional[IncEx] = None,
153-
exclude: Optional[IncEx] = None,
154+
include: IncEx | None = None,
155+
exclude: IncEx | None = None,
154156
by_alias: bool = True,
155157
exclude_unset: bool = False,
156158
exclude_defaults: bool = False,
@@ -160,7 +162,7 @@ def _serialize_response(
160162
Serialize the response content according to the field type.
161163
"""
162164
if field:
163-
errors: List[Dict[str, Any]] = []
165+
errors: list[dict[str, Any]] = []
164166
# MAINTENANCE: remove this when we drop pydantic v1
165167
if not hasattr(field, "serializable"):
166168
response_content = self._prepare_response_content(
@@ -232,7 +234,7 @@ def _prepare_response_content(
232234
return dataclasses.asdict(res)
233235
return res
234236

235-
def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
237+
def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
236238
"""
237239
Get the request body from the event, and parse it as JSON.
238240
"""
@@ -261,7 +263,7 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
261263
def _request_params_to_args(
262264
required_params: Sequence[ModelField],
263265
received_params: Mapping[str, Any],
264-
) -> Tuple[Dict[str, Any], List[Any]]:
266+
) -> tuple[dict[str, Any], list[Any]]:
265267
"""
266268
Convert the request params to a dictionary of values using validation, and returns a list of errors.
267269
"""
@@ -294,14 +296,14 @@ def _request_params_to_args(
294296

295297

296298
def _request_body_to_args(
297-
required_params: List[ModelField],
298-
received_body: Optional[Dict[str, Any]],
299-
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
299+
required_params: list[ModelField],
300+
received_body: dict[str, Any] | None,
301+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
300302
"""
301303
Convert the request body to a dictionary of values using validation, and returns a list of errors.
302304
"""
303-
values: Dict[str, Any] = {}
304-
errors: List[Dict[str, Any]] = []
305+
values: dict[str, Any] = {}
306+
errors: list[dict[str, Any]] = []
305307

306308
received_body, field_alias_omitted = _get_embed_body(
307309
field=required_params[0],
@@ -313,11 +315,11 @@ def _request_body_to_args(
313315
# This sets the location to:
314316
# { "user": { object } } if field.alias == user
315317
# { { object } if field_alias is omitted
316-
loc: Tuple[str, ...] = ("body", field.alias)
318+
loc: tuple[str, ...] = ("body", field.alias)
317319
if field_alias_omitted:
318320
loc = ("body",)
319321

320-
value: Optional[Any] = None
322+
value: Any | None = None
321323

322324
# Now that we know what to look for, try to get the value from the received body
323325
if received_body is not None:
@@ -347,8 +349,8 @@ def _validate_field(
347349
*,
348350
field: ModelField,
349351
value: Any,
350-
loc: Tuple[str, ...],
351-
existing_errors: List[Dict[str, Any]],
352+
loc: tuple[str, ...],
353+
existing_errors: list[dict[str, Any]],
352354
):
353355
"""
354356
Validate a field, and append any errors to the existing_errors list.
@@ -367,9 +369,9 @@ def _validate_field(
367369
def _get_embed_body(
368370
*,
369371
field: ModelField,
370-
required_params: List[ModelField],
371-
received_body: Optional[Dict[str, Any]],
372-
) -> Tuple[Optional[Dict[str, Any]], bool]:
372+
required_params: list[ModelField],
373+
received_body: dict[str, Any] | None,
374+
) -> tuple[dict[str, Any] | None, bool]:
373375
field_info = field.field_info
374376
embed = getattr(field_info, "embed", None)
375377

@@ -382,15 +384,15 @@ def _get_embed_body(
382384

383385

384386
def _normalize_multi_query_string_with_param(
385-
query_string: Dict[str, List[str]],
387+
query_string: dict[str, list[str]],
386388
params: Sequence[ModelField],
387-
) -> Dict[str, Any]:
389+
) -> dict[str, Any]:
388390
"""
389391
Extract and normalize resolved_query_string_parameters
390392
391393
Parameters
392394
----------
393-
query_string: Dict
395+
query_string: dict
394396
A dictionary containing the initial query string parameters.
395397
params: Sequence[ModelField]
396398
A sequence of ModelField objects representing parameters.
@@ -399,7 +401,7 @@ def _normalize_multi_query_string_with_param(
399401
-------
400402
A dictionary containing the processed multi_query_string_parameters.
401403
"""
402-
resolved_query_string: Dict[str, Any] = query_string
404+
resolved_query_string: dict[str, Any] = query_string
403405
for param in filter(is_scalar_field, params):
404406
try:
405407
# if the target parameter is a scalar, we keep the first value of the query string
@@ -416,7 +418,7 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],
416418
417419
Parameters
418420
----------
419-
headers: Dict
421+
headers: MutableMapping[str, Any]
420422
A dictionary containing the initial header parameters.
421423
params: Sequence[ModelField]
422424
A sequence of ModelField objects representing parameters.

aws_lambda_powertools/event_handler/middlewares/schema_validation.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
from __future__ import annotations
2+
13
import logging
2-
from typing import Dict, Optional
4+
from typing import TYPE_CHECKING
35

4-
from aws_lambda_powertools.event_handler.api_gateway import Response
56
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError
67
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
7-
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
88
from aws_lambda_powertools.utilities.validation import validate
99
from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError
1010

11+
if TYPE_CHECKING:
12+
from aws_lambda_powertools.event_handler.api_gateway import Response
13+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
14+
1115
logger = logging.getLogger(__name__)
1216

1317

@@ -48,21 +52,21 @@ def lambda_handler(event, context):
4852

4953
def __init__(
5054
self,
51-
inbound_schema: Dict,
52-
inbound_formats: Optional[Dict] = None,
53-
outbound_schema: Optional[Dict] = None,
54-
outbound_formats: Optional[Dict] = None,
55+
inbound_schema: dict,
56+
inbound_formats: dict | None = None,
57+
outbound_schema: dict | None = None,
58+
outbound_formats: dict | None = None,
5559
):
5660
"""See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters.
5761
5862
Parameters
5963
----------
60-
inbound_schema : Dict
64+
inbound_schema : dict
6165
JSON Schema to validate incoming event
62-
inbound_formats : Optional[Dict], optional
66+
inbound_formats : dict | None, optional
6367
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
6468
JSON Schema to validate outbound event, by default None
65-
outbound_formats : Optional[Dict], optional
69+
outbound_formats : dict | None, optional
6670
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
6771
""" # noqa: E501
6872
super().__init__()

0 commit comments

Comments
 (0)