Skip to content

feat(batch): add support to SQS FIFO queues (SqsFifoPartialProcessor) #1934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aws_lambda_powertools/utilities/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
batch_processor,
)
from aws_lambda_powertools.utilities.batch.exceptions import ExceptionInfo
from aws_lambda_powertools.utilities.batch.sqs_fifo_partial_processor import (
SQSFifoPartialProcessor,
)

__all__ = (
"BatchProcessor",
Expand All @@ -26,6 +29,7 @@
"EventType",
"FailureResponse",
"SuccessResponse",
"SQSFifoPartialProcessor",
"batch_processor",
"async_batch_processor",
)
101 changes: 101 additions & 0 deletions aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import sys
from typing import List, Optional, Tuple, Type

from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType
from aws_lambda_powertools.utilities.parser.models import SqsRecordModel

#
# type specifics
#
has_pydantic = "pydantic" in sys.modules

# For IntelliSense and Mypy to work, we need to account for possible SQS subclasses
# We need them as subclasses as we must access their message ID or sequence number metadata via dot notation
if has_pydantic:
BatchTypeModels = Optional[Type[SqsRecordModel]]


class SQSFifoCircuitBreakerError(Exception):
"""
Signals a record not processed due to the SQS FIFO processing being interrupted
"""

pass


class SQSFifoPartialProcessor(BatchProcessor):
"""Specialized BatchProcessor subclass that handles FIFO SQS batch records.

As soon as the processing of the first record fails, the remaining records
are marked as failed without processing, and returned as native partial responses.

Example
_______

## Process batch triggered by a FIFO SQS

```python
import json

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities.batch import SQSFifoPartialProcessor, EventType, batch_processor
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext


processor = SQSFifoPartialProcessor()
tracer = Tracer()
logger = Logger()


@tracer.capture_method
def record_handler(record: SQSRecord):
payload: str = record.body
if payload:
item: dict = json.loads(payload)
...

@logger.inject_lambda_context
@tracer.capture_lambda_handler
@batch_processor(record_handler=record_handler, processor=processor)
def lambda_handler(event, context: LambdaContext):
return processor.response()
```
"""

circuitBreakerError = SQSFifoCircuitBreakerError("A previous record failed processing.")

def __init__(self, model: Optional["BatchTypeModels"] = None):
super().__init__(EventType.SQS, model)

def process(self) -> List[Tuple]:
result: List[Tuple] = []

for i, record in enumerate(self.records):
"""
If we have failed messages, it means that the last message failed.
We then short circuit the process, failing the remaining messages
"""
if self.fail_messages:
return self._short_circuit_processing(i, result)

"""
Otherwise, process the message normally
"""
result.append(self._process_record(record))

return result

def _short_circuit_processing(self, first_failure_index: int, result: List[Tuple]) -> List[Tuple]:
remaining_records = self.records[first_failure_index:]
for remaining_record in remaining_records:
data = self._to_batch_type(record=remaining_record, event_type=self.event_type, model=self.model)
result.append(
self.failure_handler(
record=data, exception=(type(self.circuitBreakerError), self.circuitBreakerError, None)
)
)
return result

