Skip to content

Commit b0f170e

Browse files
committed
chore: test model support
1 parent 53b8e75 commit b0f170e

File tree

2 files changed

+281
-10
lines changed

2 files changed

+281
-10
lines changed

Diff for: aws_lambda_powertools/utilities/batch/base.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = N
213213
EventType.KinesisDataStreams: KinesisStreamRecord,
214214
EventType.DynamoDBStreams: DynamoDBRecord,
215215
}
216+
self._EVENT_ID_MAPPING = {}
216217

217218
super().__init__()
218219

@@ -268,14 +269,26 @@ def _get_messages_to_report(self) -> Dict[str, str]:
268269
"""
269270
return self._COLLECTOR_MAPPING[self.event_type]()
270271

272+
# Event Source Data Classes follow python idioms for fields
273+
# while Parser/Pydantic follows the event field names to the latter
271274
def _collect_sqs_failures(self):
272-
return {"itemIdentifier": msg.message_id for msg in self.fail_messages}
275+
if self.model:
276+
return {"itemIdentifier": msg.messageId for msg in self.fail_messages}
277+
else:
278+
return {"itemIdentifier": msg.message_id for msg in self.fail_messages}
273279

274280
def _collect_kinesis_failures(self):
275-
return {"itemIdentifier": msg.kinesis.sequence_number for msg in self.fail_messages}
281+
if self.model:
282+
# Pydantic model uses int but Lambda poller expects str
283+
return {"itemIdentifier": msg.kinesis.sequenceNumber for msg in self.fail_messages}
284+
else:
285+
return {"itemIdentifier": msg.kinesis.sequence_number for msg in self.fail_messages}
276286

277287
def _collect_dynamodb_failures(self):
278-
return {"itemIdentifier": msg.dynamodb.sequence_number for msg in self.fail_messages}
288+
if self.model:
289+
return {"itemIdentifier": msg.dynamodb.SequenceNumber for msg in self.fail_messages}
290+
else:
291+
return {"itemIdentifier": msg.dynamodb.sequence_number for msg in self.fail_messages}
279292

280293
@overload
281294
def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels":

Diff for: tests/functional/test_utilities_batch.py

+265-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import json
12
from random import randint
2-
from typing import Callable
3+
from typing import Callable, Dict, Optional
34
from unittest.mock import patch
45

56
import pytest
@@ -12,6 +13,11 @@
1213
from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord
1314
from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord
1415
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
16+
from aws_lambda_powertools.utilities.parser import BaseModel, validator
17+
from aws_lambda_powertools.utilities.parser.models import DynamoDBStreamChangedRecordModel, DynamoDBStreamRecordModel
18+
from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord as KinesisDataStreamRecordModel
19+
from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecordPayload, SqsRecordModel
20+
from aws_lambda_powertools.utilities.parser.types import Literal
1521
from tests.functional.utils import b64_to_str, str_to_b64
1622

1723

@@ -22,7 +28,12 @@ def factory(body: str):
2228
"messageId": "059f36b4-87a3-44ab-83d2-661975830a7d",
2329
"receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a",
2430
"body": body,
25-
"attributes": {},
31+
"attributes": {
32+
"ApproximateReceiveCount": "1",
33+
"SentTimestamp": "1545082649183",
34+
"SenderId": "AIDAIENQZJOLO23YVJ4VO",
35+
"ApproximateFirstReceiveTimestamp": "1545082649185",
36+
},
2637
"messageAttributes": {},
2738
"md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3",
2839
"eventSource": "aws:sqs",
@@ -66,7 +77,7 @@ def factory(body: str):
6677
"eventVersion": "1.0",
6778
"dynamodb": {
6879
"Keys": {"Id": {"N": "101"}},
69-
"NewImage": {"message": {"S": body}},
80+
"NewImage": {"Message": {"S": body}},
7081
"StreamViewType": "NEW_AND_OLD_IMAGES",
7182
"SequenceNumber": seq,
7283
"SizeBytes": 26,
@@ -105,7 +116,7 @@ def handler(record: KinesisStreamRecord):
105116
@pytest.fixture(scope="module")
106117
def dynamodb_record_handler() -> Callable:
107118
def handler(record: DynamoDBRecord):
108-
body = record.dynamodb.new_image.get("message").get_value
119+
body = record.dynamodb.new_image.get("Message").get_value
109120
if "fail" in body:
110121
raise Exception("Failed to process record.")
111122
return body
@@ -142,6 +153,14 @@ def stubbed_partial_processor_suppressed(config) -> PartialSQSProcessor:
142153
yield stubber, processor
143154

144155

156+
@pytest.fixture(scope="module")
157+
def order_event_factory() -> Callable:
158+
def factory(item: Dict) -> str:
159+
return json.dumps({"item": item})
160+
161+
return factory
162+
163+
145164
def test_partial_sqs_processor_context_with_failure(sqs_event_factory, record_handler, partial_processor):
146165
"""
147166
Test processor with one failing record
@@ -513,8 +532,8 @@ def test_batch_processor_dynamodb_context_success_only(dynamodb_event_factory, d
513532

514533
# THEN
515534
assert processed_messages == [
516-
("success", first_record["dynamodb"]["NewImage"]["message"]["S"], first_record),
517-
("success", second_record["dynamodb"]["NewImage"]["message"]["S"], second_record),
535+
("success", first_record["dynamodb"]["NewImage"]["Message"]["S"], first_record),
536+
("success", second_record["dynamodb"]["NewImage"]["Message"]["S"], second_record),
518537
]
519538

520539
assert batch.response() == {"batchItemFailures": []}
@@ -532,7 +551,7 @@ def test_batch_processor_dynamodb_context_with_failure(dynamodb_event_factory, d
532551
processed_messages = batch.process()
533552

534553
# THEN
535-
assert processed_messages[1] == ("success", second_record["dynamodb"]["NewImage"]["message"]["S"], second_record)
554+
assert processed_messages[1] == ("success", second_record["dynamodb"]["NewImage"]["Message"]["S"], second_record)
536555
assert len(batch.fail_messages) == 1
537556
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]}
538557

