Skip to content

Commit 70c35b1

Browse files
authored
fix(batch): report multiple failures (#967)
1 parent 27e5022 commit 70c35b1

File tree

2 files changed

+77
-36
lines changed

2 files changed

+77
-36
lines changed

Diff for: aws_lambda_powertools/utilities/batch/base.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _clean(self):
385385
)
386386

387387
messages = self._get_messages_to_report()
388-
self.batch_response = {"batchItemFailures": [messages]}
388+
self.batch_response = {"batchItemFailures": messages}
389389

390390
def _has_messages_to_report(self) -> bool:
391391
if self.fail_messages:
@@ -397,7 +397,7 @@ def _has_messages_to_report(self) -> bool:
397397
def _entire_batch_failed(self) -> bool:
398398
return len(self.exceptions) == len(self.records)
399399

400-
def _get_messages_to_report(self) -> Dict[str, str]:
400+
def _get_messages_to_report(self) -> List[Dict[str, str]]:
401401
"""
402402
Format messages to use in batch deletion
403403
"""
@@ -406,20 +406,25 @@ def _get_messages_to_report(self) -> Dict[str, str]:
406406
# Event Source Data Classes follow python idioms for fields
407407
# while Parser/Pydantic follows the event field names to the latter
408408
def _collect_sqs_failures(self):
409-
if self.model:
410-
return {"itemIdentifier": msg.messageId for msg in self.fail_messages}
411-
return {"itemIdentifier": msg.message_id for msg in self.fail_messages}
409+
failures = []
410+
for msg in self.fail_messages:
411+
msg_id = msg.messageId if self.model else msg.message_id
412+
failures.append({"itemIdentifier": msg_id})
413+
return failures
412414

413415
def _collect_kinesis_failures(self):
414-
if self.model:
415-
# Pydantic model uses int but Lambda poller expects str
416-
return {"itemIdentifier": msg.kinesis.sequenceNumber for msg in self.fail_messages}
417-
return {"itemIdentifier": msg.kinesis.sequence_number for msg in self.fail_messages}
416+
failures = []
417+
for msg in self.fail_messages:
418+
msg_id = msg.kinesis.sequenceNumber if self.model else msg.kinesis.sequence_number
419+
failures.append({"itemIdentifier": msg_id})
420+
return failures
418421

419422
def _collect_dynamodb_failures(self):
420-
if self.model:
421-
return {"itemIdentifier": msg.dynamodb.SequenceNumber for msg in self.fail_messages}
422-
return {"itemIdentifier": msg.dynamodb.sequence_number for msg in self.fail_messages}
423+
failures = []
424+
for msg in self.fail_messages:
425+
msg_id = msg.dynamodb.SequenceNumber if self.model else msg.dynamodb.sequence_number
426+
failures.append({"itemIdentifier": msg_id})
427+
return failures
423428

424429
@overload
425430
def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels":

Diff for: tests/functional/test_utilities_batch.py

+60-24
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,8 @@ def test_batch_processor_middleware_with_failure(sqs_event_factory, record_handl
414414
# GIVEN
415415
first_record = SQSRecord(sqs_event_factory("fail"))
416416
second_record = SQSRecord(sqs_event_factory("success"))
417-
event = {"Records": [first_record.raw_event, second_record.raw_event]}
417+
third_record = SQSRecord(sqs_event_factory("fail"))
418+
event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]}
418419

419420
processor = BatchProcessor(event_type=EventType.SQS)
420421

@@ -426,7 +427,7 @@ def lambda_handler(event, context):
426427
result = lambda_handler(event, {})
427428

428429
# THEN
429-
assert len(result["batchItemFailures"]) == 1
430+
assert len(result["batchItemFailures"]) == 2
430431

431432

432433
def test_batch_processor_context_success_only(sqs_event_factory, record_handler):
@@ -453,7 +454,8 @@ def test_batch_processor_context_with_failure(sqs_event_factory, record_handler)
453454
# GIVEN
454455
first_record = SQSRecord(sqs_event_factory("failure"))
455456
second_record = SQSRecord(sqs_event_factory("success"))
456-
records = [first_record.raw_event, second_record.raw_event]
457+
third_record = SQSRecord(sqs_event_factory("fail"))
458+
records = [first_record.raw_event, second_record.raw_event, third_record.raw_event]
457459
processor = BatchProcessor(event_type=EventType.SQS)
458460

