diff --git a/aws_lambda_powertools/shared/dynamodb_deserializer.py b/aws_lambda_powertools/shared/dynamodb_deserializer.py new file mode 100644 index 00000000000..b17344345c1 --- /dev/null +++ b/aws_lambda_powertools/shared/dynamodb_deserializer.py @@ -0,0 +1,94 @@ +from decimal import Clamped, Context, Decimal, Inexact, Overflow, Rounded, Underflow +from typing import Any, Callable, Dict, Optional, Sequence, Set + +# NOTE: DynamoDB supports up to 38 digits precision +# Therefore, this ensures our Decimal follows what's stored in the table +DYNAMODB_CONTEXT = Context( + Emin=-128, + Emax=126, + prec=38, + traps=[Clamped, Overflow, Inexact, Rounded, Underflow], +) + + +class TypeDeserializer: + """ + Deserializes DynamoDB types to Python types. + + It's based on boto3's [DynamoDB TypeDeserializer](https://boto3.amazonaws.com/v1/documentation/api/latest/_modules/boto3/dynamodb/types.html). + + The only notable difference is that for Binary (`B`, `BS`) values we return Python Bytes directly, + since we don't support Python 2. + """ + + def deserialize(self, value: Dict) -> Any: + """Deserialize DynamoDB data types into Python types. + + Parameters + ---------- + value: Any + DynamoDB value to be deserialized to a python type + + + Here are the various conversions: + + DynamoDB Python + -------- ------ + {'NULL': True} None + {'BOOL': True/False} True/False + {'N': Decimal(value)} Decimal(value) + {'S': string} string + {'B': bytes} bytes + {'NS': [str(value)]} set([str(value)]) + {'SS': [string]} set([string]) + {'BS': [bytes]} set([bytes]) + {'L': list} list + {'M': dict} dict + + Parameters + ---------- + value: Any + DynamoDB value to be deserialized to a python type + + Returns + -------- + any + Python native type converted from DynamoDB type + """ + + dynamodb_type = list(value.keys())[0] + deserializer: Optional[Callable] = getattr(self, f"_deserialize_{dynamodb_type}".lower(), None) + if deserializer is None: + raise TypeError(f"Dynamodb type {dynamodb_type} is not supported") + + return deserializer(value[dynamodb_type]) + + def _deserialize_null(self, value: bool) -> None: + return None + + def _deserialize_bool(self, value: bool) -> bool: + return value + + def _deserialize_n(self, value: str) -> Decimal: + return DYNAMODB_CONTEXT.create_decimal(value) + + def _deserialize_s(self, value: str) -> str: + return value + + def _deserialize_b(self, value: bytes) -> bytes: + return value + + def _deserialize_ns(self, value: Sequence[str]) -> Set[Decimal]: + return set(map(self._deserialize_n, value)) + + def _deserialize_ss(self, value: Sequence[str]) -> Set[str]: + return set(map(self._deserialize_s, value)) + + def _deserialize_bs(self, value: Sequence[bytes]) -> Set[bytes]: + return set(map(self._deserialize_b, value)) + + def _deserialize_l(self, value: Sequence[Dict]) -> Sequence[Any]: + return [self.deserialize(v) for v in value] + + def _deserialize_m(self, value: Dict) -> Dict: + return {k: self.deserialize(v) for k, v in value.items()} diff --git a/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py index 7339ed33fce..d0d1bd7ab41 100644 --- a/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py @@ -1,101 +1,9 @@ -from decimal import Clamped, Context, Decimal, Inexact, Overflow, Rounded, Underflow from enum import Enum -from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Set +from typing import Any, Dict, Iterator, Optional +from aws_lambda_powertools.shared.dynamodb_deserializer import TypeDeserializer from aws_lambda_powertools.utilities.data_classes.common import DictWrapper -# NOTE: DynamoDB supports up to 38 digits precision -# Therefore, this ensures our Decimal follows what's stored in the table -DYNAMODB_CONTEXT = Context( - Emin=-128, - Emax=126, - prec=38, - traps=[Clamped, Overflow, Inexact, Rounded, Underflow], -) - - -class TypeDeserializer: - """ - Deserializes DynamoDB types to Python types. - - It's based on boto3's [DynamoDB TypeDeserializer](https://boto3.amazonaws.com/v1/documentation/api/latest/_modules/boto3/dynamodb/types.html). - - The only notable difference is that for Binary (`B`, `BS`) values we return Python Bytes directly, - since we don't support Python 2. - """ - - def deserialize(self, value: Dict) -> Any: - """Deserialize DynamoDB data types into Python types. - - Parameters - ---------- - value: Any - DynamoDB value to be deserialized to a python type - - - Here are the various conversions: - - DynamoDB Python - -------- ------ - {'NULL': True} None - {'BOOL': True/False} True/False - {'N': Decimal(value)} Decimal(value) - {'S': string} string - {'B': bytes} bytes - {'NS': [str(value)]} set([str(value)]) - {'SS': [string]} set([string]) - {'BS': [bytes]} set([bytes]) - {'L': list} list - {'M': dict} dict - - Parameters - ---------- - value: Any - DynamoDB value to be deserialized to a python type - - Returns - -------- - any - Python native type converted from DynamoDB type - """ - - dynamodb_type = list(value.keys())[0] - deserializer: Optional[Callable] = getattr(self, f"_deserialize_{dynamodb_type}".lower(), None) - if deserializer is None: - raise TypeError(f"Dynamodb type {dynamodb_type} is not supported") - - return deserializer(value[dynamodb_type]) - - def _deserialize_null(self, value: bool) -> None: - return None - - def _deserialize_bool(self, value: bool) -> bool: - return value - - def _deserialize_n(self, value: str) -> Decimal: - return DYNAMODB_CONTEXT.create_decimal(value) - - def _deserialize_s(self, value: str) -> str: - return value - - def _deserialize_b(self, value: bytes) -> bytes: - return value - - def _deserialize_ns(self, value: Sequence[str]) -> Set[Decimal]: - return set(map(self._deserialize_n, value)) - - def _deserialize_ss(self, value: Sequence[str]) -> Set[str]: - return set(map(self._deserialize_s, value)) - - def _deserialize_bs(self, value: Sequence[bytes]) -> Set[bytes]: - return set(map(self._deserialize_b, value)) - - def _deserialize_l(self, value: Sequence[Dict]) -> Sequence[Any]: - return [self.deserialize(v) for v in value] - - def _deserialize_m(self, value: Dict) -> Dict: - return {k: self.deserialize(v) for k, v in value.items()} - class StreamViewType(Enum): """The type of data from the modified DynamoDB item that was captured in this stream record""" diff --git a/aws_lambda_powertools/utilities/parser/models/dynamodb.py b/aws_lambda_powertools/utilities/parser/models/dynamodb.py index 4f2de87fadb..7a3581ab13f 100644 --- a/aws_lambda_powertools/utilities/parser/models/dynamodb.py +++ b/aws_lambda_powertools/utilities/parser/models/dynamodb.py @@ -1,14 +1,17 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Type, Union -from pydantic import BaseModel +from pydantic import BaseModel, field_validator +from aws_lambda_powertools.shared.dynamodb_deserializer import TypeDeserializer from aws_lambda_powertools.utilities.parser.types import Literal +_DESERIALIZER = TypeDeserializer() + class DynamoDBStreamChangedRecordModel(BaseModel): ApproximateCreationDateTime: Optional[datetime] = None - Keys: Dict[str, Dict[str, Any]] + Keys: Dict[str, Any] NewImage: Optional[Union[Dict[str, Any], Type[BaseModel], BaseModel]] = None OldImage: Optional[Union[Dict[str, Any], Type[BaseModel], BaseModel]] = None SequenceNumber: str @@ -26,6 +29,10 @@ class DynamoDBStreamChangedRecordModel(BaseModel): # raise TypeError("DynamoDB streams model failed validation, missing both new & old stream images") # noqa: ERA001,E501 # return values # noqa: ERA001 + @field_validator("Keys", "NewImage", "OldImage", mode="before") + def deserialize_field(cls, value): + return {k: _DESERIALIZER.deserialize(v) for k, v in value.items()} + class UserIdentity(BaseModel): type: Literal["Service"] # noqa: VNE003, A003 diff --git a/tests/functional/batch/sample_models.py b/tests/functional/batch/sample_models.py index 212dad2c754..c2912b3f8a3 100644 --- a/tests/functional/batch/sample_models.py +++ b/tests/functional/batch/sample_models.py @@ -38,7 +38,7 @@ class OrderDynamoDB(BaseModel): @field_validator("Message", mode="before") def transform_message_to_dict(cls, value: Dict[Literal["S"], str]): try: - return json.loads(value["S"]) + return json.loads(value) except TypeError: raise ValueError diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index fd62fdf2624..af8b3b0196b 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -526,7 +526,7 @@ class OrderDynamoDB(BaseModel): # so Pydantic can auto-initialize nested Order model @field_validator("Message", mode="before") def transform_message_to_dict(cls, value: Dict[Literal["S"], str]): - return json.loads(value["S"]) + return json.loads(value) class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel): NewImage: Optional[OrderDynamoDB] = None @@ -570,7 +570,7 @@ class OrderDynamoDB(BaseModel): # so Pydantic can auto-initialize nested Order model @field_validator("Message", mode="before") def transform_message_to_dict(cls, value: Dict[Literal["S"], str]): - return json.loads(value["S"]) + return json.loads(value) class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel): NewImage: Optional[OrderDynamoDB] = None diff --git a/tests/unit/parser/schemas.py b/tests/unit/parser/schemas.py index 65499d319ae..b4b69135ff9 100644 --- a/tests/unit/parser/schemas.py +++ b/tests/unit/parser/schemas.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import List, Optional from pydantic import BaseModel @@ -13,12 +13,11 @@ SqsModel, SqsRecordModel, ) -from aws_lambda_powertools.utilities.parser.types import Literal class MyDynamoBusiness(BaseModel): - Message: Dict[Literal["S"], str] - Id: Dict[Literal["N"], int] + Message: str + Id: int class MyDynamoScheme(DynamoDBStreamChangedRecordModel): diff --git a/tests/unit/parser/test_dynamodb.py b/tests/unit/parser/test_dynamodb.py index abbcd152d6b..1a54c2d1991 100644 --- a/tests/unit/parser/test_dynamodb.py +++ b/tests/unit/parser/test_dynamodb.py @@ -21,19 +21,19 @@ def test_dynamo_db_stream_trigger_event(): new_image = parserd_event[0]["NewImage"] new_image_raw = raw_event["Records"][0]["dynamodb"]["NewImage"] - assert new_image.Message["S"] == new_image_raw["Message"]["S"] - assert new_image.Id["N"] == float(new_image_raw["Id"]["N"]) + assert new_image.Message == new_image_raw["Message"]["S"] + assert new_image.Id == float(new_image_raw["Id"]["N"]) # record index 1 old_image = parserd_event[1]["OldImage"] old_image_raw = raw_event["Records"][1]["dynamodb"]["OldImage"] - assert old_image.Message["S"] == old_image_raw["Message"]["S"] - assert old_image.Id["N"] == float(old_image_raw["Id"]["N"]) + assert old_image.Message == old_image_raw["Message"]["S"] + assert old_image.Id == float(old_image_raw["Id"]["N"]) new_image = parserd_event[1]["NewImage"] new_image_raw = raw_event["Records"][1]["dynamodb"]["NewImage"] - assert new_image.Message["S"] == new_image_raw["Message"]["S"] - assert new_image.Id["N"] == float(new_image_raw["Id"]["N"]) + assert new_image.Message == new_image_raw["Message"]["S"] + assert new_image.Id == float(new_image_raw["Id"]["N"]) def test_dynamo_db_stream_trigger_event_no_envelope(): @@ -65,12 +65,12 @@ def test_dynamo_db_stream_trigger_event_no_envelope(): keys = dynamodb.Keys raw_keys = raw_dynamodb["Keys"] assert keys is not None - id_key = keys["Id"] - assert id_key["N"] == raw_keys["Id"]["N"] + id_key = keys.get("Id") + assert id_key == int(raw_keys["Id"]["N"]) message_key = dynamodb.NewImage.Message assert message_key is not None - assert message_key["S"] == "New item!" + assert message_key == "New item!" def test_validate_event_does_not_conform_with_model_no_envelope(): diff --git a/tests/unit/shared/test_dynamodb_deserializer.py b/tests/unit/shared/test_dynamodb_deserializer.py new file mode 100644 index 00000000000..8c96b1745d2 --- /dev/null +++ b/tests/unit/shared/test_dynamodb_deserializer.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Optional + +import pytest + +from aws_lambda_powertools.shared.dynamodb_deserializer import TypeDeserializer + + +class DeserialiserModel: + def __init__(self, data: dict): + self._data = data + self._deserializer = TypeDeserializer() + + def _deserialize_dynamodb_dict(self) -> Optional[Dict[str, Any]]: + if self._data is None: + return None + + return {k: self._deserializer.deserialize(v) for k, v in self._data.items()} + + @property + def data(self) -> Optional[Dict[str, Any]]: + """The primary key attribute(s) for the DynamoDB item that was modified.""" + return self._deserialize_dynamodb_dict() + + +def test_deserializer(): + model = DeserialiserModel( + { + "Id": {"S": "Id-123"}, + "Name": {"S": "John Doe"}, + "ZipCode": {"N": 12345}, + "Things": {"L": [{"N": 0}, {"N": 1}, {"N": 2}, {"N": 3}]}, + "MoreThings": {"M": {"a": {"S": "foo"}, "b": {"S": "bar"}}}, + }, + ) + + assert model.data.get("Id") == "Id-123" + assert model.data.get("Name") == "John Doe" + assert model.data.get("ZipCode") == 12345 + assert model.data.get("Things") == [0, 1, 2, 3] + assert model.data.get("MoreThings") == {"a": "foo", "b": "bar"} + + +def test_deserializer_error(): + model = DeserialiserModel( + { + "Id": {"X": None}, + }, + ) + + with pytest.raises(TypeError): + model.data.get("Id")