diff --git a/aws_lambda_powertools/utilities/data_classes/sns_event.py b/aws_lambda_powertools/utilities/data_classes/sns_event.py index 84ee1c1ef0f..5d29d682ef2 100644 --- a/aws_lambda_powertools/utilities/data_classes/sns_event.py +++ b/aws_lambda_powertools/utilities/data_classes/sns_event.py @@ -20,38 +20,38 @@ class SNSMessage(DictWrapper): @property def signature_version(self) -> str: """Version of the Amazon SNS signature used.""" - return self["Sns"]["SignatureVersion"] + return self["SignatureVersion"] @property def timestamp(self) -> str: """The time (GMT) when the subscription confirmation was sent.""" - return self["Sns"]["Timestamp"] + return self["Timestamp"] @property def signature(self) -> str: """Base64-encoded "SHA1withRSA" signature of the Message, MessageId, Type, Timestamp, and TopicArn values.""" - return self["Sns"]["Signature"] + return self["Signature"] @property def signing_cert_url(self) -> str: """The URL to the certificate that was used to sign the message.""" - return self["Sns"]["SigningCertUrl"] + return self["SigningCertUrl"] @property def message_id(self) -> str: """A Universally Unique Identifier, unique for each message published. For a message that Amazon SNS resends during a retry, the message ID of the original message is used.""" - return self["Sns"]["MessageId"] + return self["MessageId"] @property def message(self) -> str: """A string that describes the message.""" - return self["Sns"]["Message"] + return self["Message"] @property def message_attributes(self) -> Dict[str, SNSMessageAttribute]: - return {k: SNSMessageAttribute(v) for (k, v) in self["Sns"]["MessageAttributes"].items()} + return {k: SNSMessageAttribute(v) for (k, v) in self["MessageAttributes"].items()} @property def get_type(self) -> str: @@ -59,24 +59,24 @@ def get_type(self) -> str: For a subscription confirmation, the type is SubscriptionConfirmation.""" # Note: this name conflicts with existing python builtins - return self["Sns"]["Type"] + return self["Type"] @property def unsubscribe_url(self) -> str: """A URL that you can use to unsubscribe the endpoint from this topic. If you visit this URL, Amazon SNS unsubscribes the endpoint and stops sending notifications to this endpoint.""" - return self["Sns"]["UnsubscribeUrl"] + return self["UnsubscribeUrl"] @property def topic_arn(self) -> str: """The Amazon Resource Name (ARN) for the topic that this endpoint is subscribed to.""" - return self["Sns"]["TopicArn"] + return self["TopicArn"] @property def subject(self) -> str: """The Subject parameter specified when the notification was published to the topic.""" - return self["Sns"]["Subject"] + return self["Subject"] class SNSEventRecord(DictWrapper): @@ -96,7 +96,7 @@ def event_source(self) -> str: @property def sns(self) -> SNSMessage: - return SNSMessage(self._data) + return SNSMessage(self._data["Sns"]) class SNSEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/sqs_event.py b/aws_lambda_powertools/utilities/data_classes/sqs_event.py index 2b3224358d8..ffec9854a2e 100644 --- a/aws_lambda_powertools/utilities/data_classes/sqs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/sqs_event.py @@ -1,6 +1,8 @@ -from typing import Any, Dict, Iterator, Optional +from typing import Any, Dict, Iterator, Optional, Type, TypeVar +from aws_lambda_powertools.utilities.data_classes import S3Event from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.sns_event import SNSMessage class SQSRecordAttributes(DictWrapper): @@ -83,6 +85,8 @@ def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: # type: ignor class SQSRecord(DictWrapper): """An Amazon SQS message""" + NestedEvent = TypeVar("NestedEvent", bound=DictWrapper) + @property def message_id(self) -> str: """A unique identifier for the message. @@ -174,6 +178,63 @@ def queue_url(self) -> str: return queue_url + @property + def decoded_nested_s3_event(self) -> S3Event: + """Returns the nested `S3Event` object that is sent in the body of a SQS message. + + Even though you can typecast the object returned by `record.json_body` + directly, this method is provided as a shortcut for convenience. + + Notes + ----- + + This method does not validate whether the SQS message body is actually a valid S3 event. + + Examples + -------- + + ```python + nested_event: S3Event = record.decoded_nested_s3_event + ``` + """ + return self._decode_nested_event(S3Event) + + @property + def decoded_nested_sns_event(self) -> SNSMessage: + """Returns the nested `SNSMessage` object that is sent in the body of a SQS message. + + Even though you can typecast the object returned by `record.json_body` + directly, this method is provided as a shortcut for convenience. + + Notes + ----- + + This method does not validate whether the SQS message body is actually + a valid SNS message. + + Examples + -------- + + ```python + nested_message: SNSMessage = record.decoded_nested_sns_event + ``` + """ + return self._decode_nested_event(SNSMessage) + + def _decode_nested_event(self, nested_event_class: Type[NestedEvent]) -> NestedEvent: + """Returns the nested event source data object. + + This is useful for handling events that are sent in the body of a SQS message. + + Examples + -------- + + ```python + data: S3Event = self._decode_nested_event(S3Event) + ``` + """ + return nested_event_class(self.json_body) + class SQSEvent(DictWrapper): """SQS Event diff --git a/tests/unit/data_classes/test_sqs_event.py b/tests/unit/data_classes/test_sqs_event.py index efacd4f026d..fe7b5e4a99a 100644 --- a/tests/unit/data_classes/test_sqs_event.py +++ b/tests/unit/data_classes/test_sqs_event.py @@ -1,4 +1,7 @@ -from aws_lambda_powertools.utilities.data_classes import SQSEvent +import json + +from aws_lambda_powertools.utilities.data_classes import S3Event, SQSEvent +from aws_lambda_powertools.utilities.data_classes.sns_event import SNSMessage from tests.functional.utils import load_event @@ -38,3 +41,94 @@ def test_seq_trigger_event(): record_2 = records[1] assert record_2.json_body == {"message": "foo1"} + + +def test_decode_nested_s3_event(): + raw_event = load_event("s3SqsEvent.json") + event = SQSEvent(raw_event) + + records = list(event.records) + record = records[0] + attributes = record.attributes + + assert len(records) == 1 + assert record.message_id == raw_event["Records"][0]["messageId"] + assert attributes.aws_trace_header is None + raw_attributes = raw_event["Records"][0]["attributes"] + assert attributes.approximate_receive_count == raw_attributes["ApproximateReceiveCount"] + assert attributes.sent_timestamp == raw_attributes["SentTimestamp"] + assert attributes.sender_id == raw_attributes["SenderId"] + assert attributes.approximate_first_receive_timestamp == raw_attributes["ApproximateFirstReceiveTimestamp"] + assert attributes.sequence_number is None + assert attributes.message_group_id is None + assert attributes.message_deduplication_id is None + assert record.md5_of_body == raw_event["Records"][0]["md5OfBody"] + assert record.event_source == raw_event["Records"][0]["eventSource"] + assert record.event_source_arn == raw_event["Records"][0]["eventSourceARN"] + assert record.aws_region == raw_event["Records"][0]["awsRegion"] + + s3_event: S3Event = record.decoded_nested_s3_event + s3_record = s3_event.record + raw_body = json.loads(raw_event["Records"][0]["body"]) + + assert s3_event.bucket_name == raw_body["Records"][0]["s3"]["bucket"]["name"] + assert s3_event.object_key == raw_body["Records"][0]["s3"]["object"]["key"] + raw_s3_record = raw_body["Records"][0] + assert s3_record.aws_region == raw_s3_record["awsRegion"] + assert s3_record.event_name == raw_s3_record["eventName"] + assert s3_record.event_source == raw_s3_record["eventSource"] + assert s3_record.event_time == raw_s3_record["eventTime"] + assert s3_record.event_version == raw_s3_record["eventVersion"] + assert s3_record.glacier_event_data is None + assert s3_record.request_parameters.source_ip_address == raw_s3_record["requestParameters"]["sourceIPAddress"] + assert s3_record.response_elements["x-amz-request-id"] == raw_s3_record["responseElements"]["x-amz-request-id"] + assert s3_record.s3.s3_schema_version == raw_s3_record["s3"]["s3SchemaVersion"] + assert s3_record.s3.bucket.arn == raw_s3_record["s3"]["bucket"]["arn"] + assert s3_record.s3.bucket.name == raw_s3_record["s3"]["bucket"]["name"] + assert ( + s3_record.s3.bucket.owner_identity.principal_id == raw_s3_record["s3"]["bucket"]["ownerIdentity"]["principalId"] + ) + assert s3_record.s3.configuration_id == raw_s3_record["s3"]["configurationId"] + assert s3_record.s3.get_object.etag == raw_s3_record["s3"]["object"]["eTag"] + assert s3_record.s3.get_object.key == raw_s3_record["s3"]["object"]["key"] + assert s3_record.s3.get_object.sequencer == raw_s3_record["s3"]["object"]["sequencer"] + assert s3_record.s3.get_object.size == raw_s3_record["s3"]["object"]["size"] + assert s3_record.s3.get_object.version_id == raw_s3_record["s3"]["object"]["versionId"] + + +def test_decode_nested_sns_event(): + raw_event = load_event("snsSqsEvent.json") + event = SQSEvent(raw_event) + + records = list(event.records) + record = records[0] + attributes = record.attributes + + assert len(records) == 1 + assert record.message_id == raw_event["Records"][0]["messageId"] + raw_attributes = raw_event["Records"][0]["attributes"] + assert attributes.aws_trace_header is None + assert attributes.approximate_receive_count == raw_attributes["ApproximateReceiveCount"] + assert attributes.sent_timestamp == raw_attributes["SentTimestamp"] + assert attributes.sender_id == raw_attributes["SenderId"] + assert attributes.approximate_first_receive_timestamp == raw_attributes["ApproximateFirstReceiveTimestamp"] + assert attributes.sequence_number is None + assert attributes.message_group_id is None + assert attributes.message_deduplication_id is None + assert record.md5_of_body == raw_event["Records"][0]["md5OfBody"] + assert record.event_source == raw_event["Records"][0]["eventSource"] + assert record.event_source_arn == raw_event["Records"][0]["eventSourceARN"] + assert record.aws_region == raw_event["Records"][0]["awsRegion"] + + sns_message: SNSMessage = record.decoded_nested_sns_event + raw_body = json.loads(raw_event["Records"][0]["body"]) + message = json.loads(sns_message.message) + + assert sns_message.get_type == raw_body["Type"] + assert sns_message.message_id == raw_body["MessageId"] + assert sns_message.topic_arn == raw_body["TopicArn"] + assert sns_message.timestamp == raw_body["Timestamp"] + assert sns_message.signature_version == raw_body["SignatureVersion"] + raw_message = json.loads(raw_body["Message"]) + assert message["message"] == raw_message["message"] + assert message["username"] == raw_message["username"]