Skip to content

Commit 7a7f10c

Browse files
fix(idempotency): add support for Optional type when serializing output (#5590)
* Accepting None when working with output serialization * Fix Python3.8/3.9 * Make mypy happy * Making it work in python 3.8 and 3.9 * Making it work in python 3.8 and 3.9
1 parent 1261c07 commit 7a7f10c

File tree

6 files changed

+163
-2
lines changed

6 files changed

+163
-2
lines changed

aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
BaseIdempotencyModelSerializer,
1212
BaseIdempotencySerializer,
1313
)
14+
from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type
1415

1516
DataClass = Any
1617

@@ -37,6 +38,9 @@ def from_dict(self, data: dict) -> DataClass:
3738

3839
@classmethod
3940
def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer:
41+
42+
model_type = get_actual_type(model_type=model_type)
43+
4044
if model_type is None:
4145
raise IdempotencyNoSerializationModelError("No serialization model was supplied")
4246

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import sys
2+
from typing import Any, Optional, Union, get_args, get_origin
3+
4+
# Conditionally import or define UnionType based on Python version
5+
if sys.version_info >= (3, 10):
6+
from types import UnionType # Available in Python 3.10+
7+
else:
8+
UnionType = Union # Fallback for Python 3.8 and 3.9
9+
10+
from aws_lambda_powertools.utilities.idempotency.exceptions import (
11+
IdempotencyModelTypeError,
12+
)
13+
14+
15+
def get_actual_type(model_type: Any) -> Any:
16+
"""
17+
Extract the actual type from a potentially Optional or Union type.
18+
This function handles types that may be wrapped in Optional or Union,
19+
including the Python 3.10+ Union syntax (Type | None).
20+
Parameters
21+
----------
22+
model_type: Any
23+
The type to analyze. Can be a simple type, Optional[Type], BaseModel, dataclass
24+
Returns
25+
-------
26+
The actual type without Optional or Union wrappers.
27+
Raises:
28+
IdempotencyModelTypeError: If the type specification is invalid
29+
(e.g., Union with multiple non-None types).
30+
"""
31+
32+
# Get the origin of the type (e.g., Union, Optional)
33+
origin = get_origin(model_type)
34+
35+
# Check if type is Union, Optional, or UnionType (Python 3.10+)
36+
if origin in (Union, Optional) or (sys.version_info >= (3, 10) and origin in (Union, UnionType)):
37+
# Get type arguments
38+
args = get_args(model_type)
39+
40+
# Filter out NoneType
41+
actual_type = _extract_non_none_types(args)
42+
43+
# Ensure only one non-None type exists
44+
if len(actual_type) != 1:
45+
raise IdempotencyModelTypeError(
46+
"Invalid type: expected a single type, optionally wrapped in Optional or Union with None.",
47+
)
48+
49+
return actual_type[0]
50+
51+
# If not a Union/Optional type, return original type
52+
return model_type
53+
54+
55+
def _extract_non_none_types(args: tuple) -> list:
56+
"""Extract non-None types from type arguments."""
57+
return [arg for arg in args if arg is not type(None)]

aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
BaseIdempotencyModelSerializer,
1313
BaseIdempotencySerializer,
1414
)
15+
from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type
1516

1617

1718
class PydanticSerializer(BaseIdempotencyModelSerializer):
@@ -34,6 +35,9 @@ def from_dict(self, data: dict) -> BaseModel:
3435

3536
@classmethod
3637
def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer:
38+
39+
model_type = get_actual_type(model_type=model_type)
40+
3741
if model_type is None:
3842
raise IdempotencyNoSerializationModelError("No serialization model was supplied")
3943

docs/utilities/idempotency.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,10 @@ By default, `idempotent_function` serializes, stores, and returns your annotated
212212

213213
The output serializer supports any JSON serializable data, **Python Dataclasses** and **Pydantic Models**.
214214

215-
!!! info "When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string."
215+
!!! info
216+
When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string.
217+
218+
Function returns must be annotated with a single type, optionally wrapped in `Optional` or `Union` with `None`.
216219

217220
=== "Pydantic"
218221

tests/functional/idempotency/_boto3/test_idempotency.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import datetime
33
import warnings
4-
from typing import Any
4+
from typing import Any, Optional
55
from unittest.mock import MagicMock, Mock
66

