Skip to content

Commit 918844d

Browse files
refactor(batch): add from __future__ import annotations (#4993)
* refactor(batch): add from __future__ import annotations and update code according to ruff rules TCH, UP006, UP007, UP037 and FA100. * Fixing types in Python 3.8 and 3.9 --------- Co-authored-by: Leandro Damascena <[email protected]>
1 parent 689072f commit 918844d

File tree

5 files changed

+63
-48
lines changed

5 files changed

+63
-48
lines changed

aws_lambda_powertools/utilities/batch/base.py

+32-25
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# -*- coding: utf-8 -*-
2-
32
"""
43
Batch processing utilities
54
"""
5+
from __future__ import annotations
6+
67
import asyncio
78
import copy
89
import inspect
@@ -11,22 +12,28 @@
1112
import sys
1213
from abc import ABC, abstractmethod
1314
from enum import Enum
14-
from typing import Any, Callable, List, Optional, Tuple, Union, overload
15+
from typing import TYPE_CHECKING, Any, Callable, Tuple, Union, overload
1516

1617
from aws_lambda_powertools.shared import constants
1718
from aws_lambda_powertools.utilities.batch.exceptions import (
1819
BatchProcessingError,
1920
ExceptionInfo,
2021
)
21-
from aws_lambda_powertools.utilities.batch.types import BatchTypeModels, PartialItemFailureResponse, PartialItemFailures
22+
from aws_lambda_powertools.utilities.batch.types import BatchTypeModels
2223
from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import (
2324
DynamoDBRecord,
2425
)
2526
from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import (
2627
KinesisStreamRecord,
2728
)
2829
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
29-
from aws_lambda_powertools.utilities.typing import LambdaContext
30+
31+
if TYPE_CHECKING:
32+
from aws_lambda_powertools.utilities.batch.types import (
33+
PartialItemFailureResponse,
34+
PartialItemFailures,
35+
)
36+
from aws_lambda_powertools.utilities.typing import LambdaContext
3037

3138
logger = logging.getLogger(__name__)
3239

@@ -41,7 +48,7 @@ class EventType(Enum):
4148
# and depending on what EventType it's passed it'll correctly map to the right record
4249
# When using Pydantic Models, it'll accept any subclass from SQS, DynamoDB and Kinesis
4350
EventSourceDataClassTypes = Union[SQSRecord, KinesisStreamRecord, DynamoDBRecord]
44-
BatchEventTypes = Union[EventSourceDataClassTypes, "BatchTypeModels"]
51+
BatchEventTypes = Union[EventSourceDataClassTypes, BatchTypeModels]
4552
SuccessResponse = Tuple[str, Any, BatchEventTypes]
4653
FailureResponse = Tuple[str, str, BatchEventTypes]
4754

@@ -54,9 +61,9 @@ class BasePartialProcessor(ABC):
5461
lambda_context: LambdaContext
5562

5663
def __init__(self):
57-
self.success_messages: List[BatchEventTypes] = []
58-
self.fail_messages: List[BatchEventTypes] = []
59-
self.exceptions: List[ExceptionInfo] = []
64+
self.success_messages: list[BatchEventTypes] = []
65+
self.fail_messages: list[BatchEventTypes] = []
66+
self.exceptions: list[ExceptionInfo] = []
6067

6168
@abstractmethod
6269
def _prepare(self):
@@ -79,7 +86,7 @@ def _process_record(self, record: dict):
7986
"""
8087
raise NotImplementedError()
8188

82-
def process(self) -> List[Tuple]:
89+
def process(self) -> list[tuple]:
8390
"""
8491
Call instance's handler for each record.
8592
"""
@@ -92,7 +99,7 @@ async def _async_process_record(self, record: dict):
9299
"""
93100
raise NotImplementedError()
94101

95-
def async_process(self) -> List[Tuple]:
102+
def async_process(self) -> list[tuple]:
96103
"""
97104
Async call instance's handler for each record.
98105
@@ -135,13 +142,13 @@ def __enter__(self):
135142
def __exit__(self, exception_type, exception_value, traceback):
136143
self._clean()
137144

138-
def __call__(self, records: List[dict], handler: Callable, lambda_context: Optional[LambdaContext] = None):
145+
def __call__(self, records: list[dict], handler: Callable, lambda_context: LambdaContext | None = None):
139146
"""
140147
Set instance attributes before execution
141148
142149
Parameters
143150
----------
144-
records: List[dict]
151+
records: list[dict]
145152
List with objects to be processed.
146153
handler: Callable
147154
Callable to process "records" entries.
@@ -222,14 +229,14 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
222229
class BasePartialBatchProcessor(BasePartialProcessor): # noqa
223230
DEFAULT_RESPONSE: PartialItemFailureResponse = {"batchItemFailures": []}
224231

225-
def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None):
232+
def __init__(self, event_type: EventType, model: BatchTypeModels | None = None):
226233
"""Process batch and partially report failed items
227234
228235
Parameters
229236
----------
230237
event_type: EventType
231238
Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event
232-
model: Optional["BatchTypeModels"]
239+
model: BatchTypeModels | None
233240
Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord
234241
235242
Exceptions
@@ -294,7 +301,7 @@ def _has_messages_to_report(self) -> bool:
294301
def _entire_batch_failed(self) -> bool:
295302
return len(self.exceptions) == len(self.records)
296303

297-
def _get_messages_to_report(self) -> List[PartialItemFailures]:
304+
def _get_messages_to_report(self) -> list[PartialItemFailures]:
298305
"""
299306
Format messages to use in batch deletion
300307
"""
@@ -343,13 +350,13 @@ def _to_batch_type(
343350
self,
344351
record: dict,
345352
event_type: EventType,
346-
model: "BatchTypeModels",
347-
) -> "BatchTypeModels": ... # pragma: no cover
353+
model: BatchTypeModels,
354+
) -> BatchTypeModels: ... # pragma: no cover
348355

349356
@overload
350357
def _to_batch_type(self, record: dict, event_type: EventType) -> EventSourceDataClassTypes: ... # pragma: no cover
351358

352-
def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None):
359+
def _to_batch_type(self, record: dict, event_type: EventType, model: BatchTypeModels | None = None):
353360
if model is not None:
354361
# If a model is provided, we assume Pydantic is installed and we need to disable v2 warnings
355362
return model.model_validate(record)
@@ -363,7 +370,7 @@ def _register_model_validation_error_record(self, record: dict):
363370
# and downstream we can correctly collect the correct message id identifier and make the failed record available
364371
# see https://github.com/aws-powertools/powertools-lambda-python/issues/2091
365372
logger.debug("Record cannot be converted to customer's model; converting without model")
366-
failed_record: "EventSourceDataClassTypes" = self._to_batch_type(record=record, event_type=self.event_type)
373+
failed_record: EventSourceDataClassTypes = self._to_batch_type(record=record, event_type=self.event_type)
367374
return self.failure_handler(record=failed_record, exception=sys.exc_info())
368375

369376

@@ -453,7 +460,7 @@ def record_handler(record: DynamoDBRecord):
453460
logger.info(record.dynamodb.new_image)
454461
payload: dict = json.loads(record.dynamodb.new_image.get("item"))
455462
# alternatively:
456-
# changes: Dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
463+
# changes: dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
457464
# payload = change.get("Message") -> "<payload>"
458465
...
459466
@@ -481,7 +488,7 @@ def lambda_handler(event, context: LambdaContext):
481488
async def _async_process_record(self, record: dict):
482489
raise NotImplementedError()
483490

484-
def _process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]:
491+
def _process_record(self, record: dict) -> SuccessResponse | FailureResponse:
485492
"""
486493
Process a record with instance's handler
487494
@@ -490,7 +497,7 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons
490497
record: dict
491498
A batch record to be processed.
492499
"""
493-
data: Optional["BatchTypeModels"] = None
500+
data: BatchTypeModels | None = None
494501
try:
495502
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
496503
if self._handler_accepts_lambda_context:
@@ -602,7 +609,7 @@ async def record_handler(record: DynamoDBRecord):
602609
logger.info(record.dynamodb.new_image)
603610
payload: dict = json.loads(record.dynamodb.new_image.get("item"))
604611
# alternatively:
605-
# changes: Dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
612+
# changes: dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
606613
# payload = change.get("Message") -> "<payload>"
607614
...
608615
@@ -630,7 +637,7 @@ def lambda_handler(event, context: LambdaContext):
630637
def _process_record(self, record: dict):
631638
raise NotImplementedError()
632639

633-
async def _async_process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]:
640+
async def _async_process_record(self, record: dict) -> SuccessResponse | FailureResponse:
634641
"""
635642
Process a record with instance's handler
636643
@@ -639,7 +646,7 @@ async def _async_process_record(self, record: dict) -> Union[SuccessResponse, Fa
639646
record: dict
640647
A batch record to be processed.
641648
"""
642-
data: Optional["BatchTypeModels"] = None
649+
data: BatchTypeModels | None = None
643650
try:
644651
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
645652
if self._handler_accepts_lambda_context:

aws_lambda_powertools/utilities/batch/decorators.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import Any, Awaitable, Callable, Dict, List
4+
from typing import TYPE_CHECKING, Any, Awaitable, Callable
55

66
from typing_extensions import deprecated
77

@@ -12,10 +12,12 @@
1212
BatchProcessor,
1313
EventType,
1414
)
15-
from aws_lambda_powertools.utilities.batch.types import PartialItemFailureResponse
16-
from aws_lambda_powertools.utilities.typing import LambdaContext
1715
from aws_lambda_powertools.warnings import PowertoolsDeprecationWarning
1816

17+
if TYPE_CHECKING:
18+
from aws_lambda_powertools.utilities.batch.types import PartialItemFailureResponse
19+
from aws_lambda_powertools.utilities.typing import LambdaContext
20+
1921

2022
@lambda_handler_decorator
2123
@deprecated(
@@ -24,7 +26,7 @@
2426
)
2527
def async_batch_processor(
2628
handler: Callable,
27-
event: Dict,
29+
event: dict,
2830
context: LambdaContext,
2931
record_handler: Callable[..., Awaitable[Any]],
3032
processor: AsyncBatchProcessor,
@@ -40,7 +42,7 @@ def async_batch_processor(
4042
----------
4143
handler: Callable
4244
Lambda's handler
43-
event: Dict
45+
event: dict
4446
Lambda's Event
4547
context: LambdaContext
4648
Lambda's Context
@@ -92,7 +94,7 @@ def async_batch_processor(
9294
)
9395
def batch_processor(
9496
handler: Callable,
95-
event: Dict,
97+
event: dict,
9698
context: LambdaContext,
9799
record_handler: Callable,
98100
processor: BatchProcessor,
@@ -108,7 +110,7 @@ def batch_processor(
108110
----------
109111
handler: Callable
110112
Lambda's handler
111-
event: Dict
113+
event: dict
112114
Lambda's Event
113115
context: LambdaContext
114116
Lambda's Context
@@ -154,7 +156,7 @@ def batch_processor(
154156

155157

156158
def process_partial_response(
157-
event: Dict,
159+
event: dict,
158160
record_handler: Callable,
159161
processor: BasePartialBatchProcessor,
160162
context: LambdaContext | None = None,
@@ -164,7 +166,7 @@ def process_partial_response(
164166
165167
Parameters
166168
----------
167-
event: Dict
169+
event: dict
168170
Lambda's original event
169171
record_handler: Callable
170172
Callable to process each record from the batch
@@ -202,7 +204,7 @@ def handler(event, context):
202204
* Async batch processors. Use `async_process_partial_response` instead.
203205
"""
204206
try:
205-
records: List[Dict] = event.get("Records", [])
207+
records: list[dict] = event.get("Records", [])
206208
except AttributeError:
207209
event_types = ", ".join(list(EventType.__members__))
208210
docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line
@@ -218,7 +220,7 @@ def handler(event, context):
218220

219221

220222
def async_process_partial_response(
221-
event: Dict,
223+
event: dict,
222224
record_handler: Callable,
223225
processor: AsyncBatchProcessor,
224226
context: LambdaContext | None = None,
@@ -228,7 +230,7 @@ def async_process_partial_response(
228230
229231
Parameters
230232
----------
231-
event: Dict
233+
event: dict
232234
Lambda's original event
233235
record_handler: Callable
234236
Callable to process each record from the batch
@@ -266,7 +268,7 @@ def handler(event, context):
266268
* Sync batch processors. Use `process_partial_response` instead.
267269
"""
268270
try:
269-
records: List[Dict] = event.get("Records", [])
271+
records: list[dict] = event.get("Records", [])
270272
except AttributeError:
271273
event_types = ", ".join(list(EventType.__members__))
272274
docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line

aws_lambda_powertools/utilities/batch/exceptions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
import traceback
88
from types import TracebackType
9-
from typing import List, Optional, Tuple, Type
9+
from typing import Optional, Tuple, Type
1010

1111
ExceptionInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]
1212

1313

1414
class BaseBatchProcessingError(Exception):
15-
def __init__(self, msg="", child_exceptions: List[ExceptionInfo] | None = None):
15+
def __init__(self, msg="", child_exceptions: list[ExceptionInfo] | None = None):
1616
super().__init__(msg)
1717
self.msg = msg
1818
self.child_exceptions = child_exceptions or []
@@ -30,7 +30,7 @@ def format_exceptions(self, parent_exception_str):
3030
class BatchProcessingError(BaseBatchProcessingError):
3131
"""When all batch records failed to be processed"""
3232

33-
def __init__(self, msg="", child_exceptions: List[ExceptionInfo] | None = None):
33+
def __init__(self, msg="", child_exceptions: list[ExceptionInfo] | None = None):
3434
super().__init__(msg, child_exceptions)
3535

3636
def __str__(self):

aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
from __future__ import annotations
2+
13
import logging
2-
from typing import Optional, Set
4+
from typing import TYPE_CHECKING
35

46
from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, ExceptionInfo, FailureResponse
57
from aws_lambda_powertools.utilities.batch.exceptions import (
68
SQSFifoCircuitBreakerError,
79
SQSFifoMessageGroupCircuitBreakerError,
810
)
9-
from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel
11+
12+
if TYPE_CHECKING:
13+
from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel
1014

1115
logger = logging.getLogger(__name__)
1216

@@ -62,13 +66,13 @@ def lambda_handler(event, context: LambdaContext):
6266
None,
6367
)
6468

65-
def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_error: bool = False):
69+
def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error: bool = False):
6670
"""
6771
Initialize the SqsFifoProcessor.
6872
6973
Parameters
7074
----------
71-
model: Optional["BatchSqsTypeModel"]
75+
model: BatchSqsTypeModel | None
7276
An optional model for batch processing.
7377
skip_group_on_error: bool
7478
Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures
@@ -77,7 +81,7 @@ def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_er
7781
"""
7882
self._skip_group_on_error: bool = skip_group_on_error
7983
self._current_group_id = None
80-
self._failed_group_ids: Set[str] = set()
84+
self._failed_group_ids: set[str] = set()
8185
super().__init__(EventType.SQS, model)
8286

8387
def _process_record(self, record):

0 commit comments

Comments
 (0)