Skip to content

Commit fae56ef

Browse files
author
Michal Ploski
committed
Fix mypy errors
1 parent c43a5e8 commit fae56ef

File tree

24 files changed

+70
-57
lines changed

24 files changed

+70
-57
lines changed

.github/workflows/python_build.yml

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ jobs:
3030
run: make dev
3131
- name: Formatting and Linting
3232
run: make lint
33+
- name: Typing
34+
run: make mypy
3335
- name: Test with pytest
3436
run: make test
3537
- name: Security baseline

aws_lambda_powertools/metrics/metrics.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def lambda_handler():
5353
----------
5454
service : str, optional
5555
service name to be used as metric dimension, by default "service_undefined"
56-
namespace : str
56+
namespace : str, optional
5757
Namespace for metrics
5858
5959
Raises
@@ -209,5 +209,6 @@ def __add_cold_start_metric(self, context: Any) -> None:
209209
logger.debug("Adding cold start metric and function_name dimension")
210210
with single_metric(name="ColdStart", unit=MetricUnit.Count, value=1, namespace=self.namespace) as metric:
211211
metric.add_dimension(name="function_name", value=context.function_name)
212-
metric.add_dimension(name="service", value=self.service)
212+
if self.service:
213+
metric.add_dimension(name="service", value=str(self.service))
213214
is_cold_start = False

aws_lambda_powertools/shared/functions.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union
1+
from typing import Optional, Union
22

33

44
def strtobool(value: str) -> bool:
@@ -38,21 +38,23 @@ def resolve_truthy_env_var_choice(env: str, choice: Optional[bool] = None) -> bo
3838
return choice if choice is not None else strtobool(env)
3939

4040

41-
def resolve_env_var_choice(env: Any, choice: Optional[Any] = None) -> Union[bool, Any]:
41+
def resolve_env_var_choice(
42+
env: Optional[str] = None, choice: Optional[Union[str, float]] = None
43+
) -> Optional[Union[str, float]]:
4244
"""Pick explicit choice over env, if available, otherwise return env value received
4345
4446
NOTE: Environment variable should be resolved by the caller.
4547
4648
Parameters
4749
----------
48-
env : Any
50+
env : str, Optional
4951
environment variable actual value
50-
choice : bool
52+
choice : str|float, optional
5153
explicit choice
5254
5355
Returns
5456
-------
55-
choice : str
57+
choice : str, Optional
5658
resolved choice as either bool or environment value
5759
"""
5860
return choice if choice is not None else env

aws_lambda_powertools/utilities/batch/exceptions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from types import TracebackType
66
from typing import List, Optional, Tuple, Type
77

8-
ExceptionInfo = Tuple[Type[BaseException], BaseException, TracebackType]
8+
ExceptionInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]
99

1010

1111
class BaseBatchProcessingError(Exception):

aws_lambda_powertools/utilities/batch/sqs.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
"""
66
import logging
77
import sys
8-
from typing import Callable, Dict, List, Optional, Tuple
8+
from typing import Callable, Dict, List, Optional, Tuple, cast
99

1010
import boto3
1111
from botocore.config import Config
1212

13+
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
14+
1315
from ...middleware_factory import lambda_handler_decorator
1416
from .base import BasePartialProcessor
1517
from .exceptions import SQSBatchProcessingError
@@ -84,11 +86,14 @@ def _get_queue_url(self) -> Optional[str]:
8486
*_, account_id, queue_name = self.records[0]["eventSourceARN"].split(":")
8587
return f"{self.client._endpoint.host}/{account_id}/{queue_name}"
8688

87-
def _get_entries_to_clean(self) -> List:
89+
def _get_entries_to_clean(self) -> List[Dict[str, str]]:
8890
"""
8991
Format messages to use in batch deletion
9092
"""
91-
return [{"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]} for msg in self.success_messages]
93+
return [
94+
{"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]}
95+
for msg in cast(List[SQSRecord], self.success_messages)
96+
]
9297

9398
def _process_record(self, record) -> Tuple:
9499
"""