77
import jmespath
@@ -2014,3 +2014,50 @@ def lambda_handler(event, context):
20142014

20152015
stubber.assert_no_pending_responses()
20162016
stubber.deactivate()
2017+
2018+
2019+
@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"])
2020+
def test_idempotent_function_serialization_dataclass_with_optional_return(output_serializer_type: str):
2021+
# GIVEN
2022+
dataclasses = get_dataclasses_lib()
2023+
config = IdempotencyConfig(use_local_cache=True)
2024+
mock_event = {"customer_id": "fake", "transaction_id": "fake-id"}
2025+
idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_dataclass_with_optional_return.<locals>.collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501
2026+
persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key)
2027+
2028+
@dataclasses.dataclass
2029+
class PaymentInput:
2030+
customer_id: str
2031+
transaction_id: str
2032+
2033+
@dataclasses.dataclass
2034+
class PaymentOutput:
2035+
customer_id: str
2036+
transaction_id: str
2037+
2038+
if output_serializer_type == "explicit":
2039+
output_serializer = DataclassSerializer(
2040+
model=PaymentOutput,
2041+
)
2042+
else:
2043+
output_serializer = DataclassSerializer
2044+
2045+
@idempotent_function(
2046+
data_keyword_argument="payment",
2047+
persistence_store=persistence_layer,
2048+
config=config,
2049+
output_serializer=output_serializer,
2050+
)
2051+
def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]:
2052+
return PaymentOutput(**dataclasses.asdict(payment))
2053+
2054+
# WHEN
2055+
payment = PaymentInput(**mock_event)
2056+
first_call: PaymentOutput = collect_payment(payment=payment)
2057+
assert first_call.customer_id == payment.customer_id
2058+
assert first_call.transaction_id == payment.transaction_id
2059+
assert isinstance(first_call, PaymentOutput)
2060+
second_call: PaymentOutput = collect_payment(payment=payment)
2061+
assert isinstance(second_call, PaymentOutput)
2062+
assert second_call.customer_id == payment.customer_id
2063+
assert second_call.transaction_id == payment.transaction_id

tests/functional/idempotency/_pydantic/test_idempotency_with_pydantic.py

+46
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import pytest
24
from pydantic import BaseModel
35

@@ -219,3 +221,47 @@ def collect_payment(payment: Payment):
219221

220222
# THEN idempotency key assertion happens at MockPersistenceLayer
221223
assert result == payment.transaction_id
224+
225+
226+
@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"])
227+
def test_idempotent_function_serialization_pydantic_with_optional_return(output_serializer_type: str):
228+
# GIVEN
229+
config = IdempotencyConfig(use_local_cache=True)
230+
mock_event = {"customer_id": "fake", "transaction_id": "fake-id"}
231+
idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic_with_optional_return.<locals>.collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501
232+
persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key)
233+
234+
class PaymentInput(BaseModel):
235+
customer_id: str
236+
transaction_id: str
237+
238+
class PaymentOutput(BaseModel):
239+
customer_id: str
240+
transaction_id: str
241+
242+
if output_serializer_type == "explicit":
243+
output_serializer = PydanticSerializer(
244+
model=PaymentOutput,
245+
)
246+
else:
247+
output_serializer = PydanticSerializer
248+
249+
@idempotent_function(
250+
data_keyword_argument="payment",
251+
persistence_store=persistence_layer,
252+
config=config,
253+
output_serializer=output_serializer,
254+
)
255+
def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]:
256+
return PaymentOutput(**payment.dict())
257+
258+
# WHEN
259+
payment = PaymentInput(**mock_event)
260+
first_call: PaymentOutput = collect_payment(payment=payment)
261+
assert first_call.customer_id == payment.customer_id
262+
assert first_call.transaction_id == payment.transaction_id
263+
assert isinstance(first_call, PaymentOutput)
264+
second_call: PaymentOutput = collect_payment(payment=payment)
265+
assert isinstance(second_call, PaymentOutput)
266+
assert second_call.customer_id == payment.customer_id
267+
assert second_call.transaction_id == payment.transaction_id

0 commit comments

Comments
 (0)