Skip to content

feat(event_source): decode nested messages on SQS events #2349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 30, 2023
24 changes: 12 additions & 12 deletions aws_lambda_powertools/utilities/data_classes/sns_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,63 +20,63 @@ class SNSMessage(DictWrapper):
@property
def signature_version(self) -> str:
"""Version of the Amazon SNS signature used."""
return self["Sns"]["SignatureVersion"]
return self["SignatureVersion"]

@property
def timestamp(self) -> str:
"""The time (GMT) when the subscription confirmation was sent."""
return self["Sns"]["Timestamp"]
return self["Timestamp"]

@property
def signature(self) -> str:
"""Base64-encoded "SHA1withRSA" signature of the Message, MessageId, Type, Timestamp, and TopicArn values."""
return self["Sns"]["Signature"]
return self["Signature"]

@property
def signing_cert_url(self) -> str:
"""The URL to the certificate that was used to sign the message."""
return self["Sns"]["SigningCertUrl"]
return self["SigningCertUrl"]

@property
def message_id(self) -> str:
"""A Universally Unique Identifier, unique for each message published.

For a message that Amazon SNS resends during a retry, the message ID of the original message is used."""
return self["Sns"]["MessageId"]
return self["MessageId"]

@property
def message(self) -> str:
"""A string that describes the message."""
return self["Sns"]["Message"]
return self["Message"]

@property
def message_attributes(self) -> Dict[str, SNSMessageAttribute]:
return {k: SNSMessageAttribute(v) for (k, v) in self["Sns"]["MessageAttributes"].items()}
return {k: SNSMessageAttribute(v) for (k, v) in self["MessageAttributes"].items()}

@property
def get_type(self) -> str:
"""The type of message.

For a subscription confirmation, the type is SubscriptionConfirmation."""
# Note: this name conflicts with existing python builtins
return self["Sns"]["Type"]
return self["Type"]

@property
def unsubscribe_url(self) -> str:
"""A URL that you can use to unsubscribe the endpoint from this topic.

If you visit this URL, Amazon SNS unsubscribes the endpoint and stops sending notifications to this endpoint."""
return self["Sns"]["UnsubscribeUrl"]
return self["UnsubscribeUrl"]

@property
def topic_arn(self) -> str:
"""The Amazon Resource Name (ARN) for the topic that this endpoint is subscribed to."""
return self["Sns"]["TopicArn"]
return self["TopicArn"]

@property
def subject(self) -> str:
"""The Subject parameter specified when the notification was published to the topic."""
return self["Sns"]["Subject"]
return self["Subject"]


class SNSEventRecord(DictWrapper):
Expand All @@ -96,7 +96,7 @@ def event_source(self) -> str:

@property
def sns(self) -> SNSMessage:
return SNSMessage(self._data)
return SNSMessage(self._data["Sns"])


class SNSEvent(DictWrapper):
Expand Down
63 changes: 62 additions & 1 deletion aws_lambda_powertools/utilities/data_classes/sqs_event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Dict, Iterator, Optional
from typing import Any, Dict, Iterator, Optional, Type, TypeVar

from aws_lambda_powertools.utilities.data_classes import S3Event
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.sns_event import SNSMessage


class SQSRecordAttributes(DictWrapper):
Expand Down Expand Up @@ -83,6 +85,8 @@ def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: # type: ignor
class SQSRecord(DictWrapper):
"""An Amazon SQS message"""

NestedEvent = TypeVar("NestedEvent", bound=DictWrapper)

@property
def message_id(self) -> str:
"""A unique identifier for the message.
Expand Down Expand Up @@ -174,6 +178,63 @@ def queue_url(self) -> str:

return queue_url

@property
def decode_nested_s3_event(self) -> S3Event:
"""Returns the nested `S3Event` object that is sent in the body of a SQS message.

Even though you can typecast the object returned by `record.json_body`
directly, this method is provided as a shortcut for convenience.

Notes
-----

This method does not validate whether the SQS message body is actually a valid S3 event.

Examples
--------

```python
nested_event: S3Event = record.decode_nested_s3_event
```
"""
return self._decode_nested_event(S3Event)

@property
def decode_nested_sns_event(self) -> SNSMessage:
"""Returns the nested `SNSMessage` object that is sent in the body of a SQS message.

Even though you can typecast the object returned by `record.json_body`
directly, this method is provided as a shortcut for convenience.

Notes
-----

This method does not validate whether the SQS message body is actually
a valid SNS message.

Examples
--------

```python
nested_message: SNSMessage = record.decode_nested_sns_event
```
"""
return self._decode_nested_event(SNSMessage)

def _decode_nested_event(self, nested_event_class: Type[NestedEvent]) -> NestedEvent:
"""Returns the nested event source data object.

This is useful for handling events that are sent in the body of a SQS message.

Examples
--------

```python
data: S3Event = self._decode_nested_event(S3Event)
```
"""
return nested_event_class(self.json_body)


class SQSEvent(DictWrapper):
"""SQS Event
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/test_data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def test_sns_trigger_event():
assert event.sns_message == "Hello from SNS!"


