Skip to content

refactor(shared): add from __future__ import annotations #4942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions aws_lambda_powertools/shared/cookies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from datetime import datetime
from __future__ import annotations

from enum import Enum
from io import StringIO
from typing import List, Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from datetime import datetime


class SameSite(Enum):
Expand Down Expand Up @@ -41,10 +45,10 @@ def __init__(
domain: str = "",
secure: bool = True,
http_only: bool = False,
max_age: Optional[int] = None,
expires: Optional[datetime] = None,
same_site: Optional[SameSite] = None,
custom_attributes: Optional[List[str]] = None,
max_age: int | None = None,
expires: datetime | None = None,
same_site: SameSite | None = None,
custom_attributes: list[str] | None = None,
):
"""

Expand All @@ -62,13 +66,13 @@ def __init__(
Marks the cookie as secure, only sendable to the server with an encrypted request over the HTTPS protocol
http_only: bool
Enabling this attribute makes the cookie inaccessible to the JavaScript `Document.cookie` API
max_age: Optional[int]
max_age: int | None
Defines the period of time after which the cookie is invalid. Use negative values to force cookie deletion.
expires: Optional[datetime]
expires: datetime | None
Defines a date where the permanent cookie expires.
same_site: Optional[SameSite]
same_site: SameSite | None
Determines if the cookie should be sent to third party websites
custom_attributes: Optional[List[str]]
custom_attributes: list[str] | None
List of additional custom attributes to set on the cookie
"""
self.name = name
Expand Down
18 changes: 10 additions & 8 deletions aws_lambda_powertools/shared/dynamodb_deserializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from decimal import Clamped, Context, Decimal, Inexact, Overflow, Rounded, Underflow
from typing import Any, Callable, Dict, Optional, Sequence, Set
from typing import Any, Callable, Sequence

# NOTE: DynamoDB supports up to 38 digits precision
# Therefore, this ensures our Decimal follows what's stored in the table
Expand All @@ -21,7 +23,7 @@ class TypeDeserializer:
since we don't support Python 2.
"""

def deserialize(self, value: Dict) -> Any:
def deserialize(self, value: dict) -> Any:
"""Deserialize DynamoDB data types into Python types.

Parameters
Expand Down Expand Up @@ -57,7 +59,7 @@ def deserialize(self, value: Dict) -> Any:
"""

dynamodb_type = list(value.keys())[0]
deserializer: Optional[Callable] = getattr(self, f"_deserialize_{dynamodb_type}".lower(), None)
deserializer: Callable | None = getattr(self, f"_deserialize_{dynamodb_type}".lower(), None)
if deserializer is None:
raise TypeError(f"Dynamodb type {dynamodb_type} is not supported")

Expand All @@ -78,17 +80,17 @@ def _deserialize_s(self, value: str) -> str:
def _deserialize_b(self, value: bytes) -> bytes:
return value

def _deserialize_ns(self, value: Sequence[str]) -> Set[Decimal]:
def _deserialize_ns(self, value: Sequence[str]) -> set[Decimal]:
return set(map(self._deserialize_n, value))

def _deserialize_ss(self, value: Sequence[str]) -> Set[str]:
def _deserialize_ss(self, value: Sequence[str]) -> set[str]:
return set(map(self._deserialize_s, value))

def _deserialize_bs(self, value: Sequence[bytes]) -> Set[bytes]:
def _deserialize_bs(self, value: Sequence[bytes]) -> set[bytes]:
return set(map(self._deserialize_b, value))

def _deserialize_l(self, value: Sequence[Dict]) -> Sequence[Any]:
def _deserialize_l(self, value: Sequence[dict]) -> Sequence[Any]:
return [self.deserialize(v) for v in value]

def _deserialize_m(self, value: Dict) -> Dict:
def _deserialize_m(self, value: dict) -> dict:
return {k: self.deserialize(v) for k, v in value.items()}
22 changes: 11 additions & 11 deletions aws_lambda_powertools/shared/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from binascii import Error as BinAsciiError
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Union, overload
from typing import Any, Generator, overload

from aws_lambda_powertools.shared import constants

Expand All @@ -32,7 +32,7 @@ def strtobool(value: str) -> bool:
raise ValueError(f"invalid truth value {value!r}")


def resolve_truthy_env_var_choice(env: str, choice: Optional[bool] = None) -> bool:
def resolve_truthy_env_var_choice(env: str, choice: bool | None = None) -> bool:
"""Pick explicit choice over truthy env value, if available, otherwise return truthy env value

NOTE: Environment variable should be resolved by the caller.
Expand All @@ -52,27 +52,27 @@ def resolve_truthy_env_var_choice(env: str, choice: Optional[bool] = None) -> bo
return choice if choice is not None else strtobool(env)


def resolve_max_age(env: str, choice: Optional[int]) -> int:
def resolve_max_age(env: str, choice: int | None) -> int:
"""Resolve max age value"""
return choice if choice is not None else int(env)


