Skip to content

Commit dea696e

Browse files
mploskiMichal Ploskiheitorlessa
authored
feat(mypy): complete mypy support for the entire codebase (aws-powertools#943)
Co-authored-by: Michal Ploski <[email protected]> Co-authored-by: Heitor Lessa <[email protected]> Co-authored-by: heitorlessa <[email protected]>
1 parent ed7a978 commit dea696e

File tree

30 files changed

+145
-83
lines changed

30 files changed

+145
-83
lines changed

.github/workflows/python_build.yml

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ jobs:
3030
run: make dev
3131
- name: Formatting and Linting
3232
run: make lint
33+
- name: Static type checking
34+
run: make mypy
3335
- name: Test with pytest
3436
run: make test
3537
- name: Security baseline

CONTRIBUTING.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ You might find useful to run both the documentation website and the API referenc
4949

5050
Category | Convention
5151
------------------------------------------------- | ---------------------------------------------------------------------------------
52-
**Docstring** | We use a slight variation of numpy convention with markdown to help generate more readable API references.
53-
**Style guide** | We use black as well as flake8 extensions to enforce beyond good practices [PEP8](https://pep8.org/). We strive to make use of type annotation as much as possible, but don't overdo in creating custom types.
52+
**Docstring** | We use a slight variation of Numpy convention with markdown to help generate more readable API references.
53+
**Style guide** | We use black as well as flake8 extensions to enforce beyond good practices [PEP8](https://pep8.org/). We use type annotations and enforce static type checking at CI (mypy).
5454
**Core utilities** | Core utilities use a Class, always accept `service` as a constructor parameter, can work in isolation, and are also available in other languages implementation.
5555
**Utilities** | Utilities are not as strict as core and focus on solving a developer experience problem while following the project [Tenets](https://awslabs.github.io/aws-lambda-powertools-python/#tenets).
5656
**Exceptions** | Specific exceptions live within utilities themselves and use `Error` suffix e.g. `MetricUnitError`.
57-
**Git commits** | We follow [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/). These are not enforced as we squash and merge PRs, but PR titles are enforced during CI.
58-
**Documentation** | API reference docs are generated from docstrings which should have Examples section to allow developers to have what they need within their own IDE. Documentation website covers the wider usage, tips, and strive to be concise.
57+
**Git commits** | We follow [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/). We do not enforce conventional commits on contributors to lower the entry bar. Instead, we enforce a conventional PR title so our label automation and changelog are generated correctly.
58+
**API documentation** | API reference docs are generated from docstrings which should have Examples section to allow developers to have what they need within their own IDE. Documentation website covers the wider usage, tips, and strive to be concise.
59+
**Documentation** | We treat it like a product. We sub-divide content aimed at getting started (80% of customers) vs advanced usage (20%). We also ensure customers know how to unit test their code when using our features.
5960

6061
## Finding contributions to work on
6162

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ coverage-html:
2929
pre-commit:
3030
pre-commit run --show-diff-on-failure
3131

32-
pr: lint pre-commit test security-baseline complexity-baseline
32+
pr: lint mypy pre-commit test security-baseline complexity-baseline
3333

3434
build: pr
3535
poetry build

aws_lambda_powertools/event_handler/api_gateway.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from enum import Enum
1111
from functools import partial
1212
from http import HTTPStatus
13-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
13+
from typing import Any, Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Type, Union
1414

1515
from aws_lambda_powertools.event_handler import content_types
1616
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
@@ -167,7 +167,7 @@ class Route:
167167
"""Internally used Route Configuration"""
168168

169169
def __init__(
170-
self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str]
170+
self, method: str, rule: Pattern, func: Callable, cors: bool, compress: bool, cache_control: Optional[str]
171171
):
172172
self.method = method.upper()
173173
self.rule = rule
@@ -555,7 +555,7 @@ def _resolve(self) -> ResponseBuilder:
555555
for route in self._routes:
556556
if method != route.method:
557557
continue
558-
match_results: Optional[re.Match] = route.rule.match(path)
558+
match_results: Optional[Match] = route.rule.match(path)
559559
if match_results:
560560
logger.debug("Found a registered route. Calling function")
561561
return self._call_route(route, match_results.groupdict()) # pass fn args

aws_lambda_powertools/metrics/metrics.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def lambda_handler():
5353
----------
5454
service : str, optional
5555
service name to be used as metric dimension, by default "service_undefined"
56-
namespace : str
56+
namespace : str, optional
5757
Namespace for metrics
5858
5959
Raises
@@ -209,5 +209,6 @@ def __add_cold_start_metric(self, context: Any) -> None:
209209
logger.debug("Adding cold start metric and function_name dimension")
210210
with single_metric(name="ColdStart", unit=MetricUnit.Count, value=1, namespace=self.namespace) as metric:
211211
metric.add_dimension(name="function_name", value=context.function_name)
212-
metric.add_dimension(name="service", value=self.service)
212+
if self.service:
213+
metric.add_dimension(name="service", value=str(self.service))
213214
is_cold_start = False

aws_lambda_powertools/shared/functions.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union
1+
from typing import Optional, Union
22

33

44
def strtobool(value: str) -> bool:
@@ -38,21 +38,23 @@ def resolve_truthy_env_var_choice(env: str, choice: Optional[bool] = None) -> bo
3838
return choice if choice is not None else strtobool(env)
3939

4040

41-
def resolve_env_var_choice(env: Any, choice: Optional[Any] = None) -> Union[bool, Any]:
41+
def resolve_env_var_choice(
42+
env: Optional[str] = None, choice: Optional[Union[str, float]] = None
43+
) -> Optional[Union[str, float]]:
4244
"""Pick explicit choice over env, if available, otherwise return env value received
4345
4446
NOTE: Environment variable should be resolved by the caller.
4547
4648
Parameters
4749
----------
48-
env : Any
50+
env : str, Optional
4951
environment variable actual value
50-
choice : bool
52+
choice : str|float, optional
5153
explicit choice
5254
5355
Returns
5456
-------
55-
choice : str
57+
choice : str, Optional
5658
resolved choice as either bool or environment value
5759
"""
5860
return choice if choice is not None else env

aws_lambda_powertools/utilities/batch/exceptions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from types import TracebackType
66
from typing import List, Optional, Tuple, Type
77

8-
ExceptionInfo = Tuple[Type[BaseException], BaseException, TracebackType]
8+
ExceptionInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]
99

1010

1111
class BaseBatchProcessingError(Exception):

aws_lambda_powertools/utilities/batch/sqs.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
"""
66
import logging
77
import sys
8-
from typing import Callable, Dict, List, Optional, Tuple
8+
from typing import Callable, Dict, List, Optional, Tuple, cast
99

1010
import boto3
1111
from botocore.config import Config
1212

13+
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
14+
1315
from ...middleware_factory import lambda_handler_decorator
1416
from .base import BasePartialProcessor
1517
from .exceptions import SQSBatchProcessingError
@@ -84,11 +86,17 @@ def _get_queue_url(self) -> Optional[str]:
8486
*_, account_id, queue_name = self.records[0]["eventSourceARN"].split(":")
8587
return f"{self.client._endpoint.host}/{account_id}/{queue_name}"
8688

87-
def _get_entries_to_clean(self) -> List:
89+
def _get_entries_to_clean(self) -> List[Dict[str, str]]:
8890
"""
8991
Format messages to use in batch deletion
9092
"""
91-
return [{"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]} for msg in self.success_messages]
93+
entries = []
94+
# success_messages has generic type of union of SQS, Dynamodb and Kinesis Streams records or Pydantic models.
95+
# Here we get SQS Record only
96+
messages = cast(List[SQSRecord], self.success_messages)
97+
for msg in messages:
98+
entries.append({"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]})
99+
return entries
92100

93101
def _process_record(self, record) -> Tuple:
94102
"""

aws_lambda_powertools/utilities/idempotency/idempotency.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def process_order(customer_id: str, order: dict, **kwargs):
112112
return {"StatusCode": 200}
113113
"""
114114

115-
if function is None:
115+
if not function:
116116
return cast(
117117
AnyCallableT,
118118
functools.partial(
@@ -132,7 +132,7 @@ def decorate(*args, **kwargs):
132132

133133
payload = kwargs.get(data_keyword_argument)
134134

135-
if payload is None:
135+
if not payload:
136136
raise RuntimeError(
137137
f"Unable to extract '{data_keyword_argument}' from keyword arguments."
138138
f" Ensure this exists in your function's signature as well as the caller used it as a keyword argument"

aws_lambda_powertools/utilities/idempotency/persistence/base.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,16 @@ def status(self) -> str:
9292
else:
9393
raise IdempotencyInvalidStatusError(self._status)
9494

95-
def response_json_as_dict(self) -> dict:
95+
def response_json_as_dict(self) -> Optional[dict]:
9696
"""
9797
Get response data deserialized to python dict
9898
9999
Returns
100100
-------
101-
dict
101+
Optional[dict]
102102
previous response data deserialized
103103
"""
104-
return json.loads(self.response_data)
104+
return json.loads(self.response_data) if self.response_data else None
105105

106106

107107
class BasePersistenceLayer(ABC):
@@ -121,7 +121,6 @@ def __init__(self):
121121
self.raise_on_no_idempotency_key = False
122122
self.expires_after_seconds: int = 60 * 60 # 1 hour default
123123
self.use_local_cache = False
124-
self._cache: Optional[LRUDict] = None
125124
self.hash_function = None
126125

127126
def configure(self, config: IdempotencyConfig, function_name: Optional[str] = None) -> None:

aws_lambda_powertools/utilities/parameters/base.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get(
4444
transform: Optional[str] = None,
4545
force_fetch: bool = False,
4646
**sdk_options,
47-
) -> Union[str, list, dict, bytes]:
47+
) -> Optional[Union[str, dict, bytes]]:
4848
"""
4949
Retrieve a parameter value or return the cached value
5050
@@ -81,6 +81,7 @@ def get(
8181
# of supported transform is small and the probability that a given
8282
# parameter will always be used in a specific transform, this should be
8383
# an acceptable tradeoff.
84+
value: Optional[Union[str, bytes, dict]] = None
8485
key = (name, transform)
8586

8687
if not force_fetch and self._has_not_expired(key):
@@ -92,7 +93,7 @@ def get(
9293
except Exception as exc:
9394
raise GetParameterError(str(exc))
9495

95-
if transform is not None:
96+
if transform:
9697
if isinstance(value, bytes):
9798
value = value.decode("utf-8")
9899
value = transform_value(value, transform)
@@ -146,26 +147,25 @@ def get_multiple(
146147
TransformParameterError
147148
When the parameter provider fails to transform a parameter value.
148149
"""
149-
150150
key = (path, transform)
151151

152152
if not force_fetch and self._has_not_expired(key):
153153
return self.store[key].value
154154

155155
try:
156-
values: Dict[str, Union[str, bytes, dict, None]] = self._get_multiple(path, **sdk_options)
156+
values = self._get_multiple(path, **sdk_options)
157157
# Encapsulate all errors into a generic GetParameterError
158158
except Exception as exc:
159159
raise GetParameterError(str(exc))
160160

161-
if transform is not None:
162-
for (key, value) in values.items():
163-
_transform = get_transform_method(key, transform)
164-
if _transform is None:
161+
if transform:
162+
transformed_values: dict = {}
163+
for (item, value) in values.items():
164+
_transform = get_transform_method(item, transform)
165+
if not _transform:
165166
continue
166-
167-
values[key] = transform_value(value, _transform, raise_on_transform_error)
168-
167+
transformed_values[item] = transform_value(value, _transform, raise_on_transform_error)
168+
values.update(transformed_values)
169169
self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age))
170170

171171
return values
@@ -217,7 +217,9 @@ def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[
217217
return None
218218

219219

220-
def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]:
220+
def transform_value(
221+
value: str, transform: str, raise_on_transform_error: Optional[bool] = True
222+
) -> Optional[Union[dict, bytes]]:
221223
"""
222224
Apply a transform to a value
223225

aws_lambda_powertools/utilities/parameters/ssm.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,17 @@ def __init__(self, config: Optional[Config] = None, boto3_session: Optional[boto
8787

8888
super().__init__()
8989

90-
def get(
90+
# We break Liskov substitution principle due to differences in signatures of this method and superclass get method
91+
# We ignore mypy error, as changes to the signature here or in a superclass is a breaking change to users
92+
def get( # type: ignore[override]
9193
self,
9294
name: str,
9395
max_age: int = DEFAULT_MAX_AGE_SECS,
9496
transform: Optional[str] = None,
9597
decrypt: bool = False,
9698
force_fetch: bool = False,
9799
**sdk_options
98-
) -> Union[str, list, dict, bytes]:
100+
) -> Optional[Union[str, dict, bytes]]:
99101
"""
100102
Retrieve a parameter value or return the cached value
101103

aws_lambda_powertools/utilities/parser/envelopes/apigw.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
2727
Parsed detail payload with model provided
2828
"""
2929
logger.debug(f"Parsing incoming data with Api Gateway model {APIGatewayProxyEventModel}")
30-
parsed_envelope = APIGatewayProxyEventModel.parse_obj(data)
30+
parsed_envelope: APIGatewayProxyEventModel = APIGatewayProxyEventModel.parse_obj(data)
3131
logger.debug(f"Parsing event payload in `detail` with {model}")
3232
return self._parse(data=parsed_envelope.body, model=model)

aws_lambda_powertools/utilities/parser/envelopes/apigwv2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
2727
Parsed detail payload with model provided
2828
"""
2929
logger.debug(f"Parsing incoming data with Api Gateway model V2 {APIGatewayProxyEventV2Model}")
30-
parsed_envelope = APIGatewayProxyEventV2Model.parse_obj(data)
30+
parsed_envelope: APIGatewayProxyEventV2Model = APIGatewayProxyEventV2Model.parse_obj(data)
3131
logger.debug(f"Parsing event payload in `detail` with {model}")
3232
return self._parse(data=parsed_envelope.body, model=model)

aws_lambda_powertools/utilities/parser/envelopes/event_bridge.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
2727
Parsed detail payload with model provided
2828
"""
2929
logger.debug(f"Parsing incoming data with EventBridge model {EventBridgeModel}")
30-
parsed_envelope = EventBridgeModel.parse_obj(data)
30+
parsed_envelope: EventBridgeModel = EventBridgeModel.parse_obj(data)
3131
logger.debug(f"Parsing event payload in `detail` with {model}")
3232
return self._parse(data=parsed_envelope.detail, model=model)

aws_lambda_powertools/utilities/parser/envelopes/kinesis.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Dict, List, Optional, Type, Union
2+
from typing import Any, Dict, List, Optional, Type, Union, cast
33

44
from ..models import KinesisDataStreamModel
55
from ..types import Model
@@ -37,6 +37,9 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
3737
logger.debug(f"Parsing incoming data with Kinesis model {KinesisDataStreamModel}")
3838
parsed_envelope: KinesisDataStreamModel = KinesisDataStreamModel.parse_obj(data)
3939
logger.debug(f"Parsing Kinesis records in `body` with {model}")
40-
return [
41-
self._parse(data=record.kinesis.data.decode("utf-8"), model=model) for record in parsed_envelope.Records
42-
]
40+
models = []
41+
for record in parsed_envelope.Records:
42+
# We allow either AWS expected contract (bytes) or a custom Model, see #943
43+
data = cast(bytes, record.kinesis.data)
44+
models.append(self._parse(data=data.decode("utf-8"), model=model))
45+
return models

aws_lambda_powertools/utilities/parser/envelopes/sns.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Dict, List, Optional, Type, Union
2+
from typing import Any, Dict, List, Optional, Type, Union, cast
33

44
from ..models import SnsModel, SnsNotificationModel, SqsModel
55
from ..types import Model
@@ -69,6 +69,8 @@ def parse(self, data: Optional[Union[Dict[str, Any], Any]], model: Type[Model])
6969
parsed_envelope = SqsModel.parse_obj(data)
7070
output = []
7171
for record in parsed_envelope.Records:
72-
sns_notification = SnsNotificationModel.parse_raw(record.body)
72+
# We allow either AWS expected contract (str) or a custom Model, see #943
73+
body = cast(str, record.body)
74+
sns_notification = SnsNotificationModel.parse_raw(body)
7375
output.append(self._parse(data=sns_notification.Message, model=model))
7476
return output

aws_lambda_powertools/utilities/parser/models/alb.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import Dict, Union
1+
from typing import Dict, Type, Union
22

33
from pydantic import BaseModel
44

5-
from aws_lambda_powertools.utilities.parser.types import Model
6-
75

86
class AlbRequestContextData(BaseModel):
97
targetGroupArn: str
@@ -16,7 +14,7 @@ class AlbRequestContext(BaseModel):
1614
class AlbModel(BaseModel):
1715
httpMethod: str
1816
path: str
19-
body: Union[str, Model]
17+
body: Union[str, Type[BaseModel]]
2018
isBase64Encoded: bool
2119
headers: Dict[str, str]
2220
queryStringParameters: Dict[str, str]

aws_lambda_powertools/utilities/parser/models/apigw.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from datetime import datetime
2-
from typing import Any, Dict, List, Optional, Union
2+
from typing import Any, Dict, List, Optional, Type, Union
33

44
from pydantic import BaseModel, root_validator
55
from pydantic.networks import IPvAnyNetwork
66

7-
from aws_lambda_powertools.utilities.parser.types import Literal, Model
7+
from aws_lambda_powertools.utilities.parser.types import Literal
88

99

1010
class ApiGatewayUserCertValidity(BaseModel):
@@ -89,4 +89,4 @@ class APIGatewayProxyEventModel(BaseModel):
8989
pathParameters: Optional[Dict[str, str]]
9090
stageVariables: Optional[Dict[str, str]]
9191
isBase64Encoded: bool
92-
body: Optional[Union[str, Model]]
92+
body: Optional[Union[str, Type[BaseModel]]]

0 commit comments

Comments
 (0)