Skip to content

Commit 9432a53

Browse files
authored
fix(mypy): a few return types, type signatures, and untyped areas (aws-powertools#718)
1 parent 68c810e commit 9432a53

File tree

12 files changed

+47
-34
lines changed

12 files changed

+47
-34
lines changed

aws_lambda_powertools/logging/formatter.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class LambdaPowertoolsFormatter(BasePowertoolsFormatter):
5858
def __init__(
5959
self,
6060
json_serializer: Optional[Callable[[Dict], str]] = None,
61-
json_deserializer: Optional[Callable[[Dict], str]] = None,
61+
json_deserializer: Optional[Callable[[Union[Dict, str, bool, int, float]], str]] = None,
6262
json_default: Optional[Callable[[Any], Any]] = None,
6363
datefmt: Optional[str] = None,
6464
log_record_order: Optional[List[str]] = None,
@@ -106,7 +106,7 @@ def __init__(
106106
self.update_formatter = self.append_keys # alias to old method
107107

108108
if self.utc:
109-
self.converter = time.gmtime
109+
self.converter = time.gmtime # type: ignore
110110

111111
super(LambdaPowertoolsFormatter, self).__init__(datefmt=self.datefmt)
112112

@@ -128,7 +128,7 @@ def format(self, record: logging.LogRecord) -> str: # noqa: A003
128128
return self.serialize(log=formatted_log)
129129

130130
def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str:
131-
record_ts = self.converter(record.created)
131+
record_ts = self.converter(record.created) # type: ignore
132132
if datefmt:
133133
return time.strftime(datefmt, record_ts)
134134

@@ -201,7 +201,7 @@ def _extract_log_exception(self, log_record: logging.LogRecord) -> Union[Tuple[s
201201
Log record with constant traceback info and exception name
202202
"""
203203
if log_record.exc_info:
204-
return self.formatException(log_record.exc_info), log_record.exc_info[0].__name__
204+
return self.formatException(log_record.exc_info), log_record.exc_info[0].__name__ # type: ignore
205205

206206
return None, None
207207

aws_lambda_powertools/logging/logger.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def registered_handler(self) -> logging.Handler:
361361
return handlers[0]
362362

363363
@property
364-
def registered_formatter(self) -> Optional[PowertoolsFormatter]:
364+
def registered_formatter(self) -> PowertoolsFormatter:
365365
"""Convenience property to access logger formatter"""
366366
return self.registered_handler.formatter # type: ignore
367367

@@ -405,7 +405,9 @@ def get_correlation_id(self) -> Optional[str]:
405405
str, optional
406406
Value for the correlation id
407407
"""
408-
return self.registered_formatter.log_format.get("correlation_id")
408+
if isinstance(self.registered_formatter, LambdaPowertoolsFormatter):
409+
return self.registered_formatter.log_format.get("correlation_id")
410+
return None
409411

410412
@staticmethod
411413
def _get_log_level(level: Union[str, int, None]) -> Union[str, int]:

aws_lambda_powertools/metrics/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
self._metric_unit_options = list(MetricUnit.__members__)
9191
self.metadata_set = metadata_set if metadata_set is not None else {}
9292

93-
def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float):
93+
def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float) -> None:
9494
"""Adds given metric
9595
9696
Example
@@ -215,7 +215,7 @@ def serialize_metric_set(
215215
**metric_names_and_values, # "single_metric": 1.0
216216
}
217217

218-
def add_dimension(self, name: str, value: str):
218+
def add_dimension(self, name: str, value: str) -> None:
219219
"""Adds given dimension to all metrics
220220
221221
Example
@@ -241,7 +241,7 @@ def add_dimension(self, name: str, value: str):
241241
# checking before casting improves performance in most cases
242242
self.dimension_set[name] = value if isinstance(value, str) else str(value)
243243

244-
def add_metadata(self, key: str, value: Any):
244+
def add_metadata(self, key: str, value: Any) -> None:
245245
"""Adds high cardinal metadata for metrics object
246246
247247
This will not be available during metrics visualization.

aws_lambda_powertools/metrics/metric.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class SingleMetric(MetricManager):
4242
Inherits from `aws_lambda_powertools.metrics.base.MetricManager`
4343
"""
4444

45-
def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float):
45+
def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float) -> None:
4646
"""Method to prevent more than one metric being created
4747
4848
Parameters

aws_lambda_powertools/metrics/metrics.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import json
33
import logging
44
import warnings
5-
from typing import Any, Callable, Dict, Optional
5+
from typing import Any, Callable, Dict, Optional, Union, cast
66

7+
from ..shared.types import AnyCallableT
78
from .base import MetricManager, MetricUnit
89
from .metric import single_metric
910

@@ -87,7 +88,7 @@ def __init__(self, service: Optional[str] = None, namespace: Optional[str] = Non
8788
service=self.service,
8889
)
8990

90-
def set_default_dimensions(self, **dimensions):
91+
def set_default_dimensions(self, **dimensions) -> None:
9192
"""Persist dimensions across Lambda invocations
9293
9394
Parameters
@@ -113,10 +114,10 @@ def lambda_handler():
113114

114115
self.default_dimensions.update(**dimensions)
115116

116-
def clear_default_dimensions(self):
117+
def clear_default_dimensions(self) -> None:
117118
self.default_dimensions.clear()
118119

119-
def clear_metrics(self):
120+
def clear_metrics(self) -> None:
120121
logger.debug("Clearing out existing metric set from memory")
121122
self.metric_set.clear()
122123
self.dimension_set.clear()
@@ -125,11 +126,11 @@ def clear_metrics(self):
125126

126127
def log_metrics(
127128
self,
128-
lambda_handler: Optional[Callable[[Any, Any], Any]] = None,
129+
lambda_handler: Union[Callable[[Dict, Any], Any], Optional[Callable[[Dict, Any, Optional[Dict]], Any]]] = None,
129130
capture_cold_start_metric: bool = False,
130131
raise_on_empty_metrics: bool = False,
131132
default_dimensions: Optional[Dict[str, str]] = None,
132-
):
133+
) -> AnyCallableT:
133134
"""Decorator to serialize and publish metrics at the end of a function execution.
134135
135136
Be aware that the log_metrics **does call* the decorated function (e.g. lambda_handler).
@@ -169,11 +170,14 @@ def handler(event, context):
169170
# Return a partial function with args filled
170171
if lambda_handler is None:
171172
logger.debug("Decorator called with parameters")
172-
return functools.partial(
173-
self.log_metrics,
174-
capture_cold_start_metric=capture_cold_start_metric,
175-
raise_on_empty_metrics=raise_on_empty_metrics,
176-
default_dimensions=default_dimensions,
173+
return cast(
174+
AnyCallableT,
175+
functools.partial(
176+
self.log_metrics,
177+
capture_cold_start_metric=capture_cold_start_metric,
178+
raise_on_empty_metrics=raise_on_empty_metrics,
179+
default_dimensions=default_dimensions,
180+
),
177181
)
178182

179183
@functools.wraps(lambda_handler)
@@ -194,9 +198,9 @@ def decorate(event, context):
194198

195199
return response
196200

197-
return decorate
201+
return cast(AnyCallableT, decorate)
198202

199-
def __add_cold_start_metric(self, context: Any):
203+
def __add_cold_start_metric(self, context: Any) -> None:
200204
"""Add cold start metric and function_name dimension
201205
202206
Parameters

