diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 5c1fea14731..d1ce8f90a07 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -1,7 +1,7 @@ import base64 import json from collections.abc import Mapping -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Callable, Dict, Iterator, List, Optional from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer @@ -9,9 +9,19 @@ class DictWrapper(Mapping): """Provides a single read only access to a wrapper dict""" - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: Dict[str, Any], json_deserializer: Optional[Callable] = None): + """ + Parameters + ---------- + data : Dict[str, Any] + Lambda Event Source Event payload + json_deserializer : Callable, optional + function to deserialize `str`, `bytes`, bytearray` containing a JSON document to a Python `obj`, + by default json.loads + """ self._data = data self._json_data: Optional[Any] = None + self._json_deserializer = json_deserializer or json.loads def __getitem__(self, key: str) -> Any: return self._data[key] @@ -122,7 +132,7 @@ def body(self) -> Optional[str]: def json_body(self) -> Any: """Parses the submitted body as json""" if self._json_data is None: - self._json_data = json.loads(self.decoded_body) + self._json_data = self._json_deserializer(self.decoded_body) return self._json_data @property diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index e52cc5d8dc1..4773d9e50de 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -1,5 +1,4 @@ import base64 -import json from typing import Any, Dict, Iterator, List, Optional from aws_lambda_powertools.utilities.data_classes.common import DictWrapper @@ -55,7 +54,7 @@ def decoded_value(self) -> bytes: def json_value(self) -> Any: """Decodes the text encoded data as JSON.""" if self._json_data is None: - self._json_data = json.loads(self.decoded_value.decode("utf-8")) + self._json_data = self._json_deserializer(self.decoded_value.decode("utf-8")) return self._json_data @property @@ -117,7 +116,7 @@ def records(self) -> Iterator[KafkaEventRecord]: """The Kafka records.""" for chunk in self["records"].values(): for record in chunk: - yield KafkaEventRecord(record) + yield KafkaEventRecord(data=record, json_deserializer=self._json_deserializer) @property def record(self) -> KafkaEventRecord: diff --git a/aws_lambda_powertools/utilities/data_classes/kinesis_firehose_event.py b/aws_lambda_powertools/utilities/data_classes/kinesis_firehose_event.py index 5683902f9d0..47dc196856d 100644 --- a/aws_lambda_powertools/utilities/data_classes/kinesis_firehose_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kinesis_firehose_event.py @@ -1,5 +1,4 @@ import base64 -import json from typing import Iterator, Optional from aws_lambda_powertools.utilities.data_classes.common import DictWrapper @@ -75,7 +74,7 @@ def data_as_text(self) -> str: def data_as_json(self) -> dict: """Decoded base64-encoded data loaded to json""" if self._json_data is None: - self._json_data = json.loads(self.data_as_text) + self._json_data = self._json_deserializer(self.data_as_text) return self._json_data @@ -110,4 +109,4 @@ def region(self) -> str: @property def records(self) -> Iterator[KinesisFirehoseRecord]: for record in self["records"]: - yield KinesisFirehoseRecord(record) + yield KinesisFirehoseRecord(data=record, json_deserializer=self._json_deserializer) diff --git a/aws_lambda_powertools/utilities/data_classes/sqs_event.py b/aws_lambda_powertools/utilities/data_classes/sqs_event.py index 7d0dbe49352..2b3224358d8 100644 --- a/aws_lambda_powertools/utilities/data_classes/sqs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/sqs_event.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterator, Optional +from typing import Any, Dict, Iterator, Optional from aws_lambda_powertools.utilities.data_classes.common import DictWrapper @@ -103,6 +103,35 @@ def body(self) -> str: """The message's contents (not URL-encoded).""" return self["body"] + @property + def json_body(self) -> Any: + """Deserializes JSON string available in 'body' property + + Notes + ----- + + **Strict typing** + + Caller controls the type as we can't use recursive generics here. + + JSON Union types would force caller to have to cast a type. Instead, + we choose Any to ease ergonomics and other tools receiving this data. + + Examples + -------- + + **Type deserialized data from JSON string** + + ```python + data: dict = record.json_body # {"telemetry": [], ...} + # or + data: list = record.json_body # ["telemetry_values"] + ``` + """ + if self._json_data is None: + self._json_data = self._json_deserializer(self["body"]) + return self._json_data + @property def attributes(self) -> SQSRecordAttributes: """A map of the attributes requested in ReceiveMessage to their respective values.""" @@ -157,4 +186,4 @@ class SQSEvent(DictWrapper): @property def records(self) -> Iterator[SQSRecord]: for record in self["Records"]: - yield SQSRecord(record) + yield SQSRecord(data=record, json_deserializer=self._json_deserializer) diff --git a/tests/events/sqsEvent.json b/tests/events/sqsEvent.json index ef03b128943..2bfcd1c7b8f 100644 --- a/tests/events/sqsEvent.json +++ b/tests/events/sqsEvent.json @@ -25,7 +25,7 @@ { "messageId": "2e1424d4-f796-459a-8184-9c92662be6da", "receiptHandle": "AQEBzWwaftRI0KuVm4tP+/7q1rGgNqicHq...", - "body": "Test message2.", + "body": "{\"message\": \"foo1\"}", "attributes": { "ApproximateReceiveCount": "1", "SentTimestamp": "1545082650636", @@ -39,4 +39,4 @@ "awsRegion": "us-east-2" } ] -} \ No newline at end of file +} diff --git a/tests/functional/test_data_classes.py b/tests/functional/test_data_classes.py index 068e8738fad..b3a24b0865a 100644 --- a/tests/functional/test_data_classes.py +++ b/tests/functional/test_data_classes.py @@ -113,6 +113,47 @@ def message(self) -> str: assert DataClassSample(data1).raw_event is data1 +def test_dict_wrapper_with_default_custom_json_deserializer(): + class DataClassSample(DictWrapper): + @property + def json_body(self) -> dict: + return self._json_deserializer(self["body"]) + + data = {"body": '{"message": "foo1"}'} + event = DataClassSample(data=data) + assert event.json_body == json.loads(data["body"]) + + +def test_dict_wrapper_with_valid_custom_json_deserializer(): + class DataClassSample(DictWrapper): + @property + def json_body(self) -> dict: + return self._json_deserializer(self["body"]) + + def fake_json_deserializer(record: dict): + return json.loads(record) + + data = {"body": '{"message": "foo1"}'} + event = DataClassSample(data=data, json_deserializer=fake_json_deserializer) + assert event.json_body == json.loads(data["body"]) + + +def test_dict_wrapper_with_invalid_custom_json_deserializer(): + class DataClassSample(DictWrapper): + @property + def json_body(self) -> dict: + return self._json_deserializer(self["body"]) + + def fake_json_deserializer() -> None: + # invalid fn signature should raise TypeError + pass + + data = {"body": {"message": "foo1"}} + with pytest.raises(TypeError): + event = DataClassSample(data=data, json_deserializer=fake_json_deserializer) + assert event.json_body == {"message": "foo1"} + + def test_dict_wrapper_implements_mapping(): class DataClassSample(DictWrapper): pass @@ -926,6 +967,9 @@ def test_seq_trigger_event(): assert record.queue_url == "https://sqs.us-east-2.amazonaws.com/123456789012/my-queue" assert record.aws_region == "us-east-2" + record_2 = records[1] + assert record_2.json_body == {"message": "foo1"} + def test_default_api_gateway_proxy_event(): event = APIGatewayProxyEvent(load_event("apiGatewayProxyEvent_noVersionAuth.json"))