Skip to content

fix(batch): delete >10 messages in legacy sqs processor #818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions aws_lambda_powertools/utilities/batch/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
Batch SQS utilities
"""
import logging
import math
import sys
from typing import Callable, Dict, List, Optional, Tuple, cast
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Dict, List, Optional, Tuple, cast

import boto3
from botocore.config import Config
Expand Down Expand Up @@ -73,6 +75,7 @@ def __init__(
session = boto3_session or boto3.session.Session()
self.client = session.client("sqs", config=config)
self.suppress_exception = suppress_exception
self.max_message_batch = 10

super().__init__()

Expand Down Expand Up @@ -120,23 +123,39 @@ def _prepare(self):
self.success_messages.clear()
self.fail_messages.clear()

def _clean(self):
def _clean(self) -> Optional[List]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

niiiice one! Thanks

"""
Delete messages from Queue in case of partial failure.
"""

# If all messages were successful, fall back to the default SQS -
# Lambda behaviour which deletes messages if Lambda responds successfully
# Lambda behavior which deletes messages if Lambda responds successfully
if not self.fail_messages:
logger.debug(f"All {len(self.success_messages)} records successfully processed")
return
return None

queue_url = self._get_queue_url()
entries_to_remove = self._get_entries_to_clean()
# Batch delete up to 10 messages at a time (SQS limit)
max_workers = math.ceil(len(entries_to_remove) / self.max_message_batch)

delete_message_response = None
if entries_to_remove:
delete_message_response = self.client.delete_message_batch(QueueUrl=queue_url, Entries=entries_to_remove)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures, results = [], []
while entries_to_remove:
futures.append(
executor.submit(
self._delete_messages, queue_url, entries_to_remove[: self.max_message_batch], self.client
)
)
entries_to_remove = entries_to_remove[self.max_message_batch :]
for future in as_completed(futures):
try:
logger.debug("Deleted batch of processed messages from SQS")
results.append(future.result())
except Exception:
logger.exception("Couldn't remove batch of processed messages from SQS")
raise
if self.suppress_exception:
logger.debug(f"{len(self.fail_messages)} records failed processing, but exceptions are suppressed")
else:
Expand All @@ -147,6 +166,13 @@ def _clean(self):
child_exceptions=self.exceptions,
)

return results

def _delete_messages(self, queue_url: str, entries_to_remove: List, sqs_client: Any):
delete_message_response = sqs_client.delete_message_batch(
QueueUrl=queue_url,
Entries=entries_to_remove,
)
return delete_message_response


Expand Down
40 changes: 34 additions & 6 deletions tests/functional/test_utilities_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import math
from random import randint
from typing import Callable, Dict, Optional
from unittest.mock import patch
Expand Down Expand Up @@ -166,20 +167,26 @@ def factory(item: Dict) -> str:
return factory


def test_partial_sqs_processor_context_with_failure(sqs_event_factory, record_handler, partial_processor):
@pytest.mark.parametrize(
"success_messages_count",
([1, 18, 34]),
)
def test_partial_sqs_processor_context_with_failure(
success_messages_count, sqs_event_factory, record_handler, partial_processor
):
"""
Test processor with one failing record
Test processor with one failing record and multiple processed records
"""
fail_record = sqs_event_factory("fail")
success_record = sqs_event_factory("success")
success_records = [sqs_event_factory("success") for i in range(0, success_messages_count)]

records = [fail_record, success_record]
records = [fail_record, *success_records]

response = {"Successful": [{"Id": fail_record["messageId"]}], "Failed": []}

with Stubber(partial_processor.client) as stubber:
stubber.add_response("delete_message_batch", response)

for _ in range(0, math.ceil((success_messages_count / partial_processor.max_message_batch))):
stubber.add_response("delete_message_batch", response)
with pytest.raises(SQSBatchProcessingError) as error:
with partial_processor(records, record_handler) as ctx:
ctx.process()
Expand All @@ -188,6 +195,27 @@ def test_partial_sqs_processor_context_with_failure(sqs_event_factory, record_ha
stubber.assert_no_pending_responses()


def test_partial_sqs_processor_context_with_failure_exception(sqs_event_factory, record_handler, partial_processor):
"""
Test processor with one failing record
"""
fail_record = sqs_event_factory("fail")
success_record = sqs_event_factory("success")

records = [fail_record, success_record]

with Stubber(partial_processor.client) as stubber:
stubber.add_client_error(
method="delete_message_batch", service_error_code="ServiceUnavailable", http_status_code=503
)
with pytest.raises(Exception) as error:
with partial_processor(records, record_handler) as ctx:
ctx.process()

assert "ServiceUnavailable" in str(error.value)
stubber.assert_no_pending_responses()


def test_partial_sqs_processor_context_only_success(sqs_event_factory, record_handler, partial_processor):
"""
Test processor without failure
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_utilities_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def test_partial_sqs_clean(monkeypatch, mocker, partial_sqs_processor):
entries_to_clean_mock = mocker.patch.object(PartialSQSProcessor, "_get_entries_to_clean")

queue_url_mock.return_value = mocker.sentinel.queue_url
entries_to_clean_mock.return_value = mocker.sentinel.entries_to_clean
entries_to_clean_mock.return_value = [mocker.sentinel.entries_to_clean]

client_mock = mocker.patch.object(partial_sqs_processor, "client", autospec=True)
with pytest.raises(SQSBatchProcessingError):
partial_sqs_processor._clean()

client_mock.delete_message_batch.assert_called_once_with(
QueueUrl=mocker.sentinel.queue_url, Entries=mocker.sentinel.entries_to_clean
QueueUrl=mocker.sentinel.queue_url, Entries=[mocker.sentinel.entries_to_clean]
)