aws_lambda_powertools/middleware_factory/factory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def final_decorator(func: Optional[Callable] = None, **kwargs):
118118
if not inspect.isfunction(func):
119119
# @custom_middleware(True) vs @custom_middleware(log_event=True)
120120
raise MiddlewareInvalidArgumentError(
121-
f"Only keyword arguments is supported for middlewares: {decorator.__qualname__} received {func}"
121+
f"Only keyword arguments is supported for middlewares: {decorator.__qualname__} received {func}" # type: ignore # noqa: E501
122122
)
123123

124124
@functools.wraps(func)

aws_lambda_powertools/shared/jmespath_utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,23 @@
66

77
import jmespath
88
from jmespath.exceptions import LexerError
9+
from jmespath.functions import Functions, signature
910

1011
from aws_lambda_powertools.exceptions import InvalidEnvelopeExpressionError
1112

1213
logger = logging.getLogger(__name__)
1314

1415

15-
class PowertoolsFunctions(jmespath.functions.Functions):
16-
@jmespath.functions.signature({"types": ["string"]})
16+
class PowertoolsFunctions(Functions):
17+
@signature({"types": ["string"]})
1718
def _func_powertools_json(self, value):
1819
return json.loads(value)
1920

20-
@jmespath.functions.signature({"types": ["string"]})
21+
@signature({"types": ["string"]})
2122
def _func_powertools_base64(self, value):
2223
return base64.b64decode(value).decode()
2324

