Skip to content

Commit d5c3431

Browse files
authored
fix(tracer): mypy generic to preserve decorated method signature (#529)
1 parent 89337a2 commit d5c3431

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

Diff for: aws_lambda_powertools/tracing/tracer.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import numbers
77
import os
8-
from typing import Any, Callable, Dict, Optional, Sequence, Union
8+
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, TypeVar, Union, cast, overload
99

1010
from ..shared import constants
1111
from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice
@@ -18,6 +18,9 @@
1818
aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE)
1919
aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE)
2020

21+
AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001
22+
AnyAwaitableT = TypeVar("AnyAwaitableT", bound=Awaitable)
23+
2124

2225
class Tracer:
2326
"""Tracer using AWS-XRay to provide decorators with known defaults for Lambda functions
@@ -329,12 +332,26 @@ def decorate(event, context, **kwargs):
329332

330333
return decorate
331334

335+
# see #465
336+
@overload
337+
def capture_method(self, method: "AnyCallableT") -> "AnyCallableT":
338+
...
339+
340+
@overload
332341
def capture_method(
333342
self,
334-
method: Optional[Callable] = None,
343+
method: None = None,
335344
capture_response: Optional[bool] = None,
336345
capture_error: Optional[bool] = None,
337-
):
346+
) -> Callable[["AnyCallableT"], "AnyCallableT"]:
347+
...
348+
349+
def capture_method(
350+
self,
351+
method: Optional[AnyCallableT] = None,
352+
capture_response: Optional[bool] = None,
353+
capture_error: Optional[bool] = None,
354+
) -> AnyCallableT:
338355
"""Decorator to create subsegment for arbitrary functions
339356
340357
It also captures both response and exceptions as metadata
@@ -487,8 +504,9 @@ async def async_tasks():
487504
# Return a partial function with args filled
488505
if method is None:
489506
logger.debug("Decorator called with parameters")
490-
return functools.partial(
491-
self.capture_method, capture_response=capture_response, capture_error=capture_error
507+
return cast(
508+
AnyCallableT,
509+
functools.partial(self.capture_method, capture_response=capture_response, capture_error=capture_error),
492510
)
493511

494512
method_name = f"{method.__name__}"
@@ -509,7 +527,7 @@ async def async_tasks():
509527
return self._decorate_generator_function(
510528
method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name
511529
)
512-
elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__):
530+
elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__): # type: ignore
513531
return self._decorate_generator_function_with_context_manager(
514532
method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name
515533
)
@@ -602,11 +620,11 @@ def decorate(*args, **kwargs):
602620

603621
def _decorate_sync_function(
604622
self,
605-
method: Callable,
623+
method: AnyCallableT,
606624
capture_response: Optional[Union[bool, str]] = None,
607625
capture_error: Optional[Union[bool, str]] = None,
608626
method_name: Optional[str] = None,
609-
):
627+
) -> AnyCallableT:
610628
@functools.wraps(method)
611629
def decorate(*args, **kwargs):
612630
with self.provider.in_subsegment(name=f"## {method_name}") as subsegment:
@@ -628,7 +646,7 @@ def decorate(*args, **kwargs):
628646

629647
return response
630648

631-
return decorate
649+
return cast(AnyCallableT, decorate)
632650

633651
def _add_response_as_metadata(
634652
self,

0 commit comments

Comments
 (0)