Skip to content

feat(event_source): support custom json_deserializer; add json_body in SQSEvent #2200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
7 changes: 4 additions & 3 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import base64
import json
from collections.abc import Mapping
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Callable, Dict, Iterator, List, Optional

from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer


class DictWrapper(Mapping):
"""Provides a single read only access to a wrapper dict"""

def __init__(self, data: Dict[str, Any]):
def __init__(self, data: Dict[str, Any], json_deserializer: Optional[Callable] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring ;) we apparently missed back in the days for DictWrapper, but given the json_deserializer will be added to all data classes, we can propagate the docstring to everyone benefits from it.

pushing that change

image

self._data = data
self._json_data: Optional[Any] = None
self._json_deserializer = json_deserializer or json.loads

def __getitem__(self, key: str) -> Any:
return self._data[key]
Expand Down Expand Up @@ -122,7 +123,7 @@ def body(self) -> Optional[str]:
def json_body(self) -> Any:
"""Parses the submitted body as json"""
if self._json_data is None:
self._json_data = json.loads(self.decoded_body)
self._json_data = self._json_deserializer(self.decoded_body)
return self._json_data

@property
Expand Down
5 changes: 2 additions & 3 deletions aws_lambda_powertools/utilities/data_classes/kafka_event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import json
from typing import Any, Dict, Iterator, List, Optional

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
Expand Down Expand Up @@ -55,7 +54,7 @@ def decoded_value(self) -> bytes:
def json_value(self) -> Any:
"""Decodes the text encoded data as JSON."""
if self._json_data is None:
self._json_data = json.loads(self.decoded_value.decode("utf-8"))
self._json_data = self._json_deserializer(self.decoded_value.decode("utf-8"))
return self._json_data

@property
Expand Down Expand Up @@ -117,7 +116,7 @@ def records(self) -> Iterator[KafkaEventRecord]:
"""The Kafka records."""
for chunk in self["records"].values():
for record in chunk:
yield KafkaEventRecord(record)
yield KafkaEventRecord(data=record, json_deserializer=self._json_deserializer)

@property
def record(self) -> KafkaEventRecord:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import json
from typing import Iterator, Optional

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
Expand Down Expand Up @@ -75,7 +74,7 @@ def data_as_text(self) -> str:
def data_as_json(self) -> dict:
"""Decoded base64-encoded data loaded to json"""
if self._json_data is None:
self._json_data = json.loads(self.data_as_text)
self._json_data = self._json_deserializer(self.data_as_text)
return self._json_data


Expand Down Expand Up @@ -110,4 +109,4 @@ def region(self) -> str:
@property
def records(self) -> Iterator[KinesisFirehoseRecord]:
for record in self["records"]:
yield KinesisFirehoseRecord(record)
yield KinesisFirehoseRecord(data=record, json_deserializer=self._json_deserializer)
12 changes: 11 additions & 1 deletion aws_lambda_powertools/utilities/data_classes/sqs_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def body(self) -> str:
"""The message's contents (not URL-encoded)."""
return self["body"]

@property
def json_body(self) -> Dict:
"""Parses the submitted body as json"""
try:
if self._json_data is None:
self._json_data = self._json_deserializer(self["body"])
return self._json_data
except Exception:
return self["body"]
Copy link
Contributor Author

@leandrodamascena leandrodamascena May 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I add a try/except block here is because it could be a simple string that cannot be deserialized as json.
Should we consider this for the other json_body fields in Kinesis/Kafka/APIGW? Should we return the plain body if the deserialization fails?
I would like to hear from you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fail hard here to give the caller a chance to handle this error correctly.

Reasoning in a mental model fashion:

  • Do we offer a non-JSON body property e.g., body?
    • If so, json_body should fail if it can't deserialize it
  • If the customer received a poison pill record, a non-JSON string, will they be given the chance to handle a JSON deserialization error?
    • If we return self["body"], it can result to a totally different error we can't even anticipate, e.g., json_body["blah"]

As a rule of thumb, you'd want to propagate the error as close to the actual problem so they instinctively know how to handle it. If we catch all, we're masking the actual error and leading them to an unknown error later on.

tiny note: until Python 3.11, try/catch has a performance overhead even if it's tiny, so we can also improve here by letting them do it (not accidentally twice).


@property
def attributes(self) -> SQSRecordAttributes:
"""A map of the attributes requested in ReceiveMessage to their respective values."""
Expand Down Expand Up @@ -157,4 +167,4 @@ class SQSEvent(DictWrapper):
@property
def records(self) -> Iterator[SQSRecord]:
for record in self["Records"]:
yield SQSRecord(record)
yield SQSRecord(data=record, json_deserializer=self._json_deserializer)
4 changes: 2 additions & 2 deletions tests/events/sqsEvent.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
{
"messageId": "2e1424d4-f796-459a-8184-9c92662be6da",
"receiptHandle": "AQEBzWwaftRI0KuVm4tP+/7q1rGgNqicHq...",
"body": "Test message2.",
"body": "{\"message\": \"foo1\"}",
"attributes": {
"ApproximateReceiveCount": "1",
"SentTimestamp": "1545082650636",
Expand All @@ -39,4 +39,4 @@
"awsRegion": "us-east-2"
}
]
}
}
43 changes: 43 additions & 0 deletions tests/functional/test_data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,46 @@ def message(self) -> str:
assert DataClassSample(data1).raw_event is data1


def test_dict_wrapper_with_default_custom_json_deserializer():
class DataClassSample(DictWrapper):
@property
def json_body(self) -> dict:
return self._json_deserializer(self["body"])

data = {"body": '{"message": "foo1"}'}
event = DataClassSample(data=data)
assert (event.json_body) == {"message": "foo1"}


def test_dict_wrapper_with_valid_custom_json_deserializer():
class DataClassSample(DictWrapper):
@property
def json_body(self) -> dict:
return self._json_deserializer(self["body"])

def fake_json_deserializer(record: dict):
return json.loads(record)

data = {"body": '{"message": "foo1"}'}
event = DataClassSample(data=data, json_deserializer=fake_json_deserializer)
assert (event.json_body) == {"message": "foo1"}


def test_dict_wrapper_with_wrong_custom_json_deserializer():
class DataClassSample(DictWrapper):
@property
def json_body(self) -> dict:
return self._json_deserializer(self["body"])

def fake_json_deserializer() -> None:
pass

data = {"body": {"message": "foo1"}}
with pytest.raises(TypeError):
event = DataClassSample(data=data, json_deserializer=fake_json_deserializer)
assert (event.json_body) == {"message": "foo1"}


def test_dict_wrapper_implements_mapping():
class DataClassSample(DictWrapper):
pass
Expand Down Expand Up @@ -926,6 +966,9 @@ def test_seq_trigger_event():
assert record.queue_url == "https://sqs.us-east-2.amazonaws.com/123456789012/my-queue"
assert record.aws_region == "us-east-2"

record_2 = records[1]
assert record_2.json_body == {"message": "foo1"}


def test_default_api_gateway_proxy_event():
event = APIGatewayProxyEvent(load_event("apiGatewayProxyEvent_noVersionAuth.json"))
Expand Down