24-
@jmespath.functions.signature({"types": ["string"]})
25+
@signature({"types": ["string"]})
2526
def _func_powertools_base64_gzip(self, value):
2627
encoded = base64.b64decode(value)
2728
uncompressed = gzip.decompress(encoded)

aws_lambda_powertools/tracing/tracer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
logger = logging.getLogger(__name__)
1818

1919
aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE)
20-
aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE)
20+
aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE) # type: ignore # noqa: E501
2121

2222

2323
class Tracer:

aws_lambda_powertools/utilities/data_classes/sqs_event.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def data_type(self) -> str:
7575

7676

7777
class SQSMessageAttributes(Dict[str, SQSMessageAttribute]):
78-
def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]:
78+
def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: # type: ignore
7979
item = super(SQSMessageAttributes, self).get(key)
80-
return None if item is None else SQSMessageAttribute(item)
80+
return None if item is None else SQSMessageAttribute(item) # type: ignore
8181

8282

8383
class SQSRecord(DictWrapper):

aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _update_record(self, data_record: DataRecord):
155155
"ExpressionAttributeNames": expression_attr_names,
156156
}
157157

158-
self.table.update_item(**kwargs) # type: ignore
158+
self.table.update_item(**kwargs)
159159

160160
def _delete_record(self, data_record: DataRecord) -> None:
161161
logger.debug(f"Deleting record for idempotency key: {data_record.idempotency_key}")

aws_lambda_powertools/utilities/validation/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def validate_data_against_schema(data: Union[Dict, str], schema: Dict, formats:
3333
except (TypeError, AttributeError, fastjsonschema.JsonSchemaDefinitionException) as e:
3434
raise InvalidSchemaFormatError(f"Schema received: {schema}, Formats: {formats}. Error: {e}")
3535
except fastjsonschema.JsonSchemaValueException as e:
36-
message = f"Failed schema validation. Error: {e.message}, Path: {e.path}, Data: {e.value}"
36+
message = f"Failed schema validation. Error: {e.message}, Path: {e.path}, Data: {e.value}" # noqa: B306
3737
raise SchemaValidationError(
3838
message,
39-
validation_message=e.message,
39+
validation_message=e.message, # noqa: B306
4040
name=e.name,
4141
path=e.path,
4242
value=e.value,

mypy.ini

+6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ show_error_context = True
1111
[mypy-jmespath]
1212
ignore_missing_imports=True
1313

14+
[mypy-jmespath.exceptions]
15+
ignore_missing_imports=True
16+
17+
[mypy-jmespath.functions]
18+
ignore_missing_imports=True
19+
1420
[mypy-boto3]
1521
ignore_missing_imports = True
1622

0 commit comments

Comments
 (0)