Skip to content

Commit 58e610c

Browse files
rafaelgsrleandrodamascenaheitorlessa
authored
feat(event_source): decode nested messages on SQS events (#2349)
Co-authored-by: Leandro Damascena <[email protected]> Co-authored-by: Heitor Lessa <[email protected]>
1 parent c62d7bb commit 58e610c

File tree

3 files changed

+169
-14
lines changed

3 files changed

+169
-14
lines changed

aws_lambda_powertools/utilities/data_classes/sns_event.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -20,63 +20,63 @@ class SNSMessage(DictWrapper):
2020
@property
2121
def signature_version(self) -> str:
2222
"""Version of the Amazon SNS signature used."""
23-
return self["Sns"]["SignatureVersion"]
23+
return self["SignatureVersion"]
2424

2525
@property
2626
def timestamp(self) -> str:
2727
"""The time (GMT) when the subscription confirmation was sent."""
28-
return self["Sns"]["Timestamp"]
28+
return self["Timestamp"]
2929

3030
@property
3131
def signature(self) -> str:
3232
"""Base64-encoded "SHA1withRSA" signature of the Message, MessageId, Type, Timestamp, and TopicArn values."""
33-
return self["Sns"]["Signature"]
33+
return self["Signature"]
3434

3535
@property
3636
def signing_cert_url(self) -> str:
3737
"""The URL to the certificate that was used to sign the message."""
38-
return self["Sns"]["SigningCertUrl"]
38+
return self["SigningCertUrl"]
3939

4040
@property
4141
def message_id(self) -> str:
4242
"""A Universally Unique Identifier, unique for each message published.
4343
4444
For a message that Amazon SNS resends during a retry, the message ID of the original message is used."""
45-
return self["Sns"]["MessageId"]
45+
return self["MessageId"]
4646

4747
@property
4848
def message(self) -> str:
4949
"""A string that describes the message."""
50-
return self["Sns"]["Message"]
50+
return self["Message"]
5151

5252
@property
5353
def message_attributes(self) -> Dict[str, SNSMessageAttribute]:
54-
return {k: SNSMessageAttribute(v) for (k, v) in self["Sns"]["MessageAttributes"].items()}
54+
return {k: SNSMessageAttribute(v) for (k, v) in self["MessageAttributes"].items()}
5555

5656
@property
5757
def get_type(self) -> str:
5858
"""The type of message.
5959
6060
For a subscription confirmation, the type is SubscriptionConfirmation."""
6161
# Note: this name conflicts with existing python builtins
62-
return self["Sns"]["Type"]
62+
return self["Type"]
6363

6464
@property
6565
def unsubscribe_url(self) -> str:
6666
"""A URL that you can use to unsubscribe the endpoint from this topic.
6767
6868
If you visit this URL, Amazon SNS unsubscribes the endpoint and stops sending notifications to this endpoint."""
69-
return self["Sns"]["UnsubscribeUrl"]
69+
return self["UnsubscribeUrl"]
7070

7171
@property
7272
def topic_arn(self) -> str:
7373
"""The Amazon Resource Name (ARN) for the topic that this endpoint is subscribed to."""
74-
return self["Sns"]["TopicArn"]
74+
return self["TopicArn"]
7575

7676
@property
7777
def subject(self) -> str:
7878
"""The Subject parameter specified when the notification was published to the topic."""
79-
return self["Sns"]["Subject"]
79+
return self["Subject"]
8080

8181

8282
class SNSEventRecord(DictWrapper):
@@ -96,7 +96,7 @@ def event_source(self) -> str:
9696

9797
@property
9898
def sns(self) -> SNSMessage:
99-
return SNSMessage(self._data)
99+
return SNSMessage(self._data["Sns"])
100100

101101

102102
class SNSEvent(DictWrapper):

aws_lambda_powertools/utilities/data_classes/sqs_event.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import Any, Dict, Iterator, Optional
1+
from typing import Any, Dict, Iterator, Optional, Type, TypeVar
22

3+
from aws_lambda_powertools.utilities.data_classes import S3Event
34
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
5+
from aws_lambda_powertools.utilities.data_classes.sns_event import SNSMessage
46

57

68
class SQSRecordAttributes(DictWrapper):
@@ -83,6 +85,8 @@ def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: # type: ignor
8385
class SQSRecord(DictWrapper):
8486
"""An Amazon SQS message"""
8587

88+
NestedEvent = TypeVar("NestedEvent", bound=DictWrapper)
89+
8690
@property
8791
def message_id(self) -> str:
8892
"""A unique identifier for the message.
@@ -174,6 +178,63 @@ def queue_url(self) -> str:
174178

175179
return queue_url
176180

181+
@property
182+
def decoded_nested_s3_event(self) -> S3Event:
183+
"""Returns the nested `S3Event` object that is sent in the body of a SQS message.
184+
185+
Even though you can typecast the object returned by `record.json_body`
186+
directly, this method is provided as a shortcut for convenience.
187+
188+
Notes
189+
-----
190+
191+
This method does not validate whether the SQS message body is actually a valid S3 event.
192+
193+
Examples
194+
--------
195+
196+
```python
197+
nested_event: S3Event = record.decoded_nested_s3_event
198+
```
199+
"""
200+
return self._decode_nested_event(S3Event)
201+
202+
@property
203+
def decoded_nested_sns_event(self) -> SNSMessage:
204+
"""Returns the nested `SNSMessage` object that is sent in the body of a SQS message.
205+
206+
Even though you can typecast the object returned by `record.json_body`
207+
directly, this method is provided as a shortcut for convenience.
208+
209+
Notes
210+
-----
211+
212+
This method does not validate whether the SQS message body is actually
213+
a valid SNS message.
214+
215+
Examples
216+
--------
217+
218+
```python
219+
nested_message: SNSMessage = record.decoded_nested_sns_event
220+
```
221+
"""
222+
return self._decode_nested_event(SNSMessage)
223+
224+
def _decode_nested_event(self, nested_event_class: Type[NestedEvent]) -> NestedEvent:
225+
"""Returns the nested event source data object.
226+
227+
This is useful for handling events that are sent in the body of a SQS message.
228+
229+
Examples
230+
--------
231+
232+
```python
233+
data: S3Event = self._decode_nested_event(S3Event)
234+
```
235+
"""
236+
return nested_event_class(self.json_body)
237+
177238

178239
class SQSEvent(DictWrapper):
179240
"""SQS Event

tests/unit/data_classes/test_sqs_event.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from aws_lambda_powertools.utilities.data_classes import SQSEvent
1+
import json
2+
3+
from aws_lambda_powertools.utilities.data_classes import S3Event, SQSEvent
4+
from aws_lambda_powertools.utilities.data_classes.sns_event import SNSMessage
25
from tests.functional.utils import load_event
36

47

@@ -38,3 +41,94 @@ def test_seq_trigger_event():
3841

3942
record_2 = records[1]
4043
assert record_2.json_body == {"message": "foo1"}
44+
45+
46+
def test_decode_nested_s3_event():
47+
raw_event = load_event("s3SqsEvent.json")
48+
event = SQSEvent(raw_event)
49+
50+
records = list(event.records)
51+
record = records[0]
52+
attributes = record.attributes
53+
54+
assert len(records) == 1
55+
assert record.message_id == raw_event["Records"][0]["messageId"]
56+
assert attributes.aws_trace_header is None
57+
raw_attributes = raw_event["Records"][0]["attributes"]
58+
assert attributes.approximate_receive_count == raw_attributes["ApproximateReceiveCount"]
59+
assert attributes.sent_timestamp == raw_attributes["SentTimestamp"]
60+
assert attributes.sender_id == raw_attributes["SenderId"]
61+
assert attributes.approximate_first_receive_timestamp == raw_attributes["ApproximateFirstReceiveTimestamp"]
62+
assert attributes.sequence_number is None
63+
assert attributes.message_group_id is None
64+
assert attributes.message_deduplication_id is None
65+
assert record.md5_of_body == raw_event["Records"][0]["md5OfBody"]
66+
assert record.event_source == raw_event["Records"][0]["eventSource"]
67+
assert record.event_source_arn == raw_event["Records"][0]["eventSourceARN"]
68+
assert record.aws_region == raw_event["Records"][0]["awsRegion"]
69+
70+
s3_event: S3Event = record.decoded_nested_s3_event
71+
s3_record = s3_event.record
72+
raw_body = json.loads(raw_event["Records"][0]["body"])
73+
74+
assert s3_event.bucket_name == raw_body["Records"][0]["s3"]["bucket"]["name"]
75+
assert s3_event.object_key == raw_body["Records"][0]["s3"]["object"]["key"]
76+
raw_s3_record = raw_body["Records"][0]
77+
assert s3_record.aws_region == raw_s3_record["awsRegion"]
78+
assert s3_record.event_name == raw_s3_record["eventName"]
79+
assert s3_record.event_source == raw_s3_record["eventSource"]
80+
assert s3_record.event_time == raw_s3_record["eventTime"]
81+
assert s3_record.event_version == raw_s3_record["eventVersion"]
82+
assert s3_record.glacier_event_data is None
83+
assert s3_record.request_parameters.source_ip_address == raw_s3_record["requestParameters"]["sourceIPAddress"]
84+
assert s3_record.response_elements["x-amz-request-id"] == raw_s3_record["responseElements"]["x-amz-request-id"]
85+
assert s3_record.s3.s3_schema_version == raw_s3_record["s3"]["s3SchemaVersion"]
86+
assert s3_record.s3.bucket.arn == raw_s3_record["s3"]["bucket"]["arn"]
87+
assert s3_record.s3.bucket.name == raw_s3_record["s3"]["bucket"]["name"]
88+
assert (
89+
s3_record.s3.bucket.owner_identity.principal_id == raw_s3_record["s3"]["bucket"]["ownerIdentity"]["principalId"]
90+
)
91+
assert s3_record.s3.configuration_id == raw_s3_record["s3"]["configurationId"]
92+
assert s3_record.s3.get_object.etag == raw_s3_record["s3"]["object"]["eTag"]
93+
assert s3_record.s3.get_object.key == raw_s3_record["s3"]["object"]["key"]
94+
assert s3_record.s3.get_object.sequencer == raw_s3_record["s3"]["object"]["sequencer"]
95+
assert s3_record.s3.get_object.size == raw_s3_record["s3"]["object"]["size"]
96+
assert s3_record.s3.get_object.version_id == raw_s3_record["s3"]["object"]["versionId"]
97+
98+
99+
def test_decode_nested_sns_event():
100+
raw_event = load_event("snsSqsEvent.json")
101+
event = SQSEvent(raw_event)
102+
103+
records = list(event.records)
104+
record = records[0]
105+
attributes = record.attributes
106+
107+
assert len(records) == 1
108+
assert record.message_id == raw_event["Records"][0]["messageId"]
109+
raw_attributes = raw_event["Records"][0]["attributes"]
110+
assert attributes.aws_trace_header is None
111+
assert attributes.approximate_receive_count == raw_attributes["ApproximateReceiveCount"]
112+
assert attributes.sent_timestamp == raw_attributes["SentTimestamp"]
113+
assert attributes.sender_id == raw_attributes["SenderId"]
114+
assert attributes.approximate_first_receive_timestamp == raw_attributes["ApproximateFirstReceiveTimestamp"]
115+
assert attributes.sequence_number is None
116+
assert attributes.message_group_id is None
117+
assert attributes.message_deduplication_id is None
118+
assert record.md5_of_body == raw_event["Records"][0]["md5OfBody"]
119+
assert record.event_source == raw_event["Records"][0]["eventSource"]
120+
assert record.event_source_arn == raw_event["Records"][0]["eventSourceARN"]
121+
assert record.aws_region == raw_event["Records"][0]["awsRegion"]
122+
123+
sns_message: SNSMessage = record.decoded_nested_sns_event
124+
raw_body = json.loads(raw_event["Records"][0]["body"])
125+
message = json.loads(sns_message.message)
126+
127+
assert sns_message.get_type == raw_body["Type"]
128+
assert sns_message.message_id == raw_body["MessageId"]
129+
assert sns_message.topic_arn == raw_body["TopicArn"]
130+
assert sns_message.timestamp == raw_body["Timestamp"]
131+
assert sns_message.signature_version == raw_body["SignatureVersion"]
132+
raw_message = json.loads(raw_body["Message"])
133+
assert message["message"] == raw_message["message"]
134+
assert message["username"] == raw_message["username"]

0 commit comments

Comments
 (0)