5
5
import logging
6
6
import numbers
7
7
import os
8
- from typing import Any , Callable , Dict , Optional , Sequence , Union
8
+ from typing import Any , Awaitable , Callable , Dict , Optional , Sequence , TypeVar , Union , cast , overload
9
9
10
10
from ..shared import constants
11
11
from ..shared .functions import resolve_env_var_choice , resolve_truthy_env_var_choice
18
18
aws_xray_sdk = LazyLoader (constants .XRAY_SDK_MODULE , globals (), constants .XRAY_SDK_MODULE )
19
19
aws_xray_sdk .core = LazyLoader (constants .XRAY_SDK_CORE_MODULE , globals (), constants .XRAY_SDK_CORE_MODULE )
20
20
21
+ AnyCallableT = TypeVar ("AnyCallableT" , bound = Callable [..., Any ]) # noqa: VNE001
22
+ AnyAwaitableT = TypeVar ("AnyAwaitableT" , bound = Awaitable )
23
+
21
24
22
25
class Tracer :
23
26
"""Tracer using AWS-XRay to provide decorators with known defaults for Lambda functions
@@ -329,12 +332,26 @@ def decorate(event, context, **kwargs):
329
332
330
333
return decorate
331
334
335
+ # see #465
336
+ @overload
337
+ def capture_method (self , method : "AnyCallableT" ) -> "AnyCallableT" :
338
+ ...
339
+
340
+ @overload
332
341
def capture_method (
333
342
self ,
334
- method : Optional [ Callable ] = None ,
343
+ method : None = None ,
335
344
capture_response : Optional [bool ] = None ,
336
345
capture_error : Optional [bool ] = None ,
337
- ):
346
+ ) -> Callable [["AnyCallableT" ], "AnyCallableT" ]:
347
+ ...
348
+
349
+ def capture_method (
350
+ self ,
351
+ method : Optional [AnyCallableT ] = None ,
352
+ capture_response : Optional [bool ] = None ,
353
+ capture_error : Optional [bool ] = None ,
354
+ ) -> AnyCallableT :
338
355
"""Decorator to create subsegment for arbitrary functions
339
356
340
357
It also captures both response and exceptions as metadata
@@ -487,8 +504,9 @@ async def async_tasks():
487
504
# Return a partial function with args filled
488
505
if method is None :
489
506
logger .debug ("Decorator called with parameters" )
490
- return functools .partial (
491
- self .capture_method , capture_response = capture_response , capture_error = capture_error
507
+ return cast (
508
+ AnyCallableT ,
509
+ functools .partial (self .capture_method , capture_response = capture_response , capture_error = capture_error ),
492
510
)
493
511
494
512
method_name = f"{ method .__name__ } "
@@ -509,7 +527,7 @@ async def async_tasks():
509
527
return self ._decorate_generator_function (
510
528
method = method , capture_response = capture_response , capture_error = capture_error , method_name = method_name
511
529
)
512
- elif hasattr (method , "__wrapped__" ) and inspect .isgeneratorfunction (method .__wrapped__ ):
530
+ elif hasattr (method , "__wrapped__" ) and inspect .isgeneratorfunction (method .__wrapped__ ): # type: ignore
513
531
return self ._decorate_generator_function_with_context_manager (
514
532
method = method , capture_response = capture_response , capture_error = capture_error , method_name = method_name
515
533
)
@@ -602,11 +620,11 @@ def decorate(*args, **kwargs):
602
620
603
621
def _decorate_sync_function (
604
622
self ,
605
- method : Callable ,
623
+ method : AnyCallableT ,
606
624
capture_response : Optional [Union [bool , str ]] = None ,
607
625
capture_error : Optional [Union [bool , str ]] = None ,
608
626
method_name : Optional [str ] = None ,
609
- ):
627
+ ) -> AnyCallableT :
610
628
@functools .wraps (method )
611
629
def decorate (* args , ** kwargs ):
612
630
with self .provider .in_subsegment (name = f"## { method_name } " ) as subsegment :
@@ -628,7 +646,7 @@ def decorate(*args, **kwargs):
628
646
629
647
return response
630
648
631
- return decorate
649
+ return cast ( AnyCallableT , decorate )
632
650
633
651
def _add_response_as_metadata (
634
652
self ,
0 commit comments