From e4a9b5e4ef6c0c67ddce03be4cddc5a2ff05be38 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 19 Nov 2024 17:39:02 +0000 Subject: [PATCH 1/5] Accepting None when working with output serialization --- .../idempotency/serialization/dataclass.py | 4 ++ .../idempotency/serialization/functions.py | 43 ++++++++++++++++ .../idempotency/serialization/pydantic.py | 4 ++ docs/utilities/idempotency.md | 5 +- .../idempotency/_boto3/test_idempotency.py | 49 ++++++++++++++++++- .../test_idempotency_with_pydantic.py | 46 +++++++++++++++++ 6 files changed, 149 insertions(+), 2 deletions(-) create mode 100644 aws_lambda_powertools/utilities/idempotency/serialization/functions.py diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py b/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py index be5d7007ef3..fc8b72252c0 100644 --- a/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py +++ b/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py @@ -11,6 +11,7 @@ BaseIdempotencyModelSerializer, BaseIdempotencySerializer, ) +from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type DataClass = Any @@ -37,6 +38,9 @@ def from_dict(self, data: dict) -> DataClass: @classmethod def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer: + + model_type = get_actual_type(model_type=model_type) + if model_type is None: raise IdempotencyNoSerializationModelError("No serialization model was supplied") diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py new file mode 100644 index 00000000000..a610202eff8 --- /dev/null +++ b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py @@ -0,0 +1,43 @@ +from types import UnionType +from typing import Any, Optional, Union, get_args, get_origin + +from aws_lambda_powertools.utilities.idempotency.exceptions import ( + IdempotencyModelTypeError, +) + + +def get_actual_type(model_type: Any) -> Any: + """ + Extract the actual type from a potentially Optional or Union type. + + This function handles types that may be wrapped in Optional or Union, + including the Python 3.10+ Union syntax (Type | None). + + Parameters + ---------- + model_type: Any + The type to analyze. Can be a simple type, Optional[Type], BaseModel, dataclass + Returns + ------- + The actual type without Optional or Union wrappers. + + Raises: + IdempotencyModelTypeError: If the type specification is invalid + (e.g., Union with multiple non-None types). + """ + # Check if the type is Optional, Union, or the new Union syntax + if get_origin(model_type) in (Optional, Union, UnionType): + # Get the arguments of the type (e.g., for Optional[int], this would be (int, NoneType)) + args = get_args(model_type) + + # Filter out NoneType to get the actual type(s) + actual_type = [arg for arg in args if arg is not type(None)] + + # Ensure there's exactly one non-None type + if len(actual_type) != 1: + raise IdempotencyModelTypeError( + "Invalid type: expected a single type, optionally wrapped in Optional or Union with None.", + ) + return actual_type[0] + + return model_type diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py b/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py index 42ae179833f..8ba45a40583 100644 --- a/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py +++ b/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py @@ -12,6 +12,7 @@ BaseIdempotencyModelSerializer, BaseIdempotencySerializer, ) +from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type class PydanticSerializer(BaseIdempotencyModelSerializer): @@ -34,6 +35,9 @@ def from_dict(self, data: dict) -> BaseModel: @classmethod def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer: + + model_type = get_actual_type(model_type=model_type) + if model_type is None: raise IdempotencyNoSerializationModelError("No serialization model was supplied") diff --git a/docs/utilities/idempotency.md b/docs/utilities/idempotency.md index f263aa1cb6e..cfe85877961 100644 --- a/docs/utilities/idempotency.md +++ b/docs/utilities/idempotency.md @@ -212,7 +212,10 @@ By default, `idempotent_function` serializes, stores, and returns your annotated The output serializer supports any JSON serializable data, **Python Dataclasses** and **Pydantic Models**. -!!! info "When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string." +!!! info + When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string. + + Function returns must be annotated with a single type, optionally wrapped in `Optional` or `Union` with `None`. === "Pydantic" diff --git a/tests/functional/idempotency/_boto3/test_idempotency.py b/tests/functional/idempotency/_boto3/test_idempotency.py index 35f82333e9c..f2214e2fd65 100644 --- a/tests/functional/idempotency/_boto3/test_idempotency.py +++ b/tests/functional/idempotency/_boto3/test_idempotency.py @@ -1,7 +1,7 @@ import copy import datetime import warnings -from typing import Any +from typing import Any, Optional from unittest.mock import MagicMock, Mock import jmespath @@ -2014,3 +2014,50 @@ def lambda_handler(event, context): stubber.assert_no_pending_responses() stubber.deactivate() + + +@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"]) +def test_idempotent_function_serialization_dataclass_with_optional_return(output_serializer_type: str): + # GIVEN + dataclasses = get_dataclasses_lib() + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_dataclass_with_optional_return..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + @dataclasses.dataclass + class PaymentInput: + customer_id: str + transaction_id: str + + @dataclasses.dataclass + class PaymentOutput: + customer_id: str + transaction_id: str + + if output_serializer_type == "explicit": + output_serializer = DataclassSerializer( + model=PaymentOutput, + ) + else: + output_serializer = DataclassSerializer + + @idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=output_serializer, + ) + def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]: + return PaymentOutput(**dataclasses.asdict(payment)) + + # WHEN + payment = PaymentInput(**mock_event) + first_call: PaymentOutput = collect_payment(payment=payment) + assert first_call.customer_id == payment.customer_id + assert first_call.transaction_id == payment.transaction_id + assert isinstance(first_call, PaymentOutput) + second_call: PaymentOutput = collect_payment(payment=payment) + assert isinstance(second_call, PaymentOutput) + assert second_call.customer_id == payment.customer_id + assert second_call.transaction_id == payment.transaction_id diff --git a/tests/functional/idempotency/_pydantic/test_idempotency_with_pydantic.py b/tests/functional/idempotency/_pydantic/test_idempotency_with_pydantic.py index aaac5948e63..f8e3debbc30 100644 --- a/tests/functional/idempotency/_pydantic/test_idempotency_with_pydantic.py +++ b/tests/functional/idempotency/_pydantic/test_idempotency_with_pydantic.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest from pydantic import BaseModel @@ -219,3 +221,47 @@ def collect_payment(payment: Payment): # THEN idempotency key assertion happens at MockPersistenceLayer assert result == payment.transaction_id + + +@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"]) +def test_idempotent_function_serialization_pydantic_with_optional_return(output_serializer_type: str): + # GIVEN + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic_with_optional_return..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + class PaymentInput(BaseModel): + customer_id: str + transaction_id: str + + class PaymentOutput(BaseModel): + customer_id: str + transaction_id: str + + if output_serializer_type == "explicit": + output_serializer = PydanticSerializer( + model=PaymentOutput, + ) + else: + output_serializer = PydanticSerializer + + @idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=output_serializer, + ) + def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]: + return PaymentOutput(**payment.dict()) + + # WHEN + payment = PaymentInput(**mock_event) + first_call: PaymentOutput = collect_payment(payment=payment) + assert first_call.customer_id == payment.customer_id + assert first_call.transaction_id == payment.transaction_id + assert isinstance(first_call, PaymentOutput) + second_call: PaymentOutput = collect_payment(payment=payment) + assert isinstance(second_call, PaymentOutput) + assert second_call.customer_id == payment.customer_id + assert second_call.transaction_id == payment.transaction_id From ae2940a8556f9b18be9ddf70434ce29739c88110 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Fri, 20 Dec 2024 10:33:22 +0000 Subject: [PATCH 2/5] Fix Python3.8/3.9 --- .../utilities/idempotency/serialization/functions.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py index a610202eff8..52293992e89 100644 --- a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py +++ b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py @@ -1,6 +1,10 @@ -from types import UnionType from typing import Any, Optional, Union, get_args, get_origin +try: + from types import UnionType +except ImportError: + UnionType = None + from aws_lambda_powertools.utilities.idempotency.exceptions import ( IdempotencyModelTypeError, ) @@ -9,10 +13,8 @@ def get_actual_type(model_type: Any) -> Any: """ Extract the actual type from a potentially Optional or Union type. - This function handles types that may be wrapped in Optional or Union, including the Python 3.10+ Union syntax (Type | None). - Parameters ---------- model_type: Any @@ -20,13 +22,12 @@ def get_actual_type(model_type: Any) -> Any: Returns ------- The actual type without Optional or Union wrappers. - Raises: IdempotencyModelTypeError: If the type specification is invalid (e.g., Union with multiple non-None types). """ # Check if the type is Optional, Union, or the new Union syntax - if get_origin(model_type) in (Optional, Union, UnionType): + if get_origin(model_type) in (Optional, Union) or (UnionType is not None and get_origin(model_type) is UnionType): # Get the arguments of the type (e.g., for Optional[int], this would be (int, NoneType)) args = get_args(model_type) From cb60c14f7a439983ad52b23ffc3a0decb005b727 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Fri, 20 Dec 2024 10:35:54 +0000 Subject: [PATCH 3/5] Make mypy happy --- .../utilities/idempotency/serialization/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py index 52293992e89..484bfabbb56 100644 --- a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py +++ b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py @@ -3,7 +3,7 @@ try: from types import UnionType except ImportError: - UnionType = None + UnionType = None # type: ignore[assignment, misc] from aws_lambda_powertools.utilities.idempotency.exceptions import ( IdempotencyModelTypeError, From 05a96d82a7403722f10f12004cc183faf8b136b8 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 23 Dec 2024 12:16:33 +0000 Subject: [PATCH 4/5] Making it work in python 3.8 and 3.9 --- .../idempotency/serialization/functions.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py index 484bfabbb56..9a1934d3e35 100644 --- a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py +++ b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py @@ -1,9 +1,11 @@ +import sys from typing import Any, Optional, Union, get_args, get_origin -try: - from types import UnionType -except ImportError: - UnionType = None # type: ignore[assignment, misc] +# Conditionally import or define UnionType based on Python version +if sys.version_info >= (3, 10): + from types import UnionType # Available in Python 3.10+ +else: + UnionType = Union # Fallback for Python 3.8 and 3.9 from aws_lambda_powertools.utilities.idempotency.exceptions import ( IdempotencyModelTypeError, @@ -26,19 +28,30 @@ def get_actual_type(model_type: Any) -> Any: IdempotencyModelTypeError: If the type specification is invalid (e.g., Union with multiple non-None types). """ - # Check if the type is Optional, Union, or the new Union syntax - if get_origin(model_type) in (Optional, Union) or (UnionType is not None and get_origin(model_type) is UnionType): - # Get the arguments of the type (e.g., for Optional[int], this would be (int, NoneType)) + + # Get the origin of the type (e.g., Union, Optional) + origin = get_origin(model_type) + + # Check if type is Union, Optional, or UnionType (Python 3.10+) + if origin in (Union, Optional) or (sys.version_info >= (3, 10) and isinstance(origin, UnionType)): + # Get type arguments args = get_args(model_type) - # Filter out NoneType to get the actual type(s) - actual_type = [arg for arg in args if arg is not type(None)] + # Filter out NoneType + actual_type = _extract_non_none_types(args) - # Ensure there's exactly one non-None type + # Ensure only one non-None type exists if len(actual_type) != 1: raise IdempotencyModelTypeError( "Invalid type: expected a single type, optionally wrapped in Optional or Union with None.", ) + return actual_type[0] + # If not a Union/Optional type, return original type return model_type + + +def _extract_non_none_types(args: tuple) -> list: + """Extract non-None types from type arguments.""" + return [arg for arg in args if arg is not type(None)] From 2c043c39906ce7a21558ffa304ee08dadda77871 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 23 Dec 2024 12:51:28 +0000 Subject: [PATCH 5/5] Making it work in python 3.8 and 3.9 --- .../utilities/idempotency/serialization/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py index 9a1934d3e35..72a8d6940c9 100644 --- a/aws_lambda_powertools/utilities/idempotency/serialization/functions.py +++ b/aws_lambda_powertools/utilities/idempotency/serialization/functions.py @@ -33,7 +33,7 @@ def get_actual_type(model_type: Any) -> Any: origin = get_origin(model_type) # Check if type is Union, Optional, or UnionType (Python 3.10+) - if origin in (Union, Optional) or (sys.version_info >= (3, 10) and isinstance(origin, UnionType)): + if origin in (Union, Optional) or (sys.version_info >= (3, 10) and origin in (Union, UnionType)): # Get type arguments args = get_args(model_type)