aws_lambda_powertools/utilities/idempotency/idempotency.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def process_order(customer_id: str, order: dict, **kwargs):
112112
return {"StatusCode": 200}
113113
"""
114114

115-
if function is None:
115+
if not function:
116116
return cast(
117117
AnyCallableT,
118118
functools.partial(
@@ -132,7 +132,7 @@ def decorate(*args, **kwargs):
132132

133133
payload = kwargs.get(data_keyword_argument)
134134

135-
if payload is None:
135+
if not payload:
136136
raise RuntimeError(
137137
f"Unable to extract '{data_keyword_argument}' from keyword arguments."
138138
f" Ensure this exists in your function's signature as well as the caller used it as a keyword argument"

aws_lambda_powertools/utilities/idempotency/persistence/base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
idempotency_key,
4141
status: str = "",
4242
expiry_timestamp: Optional[int] = None,
43-
response_data: Optional[str] = "",
43+
response_data: str = "",
4444
payload_hash: Optional[str] = None,
4545
) -> None:
4646
"""
@@ -279,14 +279,14 @@ def _save_to_cache(self, data_record: DataRecord):
279279
-------
280280
281281
"""
282-
if not self.use_local_cache:
282+
if not self.use_local_cache or self._cache is None:
283283
return
284284
if data_record.status == STATUS_CONSTANTS["INPROGRESS"]:
285285
return
286286
self._cache[data_record.idempotency_key] = data_record
287287

288288
def _retrieve_from_cache(self, idempotency_key: str):
289-
if not self.use_local_cache:
289+
if not self.use_local_cache or self._cache is None:
290290
return
291291
cached_record = self._cache.get(key=idempotency_key)
292292
if cached_record:
@@ -296,7 +296,7 @@ def _retrieve_from_cache(self, idempotency_key: str):
296296
self._delete_from_cache(idempotency_key=idempotency_key)
297297

298298
def _delete_from_cache(self, idempotency_key: str):
299-
if not self.use_local_cache:
299+
if not self.use_local_cache or self._cache is None:
300300
return
301301
if idempotency_key in self._cache:
302302
del self._cache[idempotency_key]

aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _item_to_data_record(self, item: Dict[str, Any]) -> DataRecord:
130130
idempotency_key=item[self.key_attr],
131131
status=item[self.status_attr],
132132
expiry_timestamp=item[self.expiry_attr],
133-
response_data=item.get(self.data_attr),
133+
response_data=item.get(self.data_attr, ""),
134134
payload_hash=item.get(self.validation_key_attr),
135135
)
136136

aws_lambda_powertools/utilities/parameters/base.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ExpirableValue = namedtuple("ExpirableValue", ["value", "ttl"])
1616
# These providers will be dynamically initialized on first use of the helper functions
1717
DEFAULT_PROVIDERS: Dict[str, Any] = {}
18+
MULTIPLE_VALUES_TYPE: Union[Dict[str, str], Dict[str, dict], Dict[str, bytes], Dict[str, None]]
1819
TRANSFORM_METHOD_JSON = "json"
1920
TRANSFORM_METHOD_BINARY = "binary"
2021
SUPPORTED_TRANSFORM_METHODS = [TRANSFORM_METHOD_JSON, TRANSFORM_METHOD_BINARY]
@@ -44,7 +45,7 @@ def get(
4445
transform: Optional[str] = None,
4546
force_fetch: bool = False,
4647
**sdk_options,
47-
) -> Union[str, list, dict, bytes]:
48+
) -> Optional[Union[str, dict, bytes]]:
4849
"""
4950
Retrieve a parameter value or return the cached value
5051
@@ -81,6 +82,7 @@ def get(
8182
# of supported transform is small and the probability that a given
8283
# parameter will always be used in a specific transform, this should be
8384
# an acceptable tradeoff.
85+
value: Optional[Union[str, bytes, dict]] = None
8486
key = (name, transform)
8587

8688
if not force_fetch and self._has_not_expired(key):
@@ -92,12 +94,12 @@ def get(
9294
except Exception as exc:
9395
raise GetParameterError(str(exc))
9496

95-
if transform is not None:
97+
if transform:
9698
if isinstance(value, bytes):
9799
value = value.decode("utf-8")
98100
value = transform_value(value, transform)
99-
100-
self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age))
101+
if value:
102+
self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age))
101103

