Skip to content

Commit 1ca372c

Browse files
committed
chore: refactor
1 parent 28d68af commit 1ca372c

File tree

1 file changed

+51
-78
lines changed

1 file changed

+51
-78
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
2-
from typing import Dict, List, Optional, Tuple
2+
from typing import Optional, Set, override
33

4-
from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType
4+
from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, ExceptionInfo, FailureResponse
55
from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel
66

77
logger = logging.getLogger(__name__)
@@ -15,6 +15,14 @@ class SQSFifoCircuitBreakerError(Exception):
1515
pass
1616

1717

18+
class SQSFifoMessageGroupCircuitBreakerError(Exception):
19+
"""
20+
Signals a record not processed due to the SQS FIFO message group processing being interrupted
21+
"""
22+
23+
pass
24+
25+
1826
class SqsFifoPartialProcessor(BatchProcessor):
1927
"""Process native partial responses from SQS FIFO queues.
2028
@@ -60,6 +68,12 @@ def lambda_handler(event, context: LambdaContext):
6068
None,
6169
)
6270

71+
group_circuit_breaker_exc = (
72+
SQSFifoMessageGroupCircuitBreakerError,
73+
SQSFifoMessageGroupCircuitBreakerError("A previous record from this message group failed processing"),
74+
None,
75+
)
76+
6377
def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_error: bool = False):
6478
"""
6579
Initialize the SqsFifoProcessor.
@@ -70,86 +84,45 @@ def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_er
7084
An optional model for batch processing.
7185
skip_group_on_error: bool
7286
Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures
73-
Default is False
87+
Default is False.
7488
7589
"""
76-
self._skip_group_on_error = skip_group_on_error
90+
self._skip_group_on_error: bool = skip_group_on_error
91+
self._current_group_id = None
92+
self._failed_group_ids: Set[str] = set()
7793
super().__init__(EventType.SQS, model)
7894

79-
def process(self) -> List[Tuple]:
80-
"""
81-
Call instance's handler for each record.
82-
83-
If skip_group_on_error is set to False, the process short-circuits upon detecting the first failed message,
84-
and the remaining messages are reported as failed items.
85-
86-
If skip_group_on_error is set to True, upon encountering the first failed message for a specific MessageGroupID,
87-
all messages from that MessageGroupID are skipped and reported as failed items.
88-
"""
89-
result: List[Tuple] = []
90-
skip_message_ids: List = []
91-
92-
for i, record in enumerate(self.records):
93-
# If we have failed messages and we are set to return on the first error,
94-
# short circuit the process and return the remaining messages as failed items
95-
if self.fail_messages and not self._skip_group_on_error:
96-
logger.debug("Processing of failed messages stopped because 'skip_group_on_error' is False")
97-
return self._short_circuit_processing(i, result)
98-
99-
msg_id = record.get("messageId")
100-
101-
# skip_group_on_error is True:
102-
# Skip processing the current message if its ID belongs to a group with failed messages
103-
if msg_id in skip_message_ids:
104-
logger.debug(
105-
f"Skipping message with ID '{msg_id}' as it is part of a group containing failed messages.",
106-
)
107-
continue
108-
109-
processed_message = self._process_record(record)
110-
111-
# If a processed message fail and skip_group_on_error is True,
112-
# mark subsequent messages from the same MessageGroupId as skipped
113-
if processed_message[0] == "fail" and self._skip_group_on_error:
114-
self._process_failed_subsequent_messages(record, i, skip_message_ids, result)
115-
116-
# Append the processed message normally
117-
result.append(processed_message)
118-
119-
return result
120-
121-
def _process_failed_subsequent_messages(
122-
self,
123-
record: Dict,
124-
i: int,
125-
skip_messages_group_id: List,
126-
result: List[Tuple],
127-
) -> None:
128-
"""
129-
Process failed subsequent messages from the same MessageGroupId and mark them as skipped.
130-
"""
131-
_attributes_record = record.get("attributes", {})
132-
133-
for subsequent_record in self.records[i + 1 :]:
134-
_attributes = subsequent_record.get("attributes", {})
135-
if _attributes.get("MessageGroupId") == _attributes_record.get("MessageGroupId"):
136-
skip_messages_group_id.append(subsequent_record.get("messageId"))
137-
data = self._to_batch_type(
138-
record=subsequent_record,
139-
event_type=self.event_type,
140-
model=self.model,
141-
)
142-
result.append(self.failure_handler(record=data, exception=self.circuit_breaker_exc))
143-
144-
def _short_circuit_processing(self, first_failure_index: int, result: List[Tuple]) -> List[Tuple]:
145-
"""
146-
Starting from the first failure index, fail all the remaining messages, and append them to the result list.
147-
"""
148-
remaining_records = self.records[first_failure_index:]
149-
for remaining_record in remaining_records:
150-
data = self._to_batch_type(record=remaining_record, event_type=self.event_type, model=self.model)
151-
result.append(self.failure_handler(record=data, exception=self.circuit_breaker_exc))
152-
return result
95+
@override
96+
def _process_record(self, record):
97+
self._current_group_id = record.get("attributes", {}).get("MessageGroupId")
98+
99+
# Short-circuits the process if:
100+
# - There are failed messages, OR
101+
# - The `skip_group_on_error` option is on, and the current message is part of a failed group.
102+
fail_group_id = self._skip_group_on_error and self._current_group_id in self._failed_group_ids
103+
if self.fail_messages or fail_group_id:
104+
return self.failure_handler(
105+
record=self._to_batch_type(record, event_type=self.event_type, model=self.model),
106+
exception=self.group_circuit_breaker_exc if self._skip_group_on_error else self.circuit_breaker_exc,
107+
)
108+
109+
return super()._process_record(record)
110+
111+
@override
112+
def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
113+
# If we are failing a message and the `skip_group_on_error` is on, we store the failed group ID
114+
# This way, future messages with the same group ID will be failed automatically.
115+
if self._skip_group_on_error and self._current_group_id:
116+
self._failed_group_ids.add(self._current_group_id)
117+
118+
return super().failure_handler(record, exception)
119+
120+
@override
121+
def _clean(self):
122+
self._failed_group_ids.clear()
123+
self._current_group_id = None
124+
125+
super()._clean()
153126

154127
async def _async_process_record(self, record: dict):
155128
raise NotImplementedError()

0 commit comments

Comments
 (0)