@@ -554,3 +573,242 @@ def lambda_handler(event, context):
554573

555574
# THEN
556575
assert len(result["batchItemFailures"]) == 1
576+
577+
578+
def test_batch_processor_context_model(sqs_event_factory, order_event_factory):
579+
# GIVEN
580+
class Order(BaseModel):
581+
item: dict
582+
583+
class OrderSqs(SqsRecordModel):
584+
body: Order
585+
586+
# auto transform json string
587+
# so Pydantic can auto-initialize nested Order model
588+
@validator("body", pre=True)
589+
def transform_body_to_dict(cls, value: str):
590+
return json.loads(value)
591+
592+
def record_handler(record: OrderSqs):
593+
return record.body.item
594+
595+
order_event = order_event_factory({"type": "success"})
596+
first_record = sqs_event_factory(order_event)
597+
second_record = sqs_event_factory(order_event)
598+
records = [first_record, second_record]
599+
600+
# WHEN
601+
processor = BatchProcessor(event_type=EventType.SQS, model=OrderSqs)
602+
with processor(records, record_handler) as batch:
603+
processed_messages = batch.process()
604+
605+
# THEN
606+
order_item = json.loads(order_event)["item"]
607+
assert processed_messages == [
608+
("success", order_item, first_record),
609+
("success", order_item, second_record),
610+
]
611+
612+
assert batch.response() == {"batchItemFailures": []}
613+
614+
615+
def test_batch_processor_context_model_with_failure(sqs_event_factory, order_event_factory):
616+
# GIVEN
617+
class Order(BaseModel):
618+
item: dict
619+
620+
class OrderSqs(SqsRecordModel):
621+
body: Order
622+
623+
# auto transform json string
624+
# so Pydantic can auto-initialize nested Order model
625+
@validator("body", pre=True)
626+
def transform_body_to_dict(cls, value: str):
627+
return json.loads(value)
628+
629+
def record_handler(record: OrderSqs):
630+
if "fail" in record.body.item["type"]:
631+
raise Exception("Failed to process record.")
632+
return record.body.item
633+
634+
order_event = order_event_factory({"type": "success"})
635+
order_event_fail = order_event_factory({"type": "fail"})
636+
first_record = sqs_event_factory(order_event_fail)
637+
second_record = sqs_event_factory(order_event)
638+
records = [first_record, second_record]
639+
640+
# WHEN
641+
processor = BatchProcessor(event_type=EventType.SQS, model=OrderSqs)
642+
with processor(records, record_handler) as batch:
643+
batch.process()
644+
645+
# THEN
646+
assert len(batch.fail_messages) == 1
647+
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["messageId"]}]}
648+
649+
650+
def test_batch_processor_dynamodb_context_model(dynamodb_event_factory, order_event_factory):
651+
# GIVEN
652+
class Order(BaseModel):
653+
item: dict
654+
655+
class OrderDynamoDB(BaseModel):
656+
Message: Order
657+
658+
# auto transform json string
659+
# so Pydantic can auto-initialize nested Order model
660+
@validator("Message", pre=True)
661+
def transform_message_to_dict(cls, value: Dict[Literal["S"], str]):
662+
return json.loads(value["S"])
663+
664+
class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel):
665+
NewImage: Optional[OrderDynamoDB]
666+
OldImage: Optional[OrderDynamoDB]
667+
668+
class OrderDynamoDBRecord(DynamoDBStreamRecordModel):
669+
dynamodb: OrderDynamoDBChangeRecord
670+
671+
def record_handler(record: OrderDynamoDBRecord):
672+
return record.dynamodb.NewImage.Message.item
673+
674+
order_event = order_event_factory({"type": "success"})
675+
first_record = dynamodb_event_factory(order_event)
676+
second_record = dynamodb_event_factory(order_event)
677+
records = [first_record, second_record]
678+
679+
# WHEN
680+
processor = BatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderDynamoDBRecord)
681+
with processor(records, record_handler) as batch:
682+
processed_messages = batch.process()
683+
684+
# THEN
685+
order_item = json.loads(order_event)["item"]
686+
assert processed_messages == [
687+
("success", order_item, first_record),
688+
("success", order_item, second_record),
689+
]
690+
691+
assert batch.response() == {"batchItemFailures": []}
692+
693+
694+
def test_batch_processor_dynamodb_context_model_with_failure(dynamodb_event_factory, order_event_factory):
695+
# GIVEN
696+
class Order(BaseModel):
697+
item: dict
698+
699+
class OrderDynamoDB(BaseModel):
700+
Message: Order
701+
702+
# auto transform json string
703+
# so Pydantic can auto-initialize nested Order model
704+
@validator("Message", pre=True)
705+
def transform_message_to_dict(cls, value: Dict[Literal["S"], str]):
706+
return json.loads(value["S"])
707+
708+
class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel):
709+
NewImage: Optional[OrderDynamoDB]
710+
OldImage: Optional[OrderDynamoDB]
711+
712+
class OrderDynamoDBRecord(DynamoDBStreamRecordModel):
713+
dynamodb: OrderDynamoDBChangeRecord
714+
715+
def record_handler(record: OrderDynamoDBRecord):
716+
if "fail" in record.dynamodb.NewImage.Message.item["type"]:
717+
raise Exception("Failed to process record.")
718+
return record.dynamodb.NewImage.Message.item
719+
720+
order_event = order_event_factory({"type": "success"})
721+
order_event_fail = order_event_factory({"type": "fail"})
722+
first_record = dynamodb_event_factory(order_event_fail)
723+
second_record = dynamodb_event_factory(order_event)
724+
records = [first_record, second_record]
725+
726+
# WHEN
727+
processor = BatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderDynamoDBRecord)
728+
with processor(records, record_handler) as batch:
729+
batch.process()
730+
731+
# THEN
732+
assert len(batch.fail_messages) == 1
733+
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]}
734+
735+
736+
def test_batch_processor_kinesis_context_parser_model(kinesis_event_factory, order_event_factory):
737+
# GIVEN
738+
class Order(BaseModel):
739+
item: dict
740+
741+
class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload):
742+
data: Order
743+
744+
# auto transform json string
745+
# so Pydantic can auto-initialize nested Order model
746+
@validator("data", pre=True)
747+
def transform_message_to_dict(cls, value: str):
748+
# Powertools KinesisDataStreamRecordModel
749+
return json.loads(value)
750+
751+
class OrderKinesisRecord(KinesisDataStreamRecordModel):
752+
kinesis: OrderKinesisPayloadRecord
753+
754+
def record_handler(record: OrderKinesisRecord):
755+
return record.kinesis.data.item
756+
757+
order_event = order_event_factory({"type": "success"})
758+
first_record = kinesis_event_factory(order_event)
759+
second_record = kinesis_event_factory(order_event)
760+
records = [first_record, second_record]
761+
762+
# WHEN
763+
processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord)
764+
with processor(records, record_handler) as batch:
765+
processed_messages = batch.process()
766+
767+
# THEN
768+
order_item = json.loads(order_event)["item"]
769+
assert processed_messages == [
770+
("success", order_item, first_record),
771+
("success", order_item, second_record),
772+
]
773+
774+
assert batch.response() == {"batchItemFailures": []}
775+
776+
777+
def test_batch_processor_kinesis_context_parser_model_with_failure(kinesis_event_factory, order_event_factory):
778+
# GIVEN
779+
class Order(BaseModel):
780+
item: dict
781+
782+
class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload):
783+
data: Order
784+
785+
# auto transform json string
786+
# so Pydantic can auto-initialize nested Order model
787+
@validator("data", pre=True)
788+
def transform_message_to_dict(cls, value: str):
789+
# Powertools KinesisDataStreamRecordModel
790+
return json.loads(value)
791+
792+
class OrderKinesisRecord(KinesisDataStreamRecordModel):
793+
kinesis: OrderKinesisPayloadRecord
794+
795+
def record_handler(record: OrderKinesisRecord):
796+
if "fail" in record.kinesis.data.item["type"]:
797+
raise Exception("Failed to process record.")
798+
return record.kinesis.data.item
799+
800+
order_event = order_event_factory({"type": "success"})
801+
order_event_fail = order_event_factory({"type": "fail"})
802+
803+
first_record = kinesis_event_factory(order_event_fail)
804+
second_record = kinesis_event_factory(order_event)
805+
records = [first_record, second_record]
806+
807+
# WHEN
808+
processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord)
809+
with processor(records, record_handler) as batch:
810+
batch.process()
811+
812+
# THEN
813+
assert len(batch.fail_messages) == 1
814+
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["kinesis"]["sequenceNumber"]}]}

0 commit comments

Comments
 (0)