@overload
def resolve_env_var_choice(env: Optional[str], choice: float) -> float: ...
def resolve_env_var_choice(env: str | None, choice: float) -> float: ...


@overload
def resolve_env_var_choice(env: Optional[str], choice: str) -> str: ...
def resolve_env_var_choice(env: str | None, choice: str) -> str: ...


@overload
def resolve_env_var_choice(env: Optional[str], choice: Optional[str]) -> str: ...
def resolve_env_var_choice(env: str | None, choice: str | None) -> str: ...


def resolve_env_var_choice(
env: Optional[str] = None,
choice: Optional[Union[str, float]] = None,
) -> Optional[Union[str, float]]:
env: str | None = None,
choice: str | float | None = None,
) -> str | float | None:
"""Pick explicit choice over env, if available, otherwise return env value received

NOTE: Environment variable should be resolved by the caller.
Expand Down Expand Up @@ -136,12 +136,12 @@ def powertools_debug_is_set() -> bool:
return False


def slice_dictionary(data: Dict, chunk_size: int) -> Generator[Dict, None, None]:
def slice_dictionary(data: dict, chunk_size: int) -> Generator[dict, None, None]:
for _ in range(0, len(data), chunk_size):
yield {dict_key: data[dict_key] for dict_key in itertools.islice(data, chunk_size)}


def extract_event_from_common_models(data: Any) -> Dict | Any:
def extract_event_from_common_models(data: Any) -> dict | Any:
"""Extract raw event from common types used in Powertools

If event cannot be extracted, return received data as is.
Expand Down
25 changes: 14 additions & 11 deletions aws_lambda_powertools/shared/headers_serializer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import warnings
from collections import defaultdict
from typing import Any, Dict, List, Union
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.shared.cookies import Cookie
if TYPE_CHECKING:
from aws_lambda_powertools.shared.cookies import Cookie


class BaseHeadersSerializer:
Expand All @@ -11,23 +14,23 @@ class BaseHeadersSerializer:
ALB and Lambda Function URL response payload.
"""

def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
def serialize(self, headers: dict[str, str | list[str]], cookies: list[Cookie]) -> dict[str, Any]:
"""
Serializes headers and cookies according to the request type.
Returns a dict that can be merged with the response payload.

Parameters
----------
headers: Dict[str, List[str]]
headers: dict[str, str | list[str]]
A dictionary of headers to set in the response
cookies: List[str]
cookies: list[Cookie]
A list of cookies to set in the response
"""
raise NotImplementedError()


class HttpApiHeadersSerializer(BaseHeadersSerializer):
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
def serialize(self, headers: dict[str, str | list[str]], cookies: list[Cookie]) -> dict[str, Any]:
"""
When using HTTP APIs or LambdaFunctionURLs, everything is taken care automatically for us.
We can directly assign a list of cookies and a dict of headers to the response payload, and the
Expand All @@ -39,7 +42,7 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Coo

# Format 2.0 doesn't have multiValueHeaders or multiValueQueryStringParameters fields.
# Duplicate headers are combined with commas and included in the headers field.
combined_headers: Dict[str, str] = {}
combined_headers: dict[str, str] = {}
for key, values in headers.items():
# omit headers with explicit null values
if values is None:
Expand All @@ -54,7 +57,7 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Coo


class MultiValueHeadersSerializer(BaseHeadersSerializer):
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
def serialize(self, headers: dict[str, str | list[str]], cookies: list[Cookie]) -> dict[str, Any]:
"""
When using REST APIs, headers can be encoded using the `multiValueHeaders` key on the response.
This is also the case when using an ALB integration with the `multiValueHeaders` option enabled.
Expand All @@ -63,7 +66,7 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Coo
https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-output-format
https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html#multi-value-headers-response
"""
payload: Dict[str, List[str]] = defaultdict(list)
payload: dict[str, list[str]] = defaultdict(list)
for key, values in headers.items():
# omit headers with explicit null values
if values is None:
Expand All @@ -83,14 +86,14 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Coo


class SingleValueHeadersSerializer(BaseHeadersSerializer):
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
def serialize(self, headers: dict[str, str | list[str]], cookies: list[Cookie]) -> dict[str, Any]:
"""
The ALB integration has `multiValueHeaders` disabled by default.
If we try to set multiple headers with the same key, or more than one cookie, print a warning.

https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html#respond-to-load-balancer
"""
payload: Dict[str, Dict[str, str]] = {}
payload: dict[str, dict[str, str]] = {}
payload.setdefault("headers", {})

if cookies:
Expand Down
4 changes: 1 addition & 3 deletions aws_lambda_powertools/shared/json_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ class Encoder(json.JSONEncoder):

def default(self, obj):
if isinstance(obj, decimal.Decimal):
if obj.is_nan():
return math.nan
return str(obj)
return math.nan if obj.is_nan() else str(obj)

if is_pydantic(obj):
return pydantic_to_dict(obj)
Expand Down
Loading