def test_seq_trigger_event():
def test_sqs_trigger_event():
event = SQSEvent(load_event("sqsEvent.json"))

records = list(event.records)
Expand Down
104 changes: 104 additions & 0 deletions tests/unit/data_classes/test_sqs_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json
from typing import Dict

import pytest

from aws_lambda_powertools.utilities.data_classes import S3Event, SQSEvent
from aws_lambda_powertools.utilities.data_classes.sns_event import SNSMessage
from tests.functional.utils import load_event


@pytest.mark.parametrize(
"raw_event",
[
pytest.param(load_event("s3SqsEvent.json")),
],
ids=["s3_sqs"],
)
def test_decode_nested_s3_event(raw_event: Dict):
event = SQSEvent(raw_event)

records = list(event.records)
record = records[0]
attributes = record.attributes

assert len(records) == 1
assert record.message_id == "ca3e7a89-c358-40e5-8aa0-5da01403c267"
assert attributes.aws_trace_header is None
assert attributes.approximate_receive_count == "1"
assert attributes.sent_timestamp == "1681332219270"
assert attributes.sender_id == "AIDAJHIPRHEMV73VRJEBU"
assert attributes.approximate_first_receive_timestamp == "1681332239270"
assert attributes.sequence_number is None
assert attributes.message_group_id is None
assert attributes.message_deduplication_id is None
assert record.md5_of_body == "16f4460f4477d8d693a5abe94fdbbd73"
assert record.event_source == "aws:sqs"
assert record.event_source_arn == "arn:aws:sqs:us-east-1:123456789012:SQS"
assert record.aws_region == "us-east-1"

s3_event: S3Event = record.decode_nested_s3_event
s3_record = s3_event.record

assert s3_event.bucket_name == "xxx"
assert s3_event.object_key == "test.pdf"
assert s3_record.aws_region == "us-east-1"
assert s3_record.event_name == "ObjectCreated:Put"
assert s3_record.event_source == "aws:s3"
assert s3_record.event_time == "2023-04-12T20:43:38.021Z"
assert s3_record.event_version == "2.1"
assert s3_record.glacier_event_data is None
assert s3_record.request_parameters.source_ip_address == "93.108.161.96"
assert s3_record.response_elements["x-amz-request-id"] == "YMSSR8BZJ2Y99K6P"
assert s3_record.s3.s3_schema_version == "1.0"
assert s3_record.s3.bucket.arn == "arn:aws:s3:::xxx"
assert s3_record.s3.bucket.name == "xxx"
assert s3_record.s3.bucket.owner_identity.principal_id == "A1YQ72UWCM96UF"
assert s3_record.s3.configuration_id == "SNS"
assert s3_record.s3.get_object.etag == "2e3ad1e983318bbd8e73b080e2997980"
assert s3_record.s3.get_object.key == "test.pdf"
assert s3_record.s3.get_object.sequencer == "00643717F9F8B85354"
assert s3_record.s3.get_object.size == 104681
assert s3_record.s3.get_object.version_id == "yd3d4HaWOT2zguDLvIQLU6ptDTwKBnQV"
assert s3_record.user_identity.principal_id == "A1YQ72UWCM96UF"


@pytest.mark.parametrize(
"raw_event",
[
pytest.param(load_event("snsSqsEvent.json")),
],
ids=["sns_sqs"],
)
def test_decode_nested_sns_event(raw_event: Dict):
event = SQSEvent(raw_event)

records = list(event.records)
record = records[0]
attributes = record.attributes

assert len(records) == 1
assert record.message_id == "79406a00-bf15-46ca-978c-22c3613fcb30"
assert attributes.aws_trace_header is None
assert attributes.approximate_receive_count == "1"
assert attributes.sent_timestamp == "1611050827340"
assert attributes.sender_id == "AIDAISMY7JYY5F7RTT6AO"
assert attributes.approximate_first_receive_timestamp == "1611050827344"
assert attributes.sequence_number is None
assert attributes.message_group_id is None
assert attributes.message_deduplication_id is None
assert record.md5_of_body == "8910bdaaf9a30a607f7891037d4af0b0"
assert record.event_source == "aws:sqs"
assert record.event_source_arn == "arn:aws:sqs:eu-west-1:231436140809:powertools265"
assert record.aws_region == "eu-west-1"

sns_message: SNSMessage = record.decode_nested_sns_event
message = json.loads(sns_message.message)

assert sns_message.get_type == "Notification"
assert sns_message.message_id == "d88d4479-6ec0-54fe-b63f-1cf9df4bb16e"
assert sns_message.topic_arn == "arn:aws:sns:eu-west-1:231436140809:powertools265"
assert sns_message.timestamp == "2021-01-19T10:07:07.287Z"
assert sns_message.signature_version == "1"
assert message["message"] == "hello world"
assert message["username"] == "lessa"