459461
# WHEN
@@ -462,8 +464,10 @@ def test_batch_processor_context_with_failure(sqs_event_factory, record_handler)
462464

463465
# THEN
464466
assert processed_messages[1] == ("success", second_record.body, second_record.raw_event)
465-
assert len(batch.fail_messages) == 1
466-
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.message_id}]}
467+
assert len(batch.fail_messages) == 2
468+
assert batch.response() == {
469+
"batchItemFailures": [{"itemIdentifier": first_record.message_id}, {"itemIdentifier": third_record.message_id}]
470+
}
467471

468472

469473
def test_batch_processor_kinesis_context_success_only(kinesis_event_factory, kinesis_record_handler):
@@ -491,8 +495,9 @@ def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kin
491495
# GIVEN
492496
first_record = KinesisStreamRecord(kinesis_event_factory("failure"))
493497
second_record = KinesisStreamRecord(kinesis_event_factory("success"))
498+
third_record = KinesisStreamRecord(kinesis_event_factory("failure"))
494499

495-
records = [first_record.raw_event, second_record.raw_event]
500+
records = [first_record.raw_event, second_record.raw_event, third_record.raw_event]
496501
processor = BatchProcessor(event_type=EventType.KinesisDataStreams)
497502

498503
# WHEN
@@ -501,15 +506,21 @@ def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kin
501506

502507
# THEN
503508
assert processed_messages[1] == ("success", b64_to_str(second_record.kinesis.data), second_record.raw_event)
504-
assert len(batch.fail_messages) == 1
505-
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.kinesis.sequence_number}]}
509+
assert len(batch.fail_messages) == 2
510+
assert batch.response() == {
511+
"batchItemFailures": [
512+
{"itemIdentifier": first_record.kinesis.sequence_number},
513+
{"itemIdentifier": third_record.kinesis.sequence_number},
514+
]
515+
}
506516

507517

508518
def test_batch_processor_kinesis_middleware_with_failure(kinesis_event_factory, kinesis_record_handler):
509519
# GIVEN
510520
first_record = KinesisStreamRecord(kinesis_event_factory("failure"))
511521
second_record = KinesisStreamRecord(kinesis_event_factory("success"))
512-
event = {"Records": [first_record.raw_event, second_record.raw_event]}
522+
third_record = KinesisStreamRecord(kinesis_event_factory("failure"))
523+
event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]}
513524

514525
processor = BatchProcessor(event_type=EventType.KinesisDataStreams)
515526

@@ -521,7 +532,7 @@ def lambda_handler(event, context):
521532
result = lambda_handler(event, {})
522533

523534
# THEN
524-
assert len(result["batchItemFailures"]) == 1
535+
assert len(result["batchItemFailures"]) == 2
525536

526537

527538
def test_batch_processor_dynamodb_context_success_only(dynamodb_event_factory, dynamodb_record_handler):
@@ -548,7 +559,8 @@ def test_batch_processor_dynamodb_context_with_failure(dynamodb_event_factory, d
548559
# GIVEN
549560
first_record = dynamodb_event_factory("failure")
550561
second_record = dynamodb_event_factory("success")
551-
records = [first_record, second_record]
562+
third_record = dynamodb_event_factory("failure")
563+
records = [first_record, second_record, third_record]
552564
processor = BatchProcessor(event_type=EventType.DynamoDBStreams)
553565

554566
# WHEN
@@ -557,15 +569,21 @@ def test_batch_processor_dynamodb_context_with_failure(dynamodb_event_factory, d
557569

558570
# THEN
559571
assert processed_messages[1] == ("success", second_record["dynamodb"]["NewImage"]["Message"]["S"], second_record)
560-
assert len(batch.fail_messages) == 1
561-
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]}
572+
assert len(batch.fail_messages) == 2
573+
assert batch.response() == {
574+
"batchItemFailures": [
575+
{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]},
576+
{"itemIdentifier": third_record["dynamodb"]["SequenceNumber"]},
577+
]
578+
}
562579

563580

