Skip to content

Commit 265bee5

Browse files
Merge branch 'v3' into tracing_annotations
2 parents 8023443 + ebdc3ad commit 265bee5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+660
-577
lines changed

aws_lambda_powertools/event_handler/middlewares/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
2-
from typing import Generic, Protocol
4+
from typing import TYPE_CHECKING, Generic, Protocol
35

4-
from aws_lambda_powertools.event_handler.api_gateway import Response
56
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
67

8+
if TYPE_CHECKING:
9+
from aws_lambda_powertools.event_handler.api_gateway import Response
10+
711

812
class NextMiddleware(Protocol):
913
def __call__(self, app: EventHandlerInstance) -> Response:

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+39-36
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1+
from __future__ import annotations
2+
13
import dataclasses
24
import json
35
import logging
46
from copy import deepcopy
5-
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple
7+
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
68

79
from pydantic import BaseModel
810

9-
from aws_lambda_powertools.event_handler import Response
10-
from aws_lambda_powertools.event_handler.api_gateway import Route
11-
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
11+
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler
1212
from aws_lambda_powertools.event_handler.openapi.compat import (
13-
ModelField,
1413
_model_dump,
1514
_normalize_errors,
1615
_regenerate_error_with_loc,
@@ -20,8 +19,14 @@
2019
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
2120
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
2221
from aws_lambda_powertools.event_handler.openapi.params import Param
23-
from aws_lambda_powertools.event_handler.openapi.types import IncEx
24-
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
22+
23+
if TYPE_CHECKING:
24+
from aws_lambda_powertools.event_handler import Response
25+
from aws_lambda_powertools.event_handler.api_gateway import Route
26+
from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
27+
from aws_lambda_powertools.event_handler.openapi.compat import ModelField
28+
from aws_lambda_powertools.event_handler.openapi.types import IncEx
29+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
2530

2631
logger = logging.getLogger(__name__)
2732

@@ -36,8 +41,6 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
3641
--------
3742
3843
```python
39-
from typing import List
40-
4144
from pydantic import BaseModel
4245
4346
from aws_lambda_powertools.event_handler.api_gateway import (
@@ -50,12 +53,12 @@ class Todo(BaseModel):
5053
app = APIGatewayRestResolver(enable_validation=True)
5154
5255
@app.get("/todos")
53-
def get_todos(): List[Todo]:
56+
def get_todos(): list[Todo]:
5457
return [Todo(name="hello world")]
5558
```
5659
"""
5760

58-
def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None):
61+
def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
5962
"""
6063
Initialize the OpenAPIValidationMiddleware.
6164
@@ -72,8 +75,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
7275

7376
route: Route = app.context["_route"]
7477

75-
values: Dict[str, Any] = {}
76-
errors: List[Any] = []
78+
values: dict[str, Any] = {}
79+
errors: list[Any] = []
7780

7881
# Process path values, which can be found on the route_args
7982
path_values, path_errors = _request_params_to_args(
@@ -147,10 +150,10 @@ def _handle_response(self, *, route: Route, response: Response):
147150
def _serialize_response(
148151
self,
149152
*,
150-
field: Optional[ModelField] = None,
153+
field: ModelField | None = None,
151154
response_content: Any,
152-
include: Optional[IncEx] = None,
153-
exclude: Optional[IncEx] = None,
155+
include: IncEx | None = None,
156+
exclude: IncEx | None = None,
154157
by_alias: bool = True,
155158
exclude_unset: bool = False,
156159
exclude_defaults: bool = False,
@@ -160,7 +163,7 @@ def _serialize_response(
160163
Serialize the response content according to the field type.
161164
"""
162165
if field:
163-
errors: List[Dict[str, Any]] = []
166+
errors: list[dict[str, Any]] = []
164167
# MAINTENANCE: remove this when we drop pydantic v1
165168
if not hasattr(field, "serializable"):
166169
response_content = self._prepare_response_content(
@@ -232,7 +235,7 @@ def _prepare_response_content(
232235
return dataclasses.asdict(res)
233236
return res
234237

235-
def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
238+
def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
236239
"""
237240
Get the request body from the event, and parse it as JSON.
238241
"""
@@ -261,7 +264,7 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
261264
def _request_params_to_args(
262265
required_params: Sequence[ModelField],
263266
received_params: Mapping[str, Any],
264-
) -> Tuple[Dict[str, Any], List[Any]]:
267+
) -> tuple[dict[str, Any], list[Any]]:
265268
"""
266269
Convert the request params to a dictionary of values using validation, and returns a list of errors.
267270
"""
@@ -294,14 +297,14 @@ def _request_params_to_args(
294297

295298

296299
def _request_body_to_args(
297-
required_params: List[ModelField],
298-
received_body: Optional[Dict[str, Any]],
299-
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
300+
required_params: list[ModelField],
301+
received_body: dict[str, Any] | None,
302+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
300303
"""
301304
Convert the request body to a dictionary of values using validation, and returns a list of errors.
302305
"""
303-
values: Dict[str, Any] = {}
304-
errors: List[Dict[str, Any]] = []
306+
values: dict[str, Any] = {}
307+
errors: list[dict[str, Any]] = []
305308

306309
received_body, field_alias_omitted = _get_embed_body(
307310
field=required_params[0],
@@ -313,11 +316,11 @@ def _request_body_to_args(
313316
# This sets the location to:
314317
# { "user": { object } } if field.alias == user
315318
# { { object } if field_alias is omitted
316-
loc: Tuple[str, ...] = ("body", field.alias)
319+
loc: tuple[str, ...] = ("body", field.alias)
317320
if field_alias_omitted:
318321
loc = ("body",)
319322

320-
value: Optional[Any] = None
323+
value: Any | None = None
321324

322325
# Now that we know what to look for, try to get the value from the received body
323326
if received_body is not None:
@@ -347,8 +350,8 @@ def _validate_field(
347350
*,
348351
field: ModelField,
349352
value: Any,
350-
loc: Tuple[str, ...],
351-
existing_errors: List[Dict[str, Any]],
353+
loc: tuple[str, ...],
354+
existing_errors: list[dict[str, Any]],
352355
):
353356
"""
354357
Validate a field, and append any errors to the existing_errors list.
@@ -367,9 +370,9 @@ def _validate_field(
367370
def _get_embed_body(
368371
*,
369372
field: ModelField,
370-
required_params: List[ModelField],
371-
received_body: Optional[Dict[str, Any]],
372-
) -> Tuple[Optional[Dict[str, Any]], bool]:
373+
required_params: list[ModelField],
374+
received_body: dict[str, Any] | None,
375+
) -> tuple[dict[str, Any] | None, bool]:
373376
field_info = field.field_info
374377
embed = getattr(field_info, "embed", None)
375378

@@ -382,15 +385,15 @@ def _get_embed_body(
382385

383386

384387
def _normalize_multi_query_string_with_param(
385-
query_string: Dict[str, List[str]],
388+
query_string: dict[str, list[str]],
386389
params: Sequence[ModelField],
387-
) -> Dict[str, Any]:
390+
) -> dict[str, Any]:
388391
"""
389392
Extract and normalize resolved_query_string_parameters
390393
391394
Parameters
392395
----------
393-
query_string: Dict
396+
query_string: dict
394397
A dictionary containing the initial query string parameters.
395398
params: Sequence[ModelField]
396399
A sequence of ModelField objects representing parameters.
@@ -399,7 +402,7 @@ def _normalize_multi_query_string_with_param(
399402
-------
400403
A dictionary containing the processed multi_query_string_parameters.
401404
"""
402-
resolved_query_string: Dict[str, Any] = query_string
405+
resolved_query_string: dict[str, Any] = query_string
403406
for param in filter(is_scalar_field, params):
404407
try:
405408
# if the target parameter is a scalar, we keep the first value of the query string
@@ -416,7 +419,7 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],
416419
417420
Parameters
418421
----------
419-
headers: Dict
422+
headers: MutableMapping[str, Any]
420423
A dictionary containing the initial header parameters.
421424
params: Sequence[ModelField]
422425
A sequence of ModelField objects representing parameters.

aws_lambda_powertools/event_handler/middlewares/schema_validation.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
from __future__ import annotations
2+
13
import logging
2-
from typing import Dict, Optional
4+
from typing import TYPE_CHECKING
35

4-
from aws_lambda_powertools.event_handler.api_gateway import Response
56
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError
67
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
7-
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
88
from aws_lambda_powertools.utilities.validation import validate
99
from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError
1010

11+
if TYPE_CHECKING:
12+
from aws_lambda_powertools.event_handler.api_gateway import Response
13+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
14+
1115
logger = logging.getLogger(__name__)
1216

1317

@@ -48,21 +52,21 @@ def lambda_handler(event, context):
4852

4953
def __init__(
5054
self,
51-
inbound_schema: Dict,
52-
inbound_formats: Optional[Dict] = None,
53-
outbound_schema: Optional[Dict] = None,
54-
outbound_formats: Optional[Dict] = None,
55+
inbound_schema: dict,
56+
inbound_formats: dict | None = None,
57+
outbound_schema: dict | None = None,
58+
outbound_formats: dict | None = None,
5559
):
5660
"""See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters.
5761
5862
Parameters
5963
----------
60-
inbound_schema : Dict
64+
inbound_schema : dict
6165
JSON Schema to validate incoming event
62-
inbound_formats : Optional[Dict], optional
66+
inbound_formats : dict | None, optional
6367
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
6468
JSON Schema to validate outbound event, by default None
65-
outbound_formats : Optional[Dict], optional
69+
outbound_formats : dict | None, optional
6670
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
6771
""" # noqa: E501
6872
super().__init__()

aws_lambda_powertools/middleware_factory/factory.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1+
from __future__ import annotations
2+
13
import functools
24
import inspect
35
import logging
46
import os
5-
from typing import Any, Callable, Optional
7+
from typing import Any, Callable
68

7-
from ..shared import constants
8-
from ..shared.functions import resolve_truthy_env_var_choice
9-
from ..tracing import Tracer
10-
from .exceptions import MiddlewareInvalidArgumentError
9+
from aws_lambda_powertools.middleware_factory.exceptions import MiddlewareInvalidArgumentError
10+
from aws_lambda_powertools.shared import constants
11+
from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice
12+
from aws_lambda_powertools.tracing import Tracer
1113

1214
logger = logging.getLogger(__name__)
1315

1416

1517
# Maintenance: we can't yet provide an accurate return type without ParamSpec etc. see #1066
16-
def lambda_handler_decorator(decorator: Optional[Callable] = None, trace_execution: Optional[bool] = None) -> Callable:
18+
def lambda_handler_decorator(decorator: Callable | None = None, trace_execution: bool | None = None) -> Callable:
1719
"""Decorator factory for decorating Lambda handlers.
1820
1921
You can use lambda_handler_decorator to create your own middlewares,
@@ -112,7 +114,7 @@ def lambda_handler(event, context):
112114
)
113115

114116
@functools.wraps(decorator)
115-
def final_decorator(func: Optional[Callable] = None, **kwargs: Any):
117+
def final_decorator(func: Callable | None = None, **kwargs: Any):
116118
# If called with kwargs return new func with kwargs
117119
if func is None:
118120
return functools.partial(final_decorator, **kwargs)

aws_lambda_powertools/shared/cookies.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from datetime import datetime
1+
from __future__ import annotations
2+
23
from enum import Enum
34
from io import StringIO
4-
from typing import List, Optional
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
from datetime import datetime
59

610

711
class SameSite(Enum):
@@ -41,10 +45,10 @@ def __init__(
4145
domain: str = "",
4246
secure: bool = True,
4347
http_only: bool = False,
44-
max_age: Optional[int] = None,
45-
expires: Optional[datetime] = None,
46-
same_site: Optional[SameSite] = None,
47-
custom_attributes: Optional[List[str]] = None,
48+
max_age: int | None = None,
49+
expires: datetime | None = None,
50+
same_site: SameSite | None = None,
51+
custom_attributes: list[str] | None = None,
4852
):
4953
"""
5054
@@ -62,13 +66,13 @@ def __init__(
6266
Marks the cookie as secure, only sendable to the server with an encrypted request over the HTTPS protocol
6367
http_only: bool
6468
Enabling this attribute makes the cookie inaccessible to the JavaScript `Document.cookie` API
65-
max_age: Optional[int]
69+
max_age: int | None
6670
Defines the period of time after which the cookie is invalid. Use negative values to force cookie deletion.
67-
expires: Optional[datetime]
71+
expires: datetime | None
6872
Defines a date where the permanent cookie expires.
69-
same_site: Optional[SameSite]
73+
same_site: SameSite | None
7074
Determines if the cookie should be sent to third party websites
71-
custom_attributes: Optional[List[str]]
75+
custom_attributes: list[str] | None
7276
List of additional custom attributes to set on the cookie
7377
"""
7478
self.name = name

0 commit comments

Comments
 (0)