|
1 |
| -from typing import List, Optional, Tuple |
2 |
| - |
3 |
| -from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType |
| 1 | +import logging |
| 2 | +from typing import Optional, Set |
| 3 | + |
| 4 | +from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, ExceptionInfo, FailureResponse |
| 5 | +from aws_lambda_powertools.utilities.batch.exceptions import ( |
| 6 | + SQSFifoCircuitBreakerError, |
| 7 | + SQSFifoMessageGroupCircuitBreakerError, |
| 8 | +) |
4 | 9 | from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel
|
5 | 10 |
|
6 |
| - |
7 |
| -class SQSFifoCircuitBreakerError(Exception): |
8 |
| - """ |
9 |
| - Signals a record not processed due to the SQS FIFO processing being interrupted |
10 |
| - """ |
11 |
| - |
12 |
| - pass |
| 11 | +logger = logging.getLogger(__name__) |
13 | 12 |
|
14 | 13 |
|
15 | 14 | class SqsFifoPartialProcessor(BatchProcessor):
|
@@ -57,36 +56,59 @@ def lambda_handler(event, context: LambdaContext):
|
57 | 56 | None,
|
58 | 57 | )
|
59 | 58 |
|
60 |
| - def __init__(self, model: Optional["BatchSqsTypeModel"] = None): |
61 |
| - super().__init__(EventType.SQS, model) |
| 59 | + group_circuit_breaker_exc = ( |
| 60 | + SQSFifoMessageGroupCircuitBreakerError, |
| 61 | + SQSFifoMessageGroupCircuitBreakerError("A previous record from this message group failed processing"), |
| 62 | + None, |
| 63 | + ) |
62 | 64 |
|
63 |
| - def process(self) -> List[Tuple]: |
| 65 | + def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_error: bool = False): |
64 | 66 | """
|
65 |
| - Call instance's handler for each record. When the first failed message is detected, |
66 |
| - the process is short-circuited, and the remaining messages are reported as failed items. |
| 67 | + Initialize the SqsFifoProcessor. |
| 68 | +
|
| 69 | + Parameters |
| 70 | + ---------- |
| 71 | + model: Optional["BatchSqsTypeModel"] |
| 72 | + An optional model for batch processing. |
| 73 | + skip_group_on_error: bool |
| 74 | + Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures |
| 75 | + Default is False. |
| 76 | +
|
67 | 77 | """
|
68 |
| - result: List[Tuple] = [] |
| 78 | + self._skip_group_on_error: bool = skip_group_on_error |
| 79 | + self._current_group_id = None |
| 80 | + self._failed_group_ids: Set[str] = set() |
| 81 | + super().__init__(EventType.SQS, model) |
69 | 82 |
|
70 |
| - for i, record in enumerate(self.records): |
71 |
| - # If we have failed messages, it means that the last message failed. |
72 |
| - # We then short circuit the process, failing the remaining messages |
73 |
| - if self.fail_messages: |
74 |
| - return self._short_circuit_processing(i, result) |
| 83 | + def _process_record(self, record): |
| 84 | + self._current_group_id = record.get("attributes", {}).get("MessageGroupId") |
75 | 85 |
|
76 |
| - # Otherwise, process the message normally |
77 |
| - result.append(self._process_record(record)) |
| 86 | + # Short-circuits the process if: |
| 87 | + # - There are failed messages, OR |
| 88 | + # - The `skip_group_on_error` option is on, and the current message is part of a failed group. |
| 89 | + fail_entire_batch = bool(self.fail_messages) and not self._skip_group_on_error |
| 90 | + fail_group_id = self._skip_group_on_error and self._current_group_id in self._failed_group_ids |
| 91 | + if fail_entire_batch or fail_group_id: |
| 92 | + return self.failure_handler( |
| 93 | + record=self._to_batch_type(record, event_type=self.event_type, model=self.model), |
| 94 | + exception=self.group_circuit_breaker_exc if self._skip_group_on_error else self.circuit_breaker_exc, |
| 95 | + ) |
78 | 96 |
|
79 |
| - return result |
| 97 | + return super()._process_record(record) |
80 | 98 |
|
81 |
| - def _short_circuit_processing(self, first_failure_index: int, result: List[Tuple]) -> List[Tuple]: |
82 |
| - """ |
83 |
| - Starting from the first failure index, fail all the remaining messages, and append them to the result list. |
84 |
| - """ |
85 |
| - remaining_records = self.records[first_failure_index:] |
86 |
| - for remaining_record in remaining_records: |
87 |
| - data = self._to_batch_type(record=remaining_record, event_type=self.event_type, model=self.model) |
88 |
| - result.append(self.failure_handler(record=data, exception=self.circuit_breaker_exc)) |
89 |
| - return result |
| 99 | + def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse: |
| 100 | + # If we are failing a message and the `skip_group_on_error` is on, we store the failed group ID |
| 101 | + # This way, future messages with the same group ID will be failed automatically. |
| 102 | + if self._skip_group_on_error and self._current_group_id: |
| 103 | + self._failed_group_ids.add(self._current_group_id) |
| 104 | + |
| 105 | + return super().failure_handler(record, exception) |
| 106 | + |
| 107 | + def _clean(self): |
| 108 | + self._failed_group_ids.clear() |
| 109 | + self._current_group_id = None |
| 110 | + |
| 111 | + super()._clean() |
90 | 112 |
|
91 | 113 | async def _async_process_record(self, record: dict):
|
92 | 114 | raise NotImplementedError()
|
0 commit comments