564581
def test_batch_processor_dynamodb_middleware_with_failure(dynamodb_event_factory, dynamodb_record_handler):
565582
# GIVEN
566583
first_record = dynamodb_event_factory("failure")
567584
second_record = dynamodb_event_factory("success")
568-
event = {"Records": [first_record, second_record]}
585+
third_record = dynamodb_event_factory("failure")
586+
event = {"Records": [first_record, second_record, third_record]}
569587

570588
processor = BatchProcessor(event_type=EventType.DynamoDBStreams)
571589

@@ -577,7 +595,7 @@ def lambda_handler(event, context):
577595
result = lambda_handler(event, {})
578596

579597
# THEN
580-
assert len(result["batchItemFailures"]) == 1
598+
assert len(result["batchItemFailures"]) == 2
581599

582600

583601
def test_batch_processor_context_model(sqs_event_factory, order_event_factory):
@@ -639,17 +657,23 @@ def record_handler(record: OrderSqs):
639657
order_event = order_event_factory({"type": "success"})
640658
order_event_fail = order_event_factory({"type": "fail"})
641659
first_record = sqs_event_factory(order_event_fail)
660+
third_record = sqs_event_factory(order_event_fail)
642661
second_record = sqs_event_factory(order_event)
643-
records = [first_record, second_record]
662+
records = [first_record, second_record, third_record]
644663

645664
# WHEN
646665
processor = BatchProcessor(event_type=EventType.SQS, model=OrderSqs)
647666
with processor(records, record_handler) as batch:
648667
batch.process()
649668

650669
# THEN
651-
assert len(batch.fail_messages) == 1
652-
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["messageId"]}]}
670+
assert len(batch.fail_messages) == 2
671+
assert batch.response() == {
672+
"batchItemFailures": [
673+
{"itemIdentifier": first_record["messageId"]},
674+
{"itemIdentifier": third_record["messageId"]},
675+
]
676+
}
653677

654678

655679
def test_batch_processor_dynamodb_context_model(dynamodb_event_factory, order_event_factory):
@@ -726,16 +750,22 @@ def record_handler(record: OrderDynamoDBRecord):
726750
order_event_fail = order_event_factory({"type": "fail"})
727751
first_record = dynamodb_event_factory(order_event_fail)
728752
second_record = dynamodb_event_factory(order_event)
729-
records = [first_record, second_record]
753+
third_record = dynamodb_event_factory(order_event_fail)
754+
records = [first_record, second_record, third_record]
730755

731756
# WHEN
732757
processor = BatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderDynamoDBRecord)
733758
with processor(records, record_handler) as batch:
734759
batch.process()
735760

736761
# THEN
737-
assert len(batch.fail_messages) == 1
738-
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]}
762+
assert len(batch.fail_messages) == 2
763+
assert batch.response() == {
764+
"batchItemFailures": [
765+
{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]},
766+
{"itemIdentifier": third_record["dynamodb"]["SequenceNumber"]},
767+
]
768+
}
739769

740770

741771
def test_batch_processor_kinesis_context_parser_model(kinesis_event_factory, order_event_factory):
@@ -807,16 +837,22 @@ def record_handler(record: OrderKinesisRecord):
807837

808838
first_record = kinesis_event_factory(order_event_fail)
809839
second_record = kinesis_event_factory(order_event)
810-
records = [first_record, second_record]
840+
third_record = kinesis_event_factory(order_event_fail)
841+
records = [first_record, second_record, third_record]
811842

812843
# WHEN
813844
processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord)
814845
with processor(records, record_handler) as batch:
815846
batch.process()
816847

817848
# THEN
818-
assert len(batch.fail_messages) == 1
819-
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["kinesis"]["sequenceNumber"]}]}
849+
assert len(batch.fail_messages) == 2
850+
assert batch.response() == {
851+
"batchItemFailures": [
852+
{"itemIdentifier": first_record["kinesis"]["sequenceNumber"]},
853+
{"itemIdentifier": third_record["kinesis"]["sequenceNumber"]},
854+
]
855+
}
820856

821857

822858
def test_batch_processor_error_when_entire_batch_fails(sqs_event_factory, record_handler):

0 commit comments

Comments
 (0)