diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index b82c510036a..4322b2b6bbc 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -18,12 +18,14 @@ Optional, TypeVar, Union, + overload, ) import jmespath from ..shared import constants from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice +from ..shared.types import AnyCallableT from .exceptions import InvalidLoggerSamplingRateError from .filters import SuppressFilter from .formatter import ( @@ -314,13 +316,33 @@ def _configure_sampling(self): f"Please review POWERTOOLS_LOGGER_SAMPLE_RATE environment variable." ) + @overload def inject_lambda_context( self, - lambda_handler: Optional[Callable[[Dict, Any], Any]] = None, + lambda_handler: AnyCallableT, log_event: Optional[bool] = None, correlation_id_path: Optional[str] = None, clear_state: Optional[bool] = False, - ): + ) -> AnyCallableT: + ... + + @overload + def inject_lambda_context( + self, + lambda_handler: None = None, + log_event: Optional[bool] = None, + correlation_id_path: Optional[str] = None, + clear_state: Optional[bool] = False, + ) -> Callable[[AnyCallableT], AnyCallableT]: + ... + + def inject_lambda_context( + self, + lambda_handler: Optional[AnyCallableT] = None, + log_event: Optional[bool] = None, + correlation_id_path: Optional[str] = None, + clear_state: Optional[bool] = False, + ) -> Any: """Decorator to capture Lambda contextual info and inject into logger Parameters