Skip to content

Commit 961ae5c

Browse files
Merge branch 'v3' into streaming_annotations
2 parents c79b963 + 456bf82 commit 961ae5c

File tree

11 files changed

+146
-124
lines changed

11 files changed

+146
-124
lines changed

aws_lambda_powertools/event_handler/middlewares/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
2-
from typing import Generic, Protocol
4+
from typing import TYPE_CHECKING, Generic, Protocol
35

4-
from aws_lambda_powertools.event_handler.api_gateway import Response
56
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
67

8+
if TYPE_CHECKING:
9+
from aws_lambda_powertools.event_handler.api_gateway import Response
10+
711

812
class NextMiddleware(Protocol):
913
def __call__(self, app: EventHandlerInstance) -> Response:

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+39-36
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1+
from __future__ import annotations
2+
13
import dataclasses
24
import json
35
import logging
46
from copy import deepcopy
5-
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple
7+
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
68

79
from pydantic import BaseModel
810

9-
from aws_lambda_powertools.event_handler import Response
10-
from aws_lambda_powertools.event_handler.api_gateway import Route
11-
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
11+
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler
1212
from aws_lambda_powertools.event_handler.openapi.compat import (
13-
ModelField,
1413
_model_dump,
1514
_normalize_errors,
1615
_regenerate_error_with_loc,
@@ -20,8 +19,14 @@
2019
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
2120
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
2221
from aws_lambda_powertools.event_handler.openapi.params import Param
23-
from aws_lambda_powertools.event_handler.openapi.types import IncEx
24-
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
22+
23+
if TYPE_CHECKING:
24+
from aws_lambda_powertools.event_handler import Response
25+
from aws_lambda_powertools.event_handler.api_gateway import Route
26+
from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
27+
from aws_lambda_powertools.event_handler.openapi.compat import ModelField
28+
from aws_lambda_powertools.event_handler.openapi.types import IncEx
29+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
2530

2631
logger = logging.getLogger(__name__)
2732

@@ -36,8 +41,6 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
3641
--------
3742
3843
```python
39-
from typing import List
40-
4144
from pydantic import BaseModel
4245
4346
from aws_lambda_powertools.event_handler.api_gateway import (
@@ -50,12 +53,12 @@ class Todo(BaseModel):
5053
app = APIGatewayRestResolver(enable_validation=True)
5154
5255
@app.get("/todos")
53-
def get_todos(): List[Todo]:
56+
def get_todos(): list[Todo]:
5457
return [Todo(name="hello world")]
5558
```
5659
"""
5760

58-
def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None):
61+
def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
5962
"""
6063
Initialize the OpenAPIValidationMiddleware.
6164
@@ -72,8 +75,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
7275

7376
route: Route = app.context["_route"]
7477

75-
values: Dict[str, Any] = {}
76-
errors: List[Any] = []
78+
values: dict[str, Any] = {}
79+
errors: list[Any] = []
7780

7881
# Process path values, which can be found on the route_args
7982
path_values, path_errors = _request_params_to_args(
@@ -147,10 +150,10 @@ def _handle_response(self, *, route: Route, response: Response):
147150
def _serialize_response(
148151
self,
149152
*,
150-
field: Optional[ModelField] = None,
153+
field: ModelField | None = None,
151154
response_content: Any,
152-
include: Optional[IncEx] = None,
153-
exclude: Optional[IncEx] = None,
155+
include: IncEx | None = None,
156+
exclude: IncEx | None = None,
154157
by_alias: bool = True,
155158
exclude_unset: bool = False,
156159
exclude_defaults: bool = False,
@@ -160,7 +163,7 @@ def _serialize_response(
160163
Serialize the response content according to the field type.
161164
"""
162165
if field:
163-
errors: List[Dict[str, Any]] = []
166+
errors: list[dict[str, Any]] = []
164167
# MAINTENANCE: remove this when we drop pydantic v1
165168
if not hasattr(field, "serializable"):
166169
response_content = self._prepare_response_content(
@@ -232,7 +235,7 @@ def _prepare_response_content(
232235
return dataclasses.asdict(res)
233236
return res
234237

235-
def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
238+
def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
236239
"""
237240
Get the request body from the event, and parse it as JSON.
238241
"""
@@ -261,7 +264,7 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
261264
def _request_params_to_args(
262265
required_params: Sequence[ModelField],
263266
received_params: Mapping[str, Any],
264-
) -> Tuple[Dict[str, Any], List[Any]]:
267+
) -> tuple[dict[str, Any], list[Any]]:
265268
"""
266269
Convert the request params to a dictionary of values using validation, and returns a list of errors.
267270
"""
@@ -294,14 +297,14 @@ def _request_params_to_args(
294297

295298

296299
def _request_body_to_args(
297-
required_params: List[ModelField],
298-
received_body: Optional[Dict[str, Any]],
299-
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
300+
required_params: list[ModelField],
301+
received_body: dict[str, Any] | None,
302+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
300303
"""
301304
Convert the request body to a dictionary of values using validation, and returns a list of errors.
302305
"""
303-
values: Dict[str, Any] = {}
304-
errors: List[Dict[str, Any]] = []
306+
values: dict[str, Any] = {}
307+
errors: list[dict[str, Any]] = []
305308

306309
received_body, field_alias_omitted = _get_embed_body(
307310
field=required_params[0],
@@ -313,11 +316,11 @@ def _request_body_to_args(
313316
# This sets the location to:
314317
# { "user": { object } } if field.alias == user
315318
# { { object } if field_alias is omitted
316-
loc: Tuple[str, ...] = ("body", field.alias)
319+
loc: tuple[str, ...] = ("body", field.alias)
317320
if field_alias_omitted:
318321
loc = ("body",)
319322

320-
value: Optional[Any] = None
323+
value: Any | None = None
321324

322325
# Now that we know what to look for, try to get the value from the received body
323326
if received_body is not None:
@@ -347,8 +350,8 @@ def _validate_field(
347350
*,
348351
field: ModelField,
349352
value: Any,
350-
loc: Tuple[str, ...],
351-
existing_errors: List[Dict[str, Any]],
353+
loc: tuple[str, ...],
354+
existing_errors: list[dict[str, Any]],
352355
):
353356
"""
354357
Validate a field, and append any errors to the existing_errors list.
@@ -367,9 +370,9 @@ def _validate_field(
367370
def _get_embed_body(
368371
*,
369372
field: ModelField,
370-
required_params: List[ModelField],
371-
received_body: Optional[Dict[str, Any]],
372-
) -> Tuple[Optional[Dict[str, Any]], bool]:
373+
required_params: list[ModelField],
374+
received_body: dict[str, Any] | None,
375+
) -> tuple[dict[str, Any] | None, bool]:
373376
field_info = field.field_info
374377
embed = getattr(field_info, "embed", None)
375378

@@ -382,15 +385,15 @@ def _get_embed_body(
382385

383386

384387
def _normalize_multi_query_string_with_param(
385-
query_string: Dict[str, List[str]],
388+
query_string: dict[str, list[str]],
386389
params: Sequence[ModelField],
387-
) -> Dict[str, Any]:
390+
) -> dict[str, Any]:
388391
"""
389392
Extract and normalize resolved_query_string_parameters
390393
391394
Parameters
392395
----------
393-
query_string: Dict
396+
query_string: dict
394397
A dictionary containing the initial query string parameters.
395398
params: Sequence[ModelField]
396399
A sequence of ModelField objects representing parameters.
@@ -399,7 +402,7 @@ def _normalize_multi_query_string_with_param(
399402
-------
400403
A dictionary containing the processed multi_query_string_parameters.
401404
"""
402-
resolved_query_string: Dict[str, Any] = query_string
405+
resolved_query_string: dict[str, Any] = query_string
403406
for param in filter(is_scalar_field, params):
404407
try:
405408
# if the target parameter is a scalar, we keep the first value of the query string
@@ -416,7 +419,7 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],
416419
417420
Parameters
418421
----------
419-
headers: Dict
422+
headers: MutableMapping[str, Any]
420423
A dictionary containing the initial header parameters.
421424
params: Sequence[ModelField]
422425
A sequence of ModelField objects representing parameters.

aws_lambda_powertools/event_handler/middlewares/schema_validation.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
from __future__ import annotations
2+
13
import logging
2-
from typing import Dict, Optional
4+
from typing import TYPE_CHECKING
35

4-
from aws_lambda_powertools.event_handler.api_gateway import Response
56
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError
67
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
7-
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
88
from aws_lambda_powertools.utilities.validation import validate
99
from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError
1010

11+
if TYPE_CHECKING:
12+
from aws_lambda_powertools.event_handler.api_gateway import Response
13+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
14+
1115
logger = logging.getLogger(__name__)
1216

1317

@@ -48,21 +52,21 @@ def lambda_handler(event, context):
4852

4953
def __init__(
5054
self,
51-
inbound_schema: Dict,
52-
inbound_formats: Optional[Dict] = None,
53-
outbound_schema: Optional[Dict] = None,
54-
outbound_formats: Optional[Dict] = None,
55+
inbound_schema: dict,
56+
inbound_formats: dict | None = None,
57+
outbound_schema: dict | None = None,
58+
outbound_formats: dict | None = None,
5559
):
5660
"""See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters.
5761
5862
Parameters
5963
----------
60-
inbound_schema : Dict
64+
inbound_schema : dict
6165
JSON Schema to validate incoming event
62-
inbound_formats : Optional[Dict], optional
66+
inbound_formats : dict | None, optional
6367
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
6468
JSON Schema to validate outbound event, by default None
65-
outbound_formats : Optional[Dict], optional
69+
outbound_formats : dict | None, optional
6670
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
6771
""" # noqa: E501
6872
super().__init__()

aws_lambda_powertools/utilities/data_masking/base.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import functools
44
import logging
55
import warnings
6-
from numbers import Number
7-
from typing import Any, Callable, Mapping, Optional, Sequence, Union, overload
6+
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence, overload
87

98
from jsonpath_ng.ext import parse
109

@@ -14,6 +13,9 @@
1413
)
1514
from aws_lambda_powertools.utilities.data_masking.provider import BaseProvider
1615

16+
if TYPE_CHECKING:
17+
from numbers import Number
18+
1719
logger = logging.getLogger(__name__)
1820

1921

@@ -43,7 +45,7 @@ def lambda_handler(event, context):
4345

4446
def __init__(
4547
self,
46-
provider: Optional[BaseProvider] = None,
48+
provider: BaseProvider | None = None,
4749
raise_on_missing_field: bool = True,
4850
):
4951
self.provider = provider or BaseProvider()
@@ -111,7 +113,7 @@ def _apply_action(
111113
----------
112114
data : str | dict
113115
The input data to process.
114-
fields : Optional[List[str]]
116+
fields : list[str] | None
115117
A list of fields to apply the action to. If 'None', the action is applied to the entire 'data'.
116118
action : Callable
117119
The action to apply to the data. It should be a callable that performs an operation on the data
@@ -142,21 +144,21 @@ def _apply_action(
142144

143145
def _apply_action_to_fields(
144146
self,
145-
data: Union[dict, str],
147+
data: dict | str,
146148
fields: list,
147149
action: Callable,
148150
provider_options: dict | None = None,
149151
**encryption_context: str,
150-
) -> Union[dict, str]:
152+
) -> dict | str:
151153
"""
152154
This method takes the input data, which can be either a dictionary or a JSON string,
153155
and erases, encrypts, or decrypts the specified fields.
154156
155157
Parameters
156158
----------
157-
data : Union[dict, str])
159+
data : dict | str)
158160
The input data to process. It can be either a dictionary or a JSON string.
159-
fields : List
161+
fields : list
160162
A list of fields to apply the action to. Each field can be specified as a string or
161163
a list of strings representing nested keys in the dictionary.
162164
action : Callable

aws_lambda_powertools/utilities/data_masking/provider/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def encrypt(self, data) -> str:
2424
def decrypt(self, data) -> Any:
2525
# Implementation logic for data decryption
2626
27-
def erase(self, data) -> Union[str, Iterable]:
27+
def erase(self, data) -> str | Iterable:
2828
# Implementation logic for data masking
2929
pass
3030

0 commit comments

Comments
 (0)