102104
return value
103105

@@ -146,26 +148,25 @@ def get_multiple(
146148
TransformParameterError
147149
When the parameter provider fails to transform a parameter value.
148150
"""
149-
150151
key = (path, transform)
151152

152153
if not force_fetch and self._has_not_expired(key):
153154
return self.store[key].value
154155

155156
try:
156-
values: Dict[str, Union[str, bytes, dict, None]] = self._get_multiple(path, **sdk_options)
157+
values = self._get_multiple(path, **sdk_options)
157158
# Encapsulate all errors into a generic GetParameterError
158159
except Exception as exc:
159160
raise GetParameterError(str(exc))
160161

161-
if transform is not None:
162-
for (key, value) in values.items():
163-
_transform = get_transform_method(key, transform)
164-
if _transform is None:
162+
if transform:
163+
transformed_values: dict = {}
164+
for (item, value) in values.items():
165+
_transform = get_transform_method(item, transform)
166+
if not _transform:
165167
continue
166-
167-
values[key] = transform_value(value, _transform, raise_on_transform_error)
168-
168+
transformed_values[item] = transform_value(value, _transform, raise_on_transform_error)
169+
values.update(transformed_values)
169170
self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age))
170171

171172
return values
@@ -217,7 +218,9 @@ def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[
217218
return None
218219

219220

220-
def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]:
221+
def transform_value(
222+
value: str, transform: str, raise_on_transform_error: Optional[bool] = True
223+
) -> Optional[Union[dict, bytes]]:
221224
"""
222225
Apply a transform to a value
223226

aws_lambda_powertools/utilities/parameters/ssm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get(
9595
decrypt: bool = False,
9696
force_fetch: bool = False,
9797
**sdk_options
98-
) -> Union[str, list, dict, bytes]:
98+
) -> Optional[Union[str, dict, bytes]]:
9999
"""
100100
Retrieve a parameter value or return the cached value
101101

aws_lambda_powertools/utilities/parser/envelopes/apigw.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
2727
Parsed detail payload with model provided
2828
"""
2929
logger.debug(f"Parsing incoming data with Api Gateway model {APIGatewayProxyEventModel}")
30-
parsed_envelope = APIGatewayProxyEventModel.parse_obj(data)
30+
parsed_envelope: APIGatewayProxyEventModel = APIGatewayProxyEventModel.parse_obj(data)
3131
logger.debug(f"Parsing event payload in `detail` with {model}")
3232
return self._parse(data=parsed_envelope.body, model=model)

aws_lambda_powertools/utilities/parser/envelopes/apigwv2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
2727
Parsed detail payload with model provided
2828
"""
2929
logger.debug(f"Parsing incoming data with Api Gateway model V2 {APIGatewayProxyEventV2Model}")
30-
parsed_envelope = APIGatewayProxyEventV2Model.parse_obj(data)
30+
parsed_envelope: APIGatewayProxyEventV2Model = APIGatewayProxyEventV2Model.parse_obj(data)
3131
logger.debug(f"Parsing event payload in `detail` with {model}")
3232
return self._parse(data=parsed_envelope.body, model=model)

aws_lambda_powertools/utilities/parser/envelopes/event_bridge.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
2727
Parsed detail payload with model provided
2828
"""
2929
logger.debug(f"Parsing incoming data with EventBridge model {EventBridgeModel}")
30-
parsed_envelope = EventBridgeModel.parse_obj(data)
30+
parsed_envelope: EventBridgeModel = EventBridgeModel.parse_obj(data)
3131
logger.debug(f"Parsing event payload in `detail` with {model}")
3232
return self._parse(data=parsed_envelope.detail, model=model)

