diff --git a/aws_lambda_powertools/utilities/batch/sqs.py b/aws_lambda_powertools/utilities/batch/sqs.py index 411e400615d..ee6a960c129 100644 --- a/aws_lambda_powertools/utilities/batch/sqs.py +++ b/aws_lambda_powertools/utilities/batch/sqs.py @@ -4,8 +4,10 @@ Batch SQS utilities """ import logging +import math import sys -from typing import Callable, Dict, List, Optional, Tuple, cast +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Callable, Dict, List, Optional, Tuple, cast import boto3 from botocore.config import Config @@ -73,6 +75,7 @@ def __init__( session = boto3_session or boto3.session.Session() self.client = session.client("sqs", config=config) self.suppress_exception = suppress_exception + self.max_message_batch = 10 super().__init__() @@ -120,23 +123,39 @@ def _prepare(self): self.success_messages.clear() self.fail_messages.clear() - def _clean(self): + def _clean(self) -> Optional[List]: """ Delete messages from Queue in case of partial failure. """ + # If all messages were successful, fall back to the default SQS - - # Lambda behaviour which deletes messages if Lambda responds successfully + # Lambda behavior which deletes messages if Lambda responds successfully if not self.fail_messages: logger.debug(f"All {len(self.success_messages)} records successfully processed") - return + return None queue_url = self._get_queue_url() entries_to_remove = self._get_entries_to_clean() + # Batch delete up to 10 messages at a time (SQS limit) + max_workers = math.ceil(len(entries_to_remove) / self.max_message_batch) - delete_message_response = None if entries_to_remove: - delete_message_response = self.client.delete_message_batch(QueueUrl=queue_url, Entries=entries_to_remove) - + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures, results = [], [] + while entries_to_remove: + futures.append( + executor.submit( + self._delete_messages, queue_url, entries_to_remove[: self.max_message_batch], self.client + ) + ) + entries_to_remove = entries_to_remove[self.max_message_batch :] + for future in as_completed(futures): + try: + logger.debug("Deleted batch of processed messages from SQS") + results.append(future.result()) + except Exception: + logger.exception("Couldn't remove batch of processed messages from SQS") + raise if self.suppress_exception: logger.debug(f"{len(self.fail_messages)} records failed processing, but exceptions are suppressed") else: @@ -147,6 +166,13 @@ def _clean(self): child_exceptions=self.exceptions, ) + return results + + def _delete_messages(self, queue_url: str, entries_to_remove: List, sqs_client: Any): + delete_message_response = sqs_client.delete_message_batch( + QueueUrl=queue_url, + Entries=entries_to_remove, + ) return delete_message_response diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index d32a044279b..2d9e6bab612 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -1,4 +1,5 @@ import json +import math from random import randint from typing import Callable, Dict, Optional from unittest.mock import patch @@ -166,20 +167,26 @@ def factory(item: Dict) -> str: return factory -def test_partial_sqs_processor_context_with_failure(sqs_event_factory, record_handler, partial_processor): +@pytest.mark.parametrize( + "success_messages_count", + ([1, 18, 34]), +) +def test_partial_sqs_processor_context_with_failure( + success_messages_count, sqs_event_factory, record_handler, partial_processor +): """ - Test processor with one failing record + Test processor with one failing record and multiple processed records """ fail_record = sqs_event_factory("fail") - success_record = sqs_event_factory("success") + success_records = [sqs_event_factory("success") for i in range(0, success_messages_count)] - records = [fail_record, success_record] + records = [fail_record, *success_records] response = {"Successful": [{"Id": fail_record["messageId"]}], "Failed": []} with Stubber(partial_processor.client) as stubber: - stubber.add_response("delete_message_batch", response) - + for _ in range(0, math.ceil((success_messages_count / partial_processor.max_message_batch))): + stubber.add_response("delete_message_batch", response) with pytest.raises(SQSBatchProcessingError) as error: with partial_processor(records, record_handler) as ctx: ctx.process() @@ -188,6 +195,27 @@ def test_partial_sqs_processor_context_with_failure(sqs_event_factory, record_ha stubber.assert_no_pending_responses() +def test_partial_sqs_processor_context_with_failure_exception(sqs_event_factory, record_handler, partial_processor): + """ + Test processor with one failing record + """ + fail_record = sqs_event_factory("fail") + success_record = sqs_event_factory("success") + + records = [fail_record, success_record] + + with Stubber(partial_processor.client) as stubber: + stubber.add_client_error( + method="delete_message_batch", service_error_code="ServiceUnavailable", http_status_code=503 + ) + with pytest.raises(Exception) as error: + with partial_processor(records, record_handler) as ctx: + ctx.process() + + assert "ServiceUnavailable" in str(error.value) + stubber.assert_no_pending_responses() + + def test_partial_sqs_processor_context_only_success(sqs_event_factory, record_handler, partial_processor): """ Test processor without failure diff --git a/tests/unit/test_utilities_batch.py b/tests/unit/test_utilities_batch.py index c491f0829cb..57de0223404 100644 --- a/tests/unit/test_utilities_batch.py +++ b/tests/unit/test_utilities_batch.py @@ -128,12 +128,12 @@ def test_partial_sqs_clean(monkeypatch, mocker, partial_sqs_processor): entries_to_clean_mock = mocker.patch.object(PartialSQSProcessor, "_get_entries_to_clean") queue_url_mock.return_value = mocker.sentinel.queue_url - entries_to_clean_mock.return_value = mocker.sentinel.entries_to_clean + entries_to_clean_mock.return_value = [mocker.sentinel.entries_to_clean] client_mock = mocker.patch.object(partial_sqs_processor, "client", autospec=True) with pytest.raises(SQSBatchProcessingError): partial_sqs_processor._clean() client_mock.delete_message_batch.assert_called_once_with( - QueueUrl=mocker.sentinel.queue_url, Entries=mocker.sentinel.entries_to_clean + QueueUrl=mocker.sentinel.queue_url, Entries=[mocker.sentinel.entries_to_clean] )