async def _async_process_record(self, record: dict):
raise NotImplementedError()
18 changes: 18 additions & 0 deletions docs/utilities/batch.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,24 @@ Processing batches from SQS works in four stages:
}
```

#### FIFO queues

If you're using this feature with a FIFO queue, you should use the `SQSFifoPartialProcessor` class instead. We will
stop processing messages after the first failure, and return all failed and unprocessed messages in `batchItemFailures`.
This helps preserve the ordering of messages in your queue.

=== "As a decorator"

```python hl_lines="5 11"
--8<-- "examples/batch_processing/src/sqs_fifo_batch_processor.py"
```

=== "As a context manager"

```python hl_lines="4 8"
--8<-- "examples/batch_processing/src/sqs_fifo_batch_processor_context_manager.py"
```

### Processing messages from Kinesis

Processing batches from Kinesis works in four stages:
Expand Down
23 changes: 23 additions & 0 deletions examples/batch_processing/src/sqs_fifo_batch_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities.batch import (
SQSFifoPartialProcessor,
batch_processor,
)
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext

processor = SQSFifoPartialProcessor()
tracer = Tracer()
logger = Logger()


@tracer.capture_method
def record_handler(record: SQSRecord):
...


@logger.inject_lambda_context
@tracer.capture_lambda_handler
@batch_processor(record_handler=record_handler, processor=processor)
def lambda_handler(event, context: LambdaContext):
return processor.response()
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities.batch import SQSFifoPartialProcessor
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext

processor = SQSFifoPartialProcessor()
tracer = Tracer()
logger = Logger()


@tracer.capture_method
def record_handler(record: SQSRecord):
...


@logger.inject_lambda_context
@tracer.capture_lambda_handler
def lambda_handler(event, context: LambdaContext):
batch = event["Records"]
with processor(records=batch, handler=record_handler):
processor.process() # kick off processing, return List[Tuple]

return processor.response()
57 changes: 56 additions & 1 deletion tests/functional/test_utilities_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import uuid
from random import randint
from typing import Any, Awaitable, Callable, Dict, Optional

Expand All @@ -9,6 +10,7 @@
AsyncBatchProcessor,
BatchProcessor,
EventType,
SQSFifoPartialProcessor,
async_batch_processor,
batch_processor,
)
Expand Down Expand Up @@ -40,7 +42,7 @@
def sqs_event_factory() -> Callable:
def factory(body: str):
return {
"messageId": "059f36b4-87a3-44ab-83d2-661975830a7d",
"messageId": str(uuid.uuid4()),
"receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a",
"body": body,
"attributes": {
Expand Down Expand Up @@ -117,6 +119,17 @@ def handler(record):
return handler


@pytest.fixture(scope="module")
def sqs_fifo_record_handler() -> Callable:
def handler(record):
body = record["body"]
if "fail" in body:
raise Exception("Failed to process record.")
return body

return handler


@pytest.fixture(scope="module")
def async_record_handler() -> Callable[..., Awaitable[Any]]:
async def handler(record):
Expand Down Expand Up @@ -654,6 +667,48 @@ def lambda_handler(event, context):
assert "All records failed processing. " in str(e.value)


def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_factory, sqs_fifo_record_handler):
# GIVEN
first_record = SQSRecord(sqs_event_factory("success"))
second_record = SQSRecord(sqs_event_factory("success"))
event = {"Records": [first_record.raw_event, second_record.raw_event]}

processor = SQSFifoPartialProcessor()

@batch_processor(record_handler=sqs_fifo_record_handler, processor=processor)
def lambda_handler(event, context):
return processor.response()

# WHEN
result = lambda_handler(event, {})

# THEN
assert result["batchItemFailures"] == []


def test_sqs_fifo_batch_processor_middleware_with_failure(sqs_event_factory, sqs_fifo_record_handler):
# GIVEN
first_record = SQSRecord(sqs_event_factory("success"))
second_record = SQSRecord(sqs_event_factory("fail"))
# this would normally suceed, but since it's a FIFO queue, it will be marked as failure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loved this comment. thank you for your great attention to detail @rubenfonseca

third_record = SQSRecord(sqs_event_factory("success"))
event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]}

processor = SQSFifoPartialProcessor()

@batch_processor(record_handler=sqs_fifo_record_handler, processor=processor)
def lambda_handler(event, context):
return processor.response()

# WHEN
result = lambda_handler(event, {})

# THEN
assert len(result["batchItemFailures"]) == 2
assert result["batchItemFailures"][0]["itemIdentifier"] == second_record.message_id
assert result["batchItemFailures"][1]["itemIdentifier"] == third_record.message_id


def test_async_batch_processor_middleware_success_only(sqs_event_factory, async_record_handler):
# GIVEN
first_record = SQSRecord(sqs_event_factory("success"))
Expand Down