aws_lambda_powertools/utilities/parser/envelopes/sns.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
6969
parsed_envelope = SqsModel.parse_obj(data)
7070
output = []
7171
for record in parsed_envelope.Records:
72-
sns_notification = SnsNotificationModel.parse_raw(record.body)
72+
sns_notification: SnsNotificationModel = SnsNotificationModel.parse_raw(record.body)
7373
output.append(self._parse(data=sns_notification.Message, model=model))
7474
return output

aws_lambda_powertools/utilities/parser/models/alb.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Union
1+
from typing import Dict, Generic, Union
22

33
from pydantic import BaseModel
44

@@ -13,7 +13,7 @@ class AlbRequestContext(BaseModel):
1313
elb: AlbRequestContextData
1414

1515

16-
class AlbModel(BaseModel):
16+
class AlbModel(BaseModel, Generic[Model]):
1717
httpMethod: str
1818
path: str
1919
body: Union[str, Model]

aws_lambda_powertools/utilities/parser/models/apigw.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Any, Dict, List, Optional, Union
2+
from typing import Any, Dict, Generic, List, Optional, Union
33

44
from pydantic import BaseModel, root_validator
55
from pydantic.networks import IPvAnyNetwork
@@ -76,7 +76,7 @@ def check_message_id(cls, values):
7676
return values
7777

7878

79-
class APIGatewayProxyEventModel(BaseModel):
79+
class APIGatewayProxyEventModel(BaseModel, Generic[Model]):
8080
version: Optional[str]
8181
resource: str
8282
path: str

aws_lambda_powertools/utilities/parser/models/apigwv2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Any, Dict, List, Optional, Union
2+
from typing import Any, Dict, Generic, List, Optional, Union
33

44
from pydantic import BaseModel, Field
55
from pydantic.networks import IPvAnyNetwork
@@ -56,7 +56,7 @@ class RequestContextV2(BaseModel):
5656
http: RequestContextV2Http
5757

5858

59-
class APIGatewayProxyEventV2Model(BaseModel):
59+
class APIGatewayProxyEventV2Model(BaseModel, Generic[Model]):
6060
version: str
6161
routeKey: str
6262
rawPath: str

aws_lambda_powertools/utilities/parser/models/cloudwatch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import zlib
55
from datetime import datetime
6-
from typing import List, Union
6+
from typing import Generic, List, Union
77

88
from pydantic import BaseModel, Field, validator
99

@@ -12,7 +12,7 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15-
class CloudWatchLogsLogEvent(BaseModel):
15+
class CloudWatchLogsLogEvent(BaseModel, Generic[Model]):
1616
id: str # noqa AA03 VNE003
1717
timestamp: datetime
1818
message: Union[str, Model]

aws_lambda_powertools/utilities/parser/models/dynamodb.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from datetime import date
2-
from typing import Any, Dict, List, Optional, Union
2+
from typing import Any, Dict, Generic, List, Optional, Union
33

44
from pydantic import BaseModel
55

66
from aws_lambda_powertools.utilities.parser.types import Literal, Model
77

88

9-
class DynamoDBStreamChangedRecordModel(BaseModel):
9+
class DynamoDBStreamChangedRecordModel(BaseModel, Generic[Model]):
1010
ApproximateCreationDateTime: Optional[date]
1111
Keys: Dict[str, Dict[str, Any]]
1212
NewImage: Optional[Union[Dict[str, Any], Model]]

aws_lambda_powertools/utilities/parser/models/event_bridge.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from datetime import datetime
2-
from typing import Any, Dict, List, Optional, Union
2+
from typing import Any, Dict, Generic, List, Optional, Union
33

44
from pydantic import BaseModel, Field
55

66
from aws_lambda_powertools.utilities.parser.types import Model
77

88

9-
class EventBridgeModel(BaseModel):
9+
class EventBridgeModel(BaseModel, Generic[Model]):
1010
version: str
1111
id: str # noqa: A003,VNE003
1212
source: str

0 commit comments

Comments
 (0)