From 195dcf7c34a3f6f62887d23653a78425bea6b252 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 7 Aug 2024 00:08:58 +0100 Subject: [PATCH 1/4] Adding tracer to v3 --- aws_lambda_powertools/tracing/base.py | 30 ++-- .../tracing/provider/__init__.py | 0 .../provider/aws_xray/aws_xray_tracer.py | 142 ++++++++++++++++++ .../tracing/provider/base.py | 112 ++++++++++++++ aws_lambda_powertools/tracing/tracer.py | 88 +++++++---- tests/unit/test_tracing.py | 42 +++++- 6 files changed, 366 insertions(+), 48 deletions(-) create mode 100644 aws_lambda_powertools/tracing/provider/__init__.py create mode 100644 aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py create mode 100644 aws_lambda_powertools/tracing/provider/base.py diff --git a/aws_lambda_powertools/tracing/base.py b/aws_lambda_powertools/tracing/base.py index 6ea58da6b5a..5adbb585756 100644 --- a/aws_lambda_powertools/tracing/base.py +++ b/aws_lambda_powertools/tracing/base.py @@ -1,14 +1,14 @@ -import abc import numbers import traceback +from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Generator, List, Optional, Sequence, Union -class BaseSegment(abc.ABC): +class BaseSegment(ABC): """Holds common properties and methods on segment and subsegment.""" - @abc.abstractmethod + @abstractmethod def close(self, end_time: Optional[int] = None): """Close the trace entity by setting `end_time` and flip the in progress flag to False. @@ -19,15 +19,15 @@ def close(self, end_time: Optional[int] = None): Time in epoch seconds, by default current time will be used. """ - @abc.abstractmethod + @abstractmethod def add_subsegment(self, subsegment: Any): """Add input subsegment as a child subsegment.""" - @abc.abstractmethod + @abstractmethod def remove_subsegment(self, subsegment: Any): """Remove input subsegment from child subsegments.""" - @abc.abstractmethod + @abstractmethod def put_annotation(self, key: str, value: Union[str, numbers.Number, bool]) -> None: """Annotate segment or subsegment with a key-value pair. @@ -41,7 +41,7 @@ def put_annotation(self, key: str, value: Union[str, numbers.Number, bool]) -> N Annotation value """ - @abc.abstractmethod + @abstractmethod def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None: """Add metadata to segment or subsegment. Metadata is not indexed but can be later retrieved by BatchGetTraces API. @@ -56,7 +56,7 @@ def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None Metadata namespace, by default 'default' """ - @abc.abstractmethod + @abstractmethod def add_exception(self, exception: BaseException, stack: List[traceback.StackSummary], remote: bool = False): """Add an exception to trace entities. @@ -73,8 +73,8 @@ def add_exception(self, exception: BaseException, stack: List[traceback.StackSum """ -class BaseProvider(abc.ABC): - @abc.abstractmethod +class BaseProvider(ABC): + @abstractmethod @contextmanager def in_subsegment(self, name=None, **kwargs) -> Generator[BaseSegment, None, None]: """Return a subsegment context manger. @@ -87,7 +87,7 @@ def in_subsegment(self, name=None, **kwargs) -> Generator[BaseSegment, None, Non Optional parameters to be propagated to segment """ - @abc.abstractmethod + @abstractmethod @contextmanager def in_subsegment_async(self, name=None, **kwargs) -> Generator[BaseSegment, None, None]: """Return a subsegment async context manger. @@ -100,7 +100,7 @@ def in_subsegment_async(self, name=None, **kwargs) -> Generator[BaseSegment, Non Optional parameters to be propagated to segment """ - @abc.abstractmethod + @abstractmethod def put_annotation(self, key: str, value: Union[str, numbers.Number, bool]) -> None: """Annotate current active trace entity with a key-value pair. @@ -114,7 +114,7 @@ def put_annotation(self, key: str, value: Union[str, numbers.Number, bool]) -> N Annotation value """ - @abc.abstractmethod + @abstractmethod def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None: """Add metadata to the current active trace entity. @@ -130,7 +130,7 @@ def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None Metadata namespace, by default 'default' """ - @abc.abstractmethod + @abstractmethod def patch(self, modules: Sequence[str]) -> None: """Instrument a set of supported libraries @@ -140,6 +140,6 @@ def patch(self, modules: Sequence[str]) -> None: Set of modules to be patched """ - @abc.abstractmethod + @abstractmethod def patch_all(self) -> None: """Instrument all supported libraries""" diff --git a/aws_lambda_powertools/tracing/provider/__init__.py b/aws_lambda_powertools/tracing/provider/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py b/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py new file mode 100644 index 00000000000..b786c2e92d6 --- /dev/null +++ b/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager, contextmanager +from numbers import Number +from typing import Any, AsyncGenerator, Generator, Literal, Sequence, Union + +from aws_lambda_powertools.shared import constants +from aws_lambda_powertools.shared.lazy_import import LazyLoader +from aws_lambda_powertools.tracing.provider.base import BaseProvider, BaseSpan + +aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE) + + +class XraySpan(BaseSpan): + def __init__(self, subsegment): + self.subsegment = subsegment + self.add_subsegment = self.subsegment.add_subsegment + self.remove_subsegment = self.subsegment.remove_subsegment + self.put_annotation = self.subsegment.put_annotation + self.put_metadata = self.subsegment.put_metadata + self.add_exception = self.subsegment.add_exception + self.close = self.subsegment.close + + def set_attribute( + self, + key: str, + value: Any, + category: Literal["Annotation", "Metadata", "Auto"] = "Auto", + **kwargs, + ) -> None: + """ + Set an attribute on this span with a key-value pair. + + Parameters + ---------- + key: str + attribute key + value: Any + Value for attribute + category: Literal["Annotation","Metadata","Auto"] = "Auto" + This parameter specifies the category of attribute to set. + - **"Annotation"**: Sets the attribute as an Annotation. + - **"Metadata"**: Sets the attribute as Metadata. + - **"Auto" (default)**: Automatically determines the attribute + type based on its value. + + kwargs: Optional[dict] + Optional parameters to be passed to provider.set_attributes + """ + if category == "Annotation": + self.put_annotation(key=key, value=value) + return + + if category == "Metadata": + self.put_metadata(key=key, value=value, namespace=kwargs.get("namespace", "dafault")) + return + + # Auto + if isinstance(value, (str, Number, bool)): + self.put_annotation(key=key, value=value) + return + + # Auto & not in (str, Number, bool) + self.put_metadata(key=key, value=value, namespace=kwargs.get("namespace", "dafault")) + + def record_exception(self, exception: BaseException, **kwargs): + stack = aws_xray_sdk.core.utils.stacktrace.get_stacktrace() + self.add_exception(exception=exception, stack=stack) + + +class AwsXrayProvider(BaseProvider): + def __init__(self, xray_recorder=None): + if not xray_recorder: + from aws_xray_sdk.core import xray_recorder + + self.recorder = xray_recorder + self.in_subsegment = self.recorder.in_subsegment + self.in_subsegment_async = self.recorder.in_subsegment_async + + @contextmanager + def trace(self, name: str, **kwargs) -> Generator[XraySpan, None, None]: + with self.in_subsegment(name=name, **kwargs) as sub_segment: + yield XraySpan(subsegment=sub_segment) + + @asynccontextmanager + async def trace_async(self, name: str, **kwargs) -> AsyncGenerator[XraySpan, None]: + async with self.in_subsegment_async(name=name, **kwargs) as subsegment: + yield XraySpan(subsegment=subsegment) + + def set_attribute( + self, + key: str, + value: Any, + category: Literal["Annotation", "Metadata", "Auto"] = "Auto", + **kwargs, + ) -> None: + """ + Set an attribute on the current active span with a key-value pair. + + Parameters + ---------- + key: str + attribute key + value: Any + Value for attribute + category: Literal["Annotation","Metadata","Auto"] = "Auto" + This parameter specifies the type of attribute to set. + - **"Annotation"**: Sets the attribute as an Annotation. + - **"Metadata"**: Sets the attribute as Metadata. + - **"Auto" (default)**: Automatically determines the attribute + type based on its value. + + kwargs: Optional[dict] + Optional parameters to be passed to provider.set_attributes + """ + if category == "Annotation": + self.put_annotation(key=key, value=value) + return + + if category == "Metadata": + self.put_metadata(key=key, value=value, namespace=kwargs.get("namespace", "dafault")) + return + + # Auto + if isinstance(value, (str, Number, bool)): + self.put_annotation(key=key, value=value) + return + + # Auto & not in (str, Number, bool) + self.put_metadata(key=key, value=value, namespace=kwargs.get("namespace", "dafault")) + + def put_annotation(self, key: str, value: Union[str, Number, bool]) -> None: + return self.recorder.put_annotation(key=key, value=value) + + def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None: + return self.recorder.put_metadata(key=key, value=value, namespace=namespace) + + def patch(self, modules: Sequence[str]) -> None: + return aws_xray_sdk.core.patch(modules) + + def patch_all(self) -> None: + return aws_xray_sdk.core.patch_all() diff --git a/aws_lambda_powertools/tracing/provider/base.py b/aws_lambda_powertools/tracing/provider/base.py new file mode 100644 index 00000000000..eb184d65245 --- /dev/null +++ b/aws_lambda_powertools/tracing/provider/base.py @@ -0,0 +1,112 @@ +from abc import ABC, abstractmethod +from contextlib import asynccontextmanager, contextmanager +from typing import Any, AsyncGenerator, Generator, Sequence + + +class BaseSpan(ABC): + """A span represents a unit of work or operation within a trace. + Spans are the building blocks of Traces.""" + + @abstractmethod + def set_attribute(self, key: str, value: Any, **kwargs) -> None: + """Set an attribute for a span with a key-value pair. + + Parameters + ---------- + key: str + Attribute key + value: Any + Attribute value + kwargs: Optional[dict] + Optional parameters + """ + + @abstractmethod + def record_exception(self, exception: BaseException, **kwargs): + """Records an exception to this Span. + + Parameters + ---------- + exception: Exception + Caught exception during the exectution of this Span + kwargs: Optional[dict] + Optional parameters + """ + + +class BaseProvider(ABC): + """BaseProvider is an abstract base class that defines the expected behavior for tracing providers + used by Tracer. Inheriting classes must implement this interface to be compatible with Tracer. + """ + + @abstractmethod + @contextmanager + def trace(self, name: str, **kwargs) -> Generator[BaseSpan, None, None]: + """Context manager for creating a new span and set it + as the current span in this tracer's context. + + Exiting the context manager will call the span's end method, + as well as return the current span to its previous value by + returning to the previous context. + + Parameters + ---------- + name: str + Span name + kwargs: Optional[dict] + Optional parameters to be propagated to the span + """ + + @abstractmethod + @asynccontextmanager + def trace_async(self, name: str, **kwargs) -> AsyncGenerator[BaseSpan, None]: + """Async Context manager for creating a new span async and set it + as the current span in this tracer's context. + + Exiting the context manager will call the span's end method, + as well as return the current span to its previous value by + returning to the previous context. + + Parameters + ---------- + name: str + Span name + kwargs: Optional[dict] + Optional parameters to be propagated to the span + """ + + @abstractmethod + def set_attribute(self, key: str, value: Any, **kwargs) -> None: + """set attribute on current active span with a key-value pair. + + Parameters + ---------- + key: str + attribute key + value: Any + attribute value + kwargs: Optional[dict] + Optional parameters to be propagated to the span + """ + + @abstractmethod + def patch(self, modules: Sequence[str]) -> None: + """Instrument a set of given libraries if supported by provider + See specific provider for more detail + + Exmaple + ------- + tracer = Tracer(service="payment") + libraries = (['aioboto3',mysql]) + # provider.patch will be called by tracer.patch + tracer.patch(libraries) + + Parameters + ---------- + modules: Set[str] + Set of modules to be patched + """ + + @abstractmethod + def patch_all(self) -> None: + """Instrument all supported libraries""" diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index a79ac4ec738..fceab00b7bd 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -15,7 +15,8 @@ ) from aws_lambda_powertools.shared.lazy_import import LazyLoader from aws_lambda_powertools.shared.types import AnyCallableT -from aws_lambda_powertools.tracing.base import BaseProvider, BaseSegment +from aws_lambda_powertools.tracing.provider.aws_xray.aws_xray_tracer import AwsXrayProvider +from aws_lambda_powertools.tracing.provider.base import BaseProvider, BaseSpan is_cold_start = True logger = logging.getLogger(__name__) @@ -154,7 +155,7 @@ def __init__( disabled: Optional[bool] = None, auto_patch: Optional[bool] = None, patch_modules: Optional[Sequence[str]] = None, - provider: Optional[BaseProvider] = None, + provider: Optional[AwsXrayProvider] = None, ): self.__build_config( service=service, @@ -163,7 +164,7 @@ def __init__( patch_modules=patch_modules, provider=provider, ) - self.provider: BaseProvider = self._config["provider"] + self.provider: AwsXrayProvider = self._config["provider"] self.disabled = self._config["disabled"] self.service = self._config["service"] self.auto_patch = self._config["auto_patch"] @@ -177,6 +178,36 @@ def __init__( if self._is_xray_provider(): self._disable_xray_trace_batching() + def set_attribute(self, key: str, value: Any, **kwargs): + """Set an attribute on current active span with a key-value pair. + + Parameters + ---------- + key: str + attribute key + value: Any + Value for attribute + kwargs: Optional[dict] + Optional parameters to be passed to provider.set_attributes + + Example + ------- + Set an attribute for a pseudo service named payment + + tracer = Tracer(service="payment") + tracer.set_attribute("PaymentStatus", "CONFIRMED") + """ + if self.disabled: + logger.debug("Tracing has been disabled, aborting set_attribute") + return + + logger.debug(f"setting attribute on key '{key}' with '{value}'") + + namespace = kwargs.get("namespace") or self.service + kwargs["namespace"] = namespace + + self.provider.set_attribute(key=key, value=value, **kwargs) + def put_annotation(self, key: str, value: Union[str, numbers.Number, bool]): """Adds annotation to existing segment or subsegment @@ -199,7 +230,8 @@ def put_annotation(self, key: str, value: Union[str, numbers.Number, bool]): return logger.debug(f"Annotating on key '{key}' with '{value}'") - self.provider.put_annotation(key=key, value=value) + + self.provider.set_attribute(key=key, value=value, category="Annotation") def put_metadata(self, key: str, value: Any, namespace: Optional[str] = None): """Adds metadata to existing segment or subsegment @@ -227,7 +259,8 @@ def put_metadata(self, key: str, value: Any, namespace: Optional[str] = None): namespace = namespace or self.service logger.debug(f"Adding metadata on key '{key}' with '{value}' at namespace '{namespace}'") - self.provider.put_metadata(key=key, value=value, namespace=namespace) + + self.provider.set_attribute(key=key, value=value, namespace=namespace, category="Metadata") def patch(self, modules: Optional[Sequence[str]] = None): """Patch modules for instrumentation. @@ -311,7 +344,7 @@ def handler(event, context): @functools.wraps(lambda_handler) def decorate(event, context, **kwargs): - with self.provider.in_subsegment(name=f"## {lambda_handler_name}") as subsegment: + with self.provider.trace(name=f"## {lambda_handler_name}") as subsegment: try: logger.debug("Calling lambda handler") response = lambda_handler(event, context, **kwargs) @@ -335,13 +368,13 @@ def decorate(event, context, **kwargs): finally: global is_cold_start logger.debug("Annotating cold start") - subsegment.put_annotation(key="ColdStart", value=is_cold_start) + subsegment.set_attribute(key="ColdStart", value=is_cold_start) if is_cold_start: is_cold_start = False if self.service: - subsegment.put_annotation(key="Service", value=self.service) + subsegment.set_attribute(key="Service", value=self.service) return response @@ -575,7 +608,7 @@ def _decorate_async_function( ): @functools.wraps(method) async def decorate(*args, **kwargs): - async with self.provider.in_subsegment_async(name=f"## {method_name}") as subsegment: + async with self.provider.trace_async(name=f"## {method_name}") as subsegment: try: logger.debug(f"Calling method: {method_name}") response = await method(*args, **kwargs) @@ -608,7 +641,7 @@ def _decorate_generator_function( ): @functools.wraps(method) def decorate(*args, **kwargs): - with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: + with self.provider.trace(name=f"## {method_name}") as subsegment: try: logger.debug(f"Calling method: {method_name}") result = yield from method(*args, **kwargs) @@ -642,7 +675,7 @@ def _decorate_generator_function_with_context_manager( @functools.wraps(method) @contextlib.contextmanager def decorate(*args, **kwargs): - with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: + with self.provider.trace(name=f"## {method_name}") as subsegment: try: logger.debug(f"Calling method: {method_name}") with method(*args, **kwargs) as return_val: @@ -675,7 +708,7 @@ def _decorate_sync_function( ) -> AnyCallableT: @functools.wraps(method) def decorate(*args, **kwargs): - with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: + with self.provider.trace(name=f"## {method_name}") as subsegment: try: logger.debug(f"Calling method: {method_name}") response = method(*args, **kwargs) @@ -703,7 +736,7 @@ def _add_response_as_metadata( self, method_name: Optional[str] = None, data: Optional[Any] = None, - subsegment: Optional[BaseSegment] = None, + subsegment: Optional[BaseSpan] = None, capture_response: Optional[Union[bool, str]] = None, ): """Add response as metadata for given subsegment @@ -714,21 +747,20 @@ def _add_response_as_metadata( method name to add as metadata key, by default None data : Any, optional data to add as subsegment metadata, by default None - subsegment : BaseSegment, optional + subsegment : BaseSpan, optional existing subsegment to add metadata on, by default None capture_response : bool, optional Do not include response as metadata """ if data is None or not capture_response or subsegment is None: return - - subsegment.put_metadata(key=f"{method_name} response", value=data, namespace=self.service) + subsegment.set_attribute(key=f"{method_name} response", value=data, namespace=self.service, category="Metadata") def _add_full_exception_as_metadata( self, method_name: str, error: Exception, - subsegment: BaseSegment, + subsegment: BaseSpan, capture_error: Optional[bool] = None, ): """Add full exception object as metadata for given subsegment @@ -739,7 +771,7 @@ def _add_full_exception_as_metadata( method name to add as metadata key, by default None error : Exception error to add as subsegment metadata, by default None - subsegment : BaseSegment + subsegment : BaseSpan existing subsegment to add metadata on, by default None capture_error : bool, optional Do not include error as metadata, by default True @@ -747,7 +779,12 @@ def _add_full_exception_as_metadata( if not capture_error: return - subsegment.put_metadata(key=f"{method_name} error", value=error, namespace=self.service) + subsegment.set_attribute( + key=f"{method_name} error", + value=error, + namespace=self.service, + category="Metadata", + ) @staticmethod def _disable_tracer_provider(): @@ -809,16 +846,7 @@ def _reset_config(cls): cls._config = copy.copy(cls._default_config) def _patch_xray_provider(self): - # Due to Lazy Import, we need to activate `core` attrib via import - # we also need to include `patch`, `patch_all` methods - # to ensure patch calls are done via the provider - from aws_xray_sdk.core import xray_recorder # type: ignore - - provider = xray_recorder - provider.patch = aws_xray_sdk.core.patch - provider.patch_all = aws_xray_sdk.core.patch_all - - return provider + return AwsXrayProvider() def _disable_xray_trace_batching(self): """Configure X-Ray SDK to send subsegment individually over batching @@ -831,7 +859,7 @@ def _disable_xray_trace_batching(self): aws_xray_sdk.core.xray_recorder.configure(streaming_threshold=0) def _is_xray_provider(self): - return "aws_xray_sdk" in self.provider.__module__ + return isinstance(self.provider, AwsXrayProvider) def ignore_endpoint(self, hostname: Optional[str] = None, urls: Optional[List[str]] = None): """If you want to ignore certain httplib requests you can do so based on the hostname or URL that is being diff --git a/tests/unit/test_tracing.py b/tests/unit/test_tracing.py index 0d12afa629b..9d9731b0147 100644 --- a/tests/unit/test_tracing.py +++ b/tests/unit/test_tracing.py @@ -1,4 +1,5 @@ import contextlib +from numbers import Number from typing import NamedTuple from unittest import mock from unittest.mock import MagicMock @@ -31,10 +32,24 @@ def __init__( ): self.put_metadata_mock = put_metadata_mock or mocker.MagicMock() self.put_annotation_mock = put_annotation_mock or mocker.MagicMock() + self.trace = self.in_subsegment self.in_subsegment = in_subsegment or mocker.MagicMock() self.patch_mock = patch_mock or mocker.MagicMock() self.disable_tracing_provider_mock = disable_tracing_provider_mock or mocker.MagicMock() self.in_subsegment_async = in_subsegment_async or mocker.MagicMock(spec=True) + self.trace_async = self.in_subsegment_async + + def set_attribute(self, *args, **kwargs): + if kwargs.get("category") == "Metadata": + return self.put_metadata(*args, **kwargs) + + if kwargs.get("category") == "Annotation": + return self.put_annotation(*args, **kwargs) + + if isinstance(kwargs.get("value"), (str, Number, bool)): + return self.put_annotation(*args, **kwargs) + + return self.put_metadata(*args, **kwargs) def put_metadata(self, *args, **kwargs): return self.put_metadata_mock(*args, **kwargs) @@ -65,8 +80,8 @@ def reset_tracing_config(mocker): @pytest.fixture -def in_subsegment_mock(): - class AsyncContextManager(mock.MagicMock): +def in_subsegment_mock(mocker): + class AsyncContextManager(mocker.MagicMock): async def __aenter__(self, *args, **kwargs): return self.__enter__() @@ -78,10 +93,26 @@ class InSubsegment(NamedTuple): put_annotation: mock.MagicMock = mock.MagicMock() put_metadata: mock.MagicMock = mock.MagicMock() + def set_attribute(self, *args, **kwargs): + if kwargs.get("category") == "Metadata": + kwargs.pop("category") + return self.put_metadata(*args, **kwargs) + + if kwargs.get("category") == "Annotation": + kwargs.pop("category") + return self.put_annotation(*args, **kwargs) + + if isinstance(kwargs.get("value"), (str, Number, bool)): + return self.put_annotation(*args, **kwargs) + + return self.put_metadata(*args, **kwargs) + in_subsegment = InSubsegment() in_subsegment.in_subsegment.return_value.__enter__.return_value.put_annotation = in_subsegment.put_annotation in_subsegment.in_subsegment.return_value.__enter__.return_value.put_metadata = in_subsegment.put_metadata in_subsegment.in_subsegment.return_value.__aenter__.return_value.put_metadata = in_subsegment.put_metadata + in_subsegment.in_subsegment.return_value.__enter__.return_value.set_attribute = in_subsegment.set_attribute + in_subsegment.in_subsegment.return_value.__aenter__.return_value.set_attribute = in_subsegment.set_attribute yield in_subsegment @@ -155,6 +186,7 @@ def test_tracer_custom_metadata(monkeypatch, mocker, dummy_response, provider_st key=annotation_key, value=annotation_value, namespace="booking", + category="Metadata", ) @@ -172,7 +204,11 @@ def test_tracer_custom_annotation(monkeypatch, mocker, dummy_response, provider_ # THEN we should have an annotation as expected assert put_annotation_mock.call_count == 1 - assert put_annotation_mock.call_args == mocker.call(key=annotation_key, value=annotation_value) + assert put_annotation_mock.call_args == mocker.call( + key=annotation_key, + value=annotation_value, + category="Annotation", + ) @mock.patch("aws_lambda_powertools.tracing.Tracer.patch") From f0fbd01bfe805ee390fd40d6f0b0081633647b86 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 7 Aug 2024 00:21:28 +0100 Subject: [PATCH 2/4] Adding tracer to v3 --- poetry.lock | 13 ++++++++++++- pyproject.toml | 1 + 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index e0b25c11ccf..ba6b60f0a54 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3746,6 +3746,17 @@ files = [ doc = ["sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] test = ["mypy", "pytest", "typing-extensions"] +[[package]] +name = "types-aws-xray-sdk" +version = "2.14.0.20240606" +description = "Typing stubs for aws-xray-sdk" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-aws-xray-sdk-2.14.0.20240606.tar.gz", hash = "sha256:3215f8f80b48c9da9f7ff16021234cd631b538095933d5432e3fa4c5e2d76a22"}, + {file = "types_aws_xray_sdk-2.14.0.20240606-py3-none-any.whl", hash = "sha256:c238ad639bb50896f1326c12bcc36b7832b5bc7c4b5e2b19a7efcd89d7d28b94"}, +] + [[package]] name = "types-awscrt" version = "0.21.0" @@ -4196,4 +4207,4 @@ validation = ["fastjsonschema"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0.0" -content-hash = "e6a93ae2514bd23686e766fcf06cd42cba18822272b07e116436edcaf9b3bfa7" +content-hash = "3b7bb5f4264d95c47b306fa4187235e0662c5ed2cf6c318c8f5b6fe722b8d56e" diff --git a/pyproject.toml b/pyproject.toml index 57df08cd021..c49745887e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ types-redis = "^4.6.0.7" testcontainers = { extras = ["redis"], version = "^3.7.1" } multiprocess = "^0.70.16" boto3-stubs = {extras = ["appconfig", "appconfigdata", "cloudformation", "cloudwatch", "dynamodb", "lambda", "logs", "s3", "secretsmanager", "ssm", "xray"], version = "^1.34.139"} +types-aws-xray-sdk = "^2.14.0.20240606" [tool.coverage.run] source = ["aws_lambda_powertools"] From e158ca03460c71a0c187ba644f077452bd6724f9 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 19 Aug 2024 19:37:14 +0100 Subject: [PATCH 3/4] Fixing base provider + won't change Tracer --- .../provider/aws_xray/aws_xray_tracer.py | 20 +- .../tracing/provider/base.py | 521 +++++++++++++++++- aws_lambda_powertools/tracing/tracer.py | 91 ++- tests/functional/test_tracing.py | 3 +- tests/unit/test_tracing.py | 42 +- 5 files changed, 572 insertions(+), 105 deletions(-) diff --git a/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py b/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py index b786c2e92d6..1163e87df45 100644 --- a/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py +++ b/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py @@ -2,7 +2,7 @@ from contextlib import asynccontextmanager, contextmanager from numbers import Number -from typing import Any, AsyncGenerator, Generator, Literal, Sequence, Union +from typing import Any, AsyncGenerator, Generator, Literal, Optional, Sequence, Union from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.lazy_import import LazyLoader @@ -69,14 +69,26 @@ def record_exception(self, exception: BaseException, **kwargs): class AwsXrayProvider(BaseProvider): - def __init__(self, xray_recorder=None): - if not xray_recorder: - from aws_xray_sdk.core import xray_recorder + + def __init__( + self, + service: str = "", + disabled: Optional[bool] = None, + auto_patch: Optional[bool] = None, + patch_modules: Optional[Sequence[str]] = None, + ): + from aws_xray_sdk.core import xray_recorder # type: ignore self.recorder = xray_recorder self.in_subsegment = self.recorder.in_subsegment self.in_subsegment_async = self.recorder.in_subsegment_async + self.service = service + + super().__init__( + service=self.service, + ) + @contextmanager def trace(self, name: str, **kwargs) -> Generator[XraySpan, None, None]: with self.in_subsegment(name=name, **kwargs) as sub_segment: diff --git a/aws_lambda_powertools/tracing/provider/base.py b/aws_lambda_powertools/tracing/provider/base.py index eb184d65245..316d49604f7 100644 --- a/aws_lambda_powertools/tracing/provider/base.py +++ b/aws_lambda_powertools/tracing/provider/base.py @@ -1,6 +1,21 @@ +import contextlib +import functools +import inspect +import logging +import os from abc import ABC, abstractmethod from contextlib import asynccontextmanager, contextmanager -from typing import Any, AsyncGenerator, Generator, Sequence +from typing import Any, AsyncGenerator, Callable, Generator, Optional, Sequence, Union, cast, overload + +from aws_lambda_powertools.shared import constants +from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice, sanitize_xray_segment_name +from aws_lambda_powertools.shared.types import AnyCallableT +from aws_lambda_powertools.tracing.base import BaseSegment + +logger = logging.getLogger(__name__) + + +is_cold_start = True class BaseSpan(ABC): @@ -39,6 +54,9 @@ class BaseProvider(ABC): used by Tracer. Inheriting classes must implement this interface to be compatible with Tracer. """ + def __init__(self, service: str = ""): + self.service = service + @abstractmethod @contextmanager def trace(self, name: str, **kwargs) -> Generator[BaseSpan, None, None]: @@ -110,3 +128,504 @@ def patch(self, modules: Sequence[str]) -> None: @abstractmethod def patch_all(self) -> None: """Instrument all supported libraries""" + + def capture_lambda_handler( + self, + lambda_handler: Any = None, + capture_response: Optional[bool] = None, + capture_error: Optional[bool] = None, + ): + """Decorator to create subsegment for lambda handlers + + As Lambda follows (event, context) signature we can remove some of the boilerplate + and also capture any exception any Lambda function throws or its response as metadata + + Parameters + ---------- + lambda_handler : Callable + Method to annotate on + capture_response : bool, optional + Instructs tracer to not include handler's response as metadata + capture_error : bool, optional + Instructs tracer to not include handler's error as metadata, by default True + + Example + ------- + **Lambda function using capture_lambda_handler decorator** + + tracer = Tracer(service="payment") + @tracer.capture_lambda_handler + def handler(event, context): + ... + + **Preventing Tracer to log response as metadata** + + tracer = Tracer(service="payment") + @tracer.capture_lambda_handler(capture_response=False) + def handler(event, context): + ... + + Raises + ------ + err + Exception raised by method + """ + # If handler is None we've been called with parameters + # Return a partial function with args filled + if lambda_handler is None: + logger.debug("Decorator called with parameters") + return functools.partial( + self.capture_lambda_handler, + capture_response=capture_response, + capture_error=capture_error, + ) + + lambda_handler_name = lambda_handler.__name__ + capture_response = resolve_truthy_env_var_choice( + env=os.getenv(constants.TRACER_CAPTURE_RESPONSE_ENV, "true"), + choice=capture_response, + ) + capture_error = resolve_truthy_env_var_choice( + env=os.getenv(constants.TRACER_CAPTURE_ERROR_ENV, "true"), + choice=capture_error, + ) + + @functools.wraps(lambda_handler) + def decorate(event, context, **kwargs): + with self.trace(name=f"## {lambda_handler_name}") as subsegment: + try: + logger.debug("Calling lambda handler") + response = lambda_handler(event, context, **kwargs) + logger.debug("Received lambda handler response successfully") + self._add_response_as_metadata( + method_name=lambda_handler_name, + data=response, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from {lambda_handler_name}") + self._add_full_exception_as_metadata( + method_name=lambda_handler_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + + raise + finally: + global is_cold_start + logger.debug("Annotating cold start") + subsegment.put_annotation(key="ColdStart", value=is_cold_start) + + if is_cold_start: + is_cold_start = False + + if self.service: + subsegment.put_annotation(key="Service", value=self.service) + + return response + + return decorate + + # see #465 + @overload + def capture_method(self, method: "AnyCallableT") -> "AnyCallableT": ... # pragma: no cover + + @overload + def capture_method( + self, + method: None = None, + capture_response: Optional[bool] = None, + capture_error: Optional[bool] = None, + ) -> Callable[["AnyCallableT"], "AnyCallableT"]: ... # pragma: no cover + + def capture_method( + self, + method: Optional[AnyCallableT] = None, + capture_response: Optional[bool] = None, + capture_error: Optional[bool] = None, + ) -> AnyCallableT: + """Decorator to create subsegment for arbitrary functions + + It also captures both response and exceptions as metadata + and creates a subsegment named `## ` + # see here: [Qualified name for classes and functions](https://peps.python.org/pep-3155/) + + When running [async functions concurrently](https://docs.python.org/3/library/asyncio-task.html#id6), + methods may impact each others subsegment, and can trigger + and AlreadyEndedException from X-Ray due to async nature. + + For this use case, either use `capture_method` only where + `async.gather` is called, or use `in_subsegment_async` + context manager via our escape hatch mechanism - See examples. + + Parameters + ---------- + method : Callable + Method to annotate on + capture_response : bool, optional + Instructs tracer to not include method's response as metadata + capture_error : bool, optional + Instructs tracer to not include handler's error as metadata, by default True + + Example + ------- + **Custom function using capture_method decorator** + + tracer = Tracer(service="payment") + @tracer.capture_method + def some_function() + + **Custom async method using capture_method decorator** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + async def confirm_booking(booking_id: str) -> Dict: + resp = call_to_booking_service() + + tracer.put_annotation("BookingConfirmation", resp["requestId"]) + tracer.put_metadata("Booking confirmation", resp) + + return resp + + def lambda_handler(event: dict, context: Any) -> Dict: + booking_id = event.get("booking_id") + asyncio.run(confirm_booking(booking_id=booking_id)) + + **Custom generator function using capture_method decorator** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + def bookings_generator(booking_id): + resp = call_to_booking_service() + yield resp[0] + yield resp[1] + + def lambda_handler(event: dict, context: Any) -> Dict: + gen = bookings_generator(booking_id=booking_id) + result = list(gen) + + **Custom generator context manager using capture_method decorator** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + @contextlib.contextmanager + def booking_actions(booking_id): + resp = call_to_booking_service() + yield "example result" + cleanup_stuff() + + def lambda_handler(event: dict, context: Any) -> Dict: + booking_id = event.get("booking_id") + + with booking_actions(booking_id=booking_id) as booking: + result = booking + + **Tracing nested async calls** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + async def get_identity(): + ... + + @tracer.capture_method + async def long_async_call(): + ... + + @tracer.capture_method + async def async_tasks(): + await get_identity() + ret = await long_async_call() + + return { "task": "done", **ret } + + **Safely tracing concurrent async calls with decorator** + + This may not needed once [this bug is closed](https://github.com/aws/aws-xray-sdk-python/issues/164) + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + async def get_identity(): + async with aioboto3.client("sts") as sts: + account = await sts.get_caller_identity() + return account + + async def long_async_call(): + ... + + @tracer.capture_method + async def async_tasks(): + _, ret = await asyncio.gather(get_identity(), long_async_call(), return_exceptions=True) + + return { "task": "done", **ret } + + **Safely tracing each concurrent async calls with escape hatch** + + This may not needed once [this bug is closed](https://github.com/aws/aws-xray-sdk-python/issues/164) + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + async def get_identity(): + async tracer.provider.in_subsegment_async("## get_identity"): + ... + + async def long_async_call(): + async tracer.provider.in_subsegment_async("## long_async_call"): + ... + + @tracer.capture_method + async def async_tasks(): + _, ret = await asyncio.gather(get_identity(), long_async_call(), return_exceptions=True) + + return { "task": "done", **ret } + + Raises + ------ + err + Exception raised by method + """ + # If method is None we've been called with parameters + # Return a partial function with args filled + if method is None: + logger.debug("Decorator called with parameters") + return cast( + AnyCallableT, + functools.partial(self.capture_method, capture_response=capture_response, capture_error=capture_error), + ) + + # Example: app.ClassA.get_all # noqa ERA001 + # Valid characters can be found at http://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html + method_name = sanitize_xray_segment_name(f"{method.__module__}.{method.__qualname__}") + + capture_response = resolve_truthy_env_var_choice( + env=os.getenv(constants.TRACER_CAPTURE_RESPONSE_ENV, "true"), + choice=capture_response, + ) + capture_error = resolve_truthy_env_var_choice( + env=os.getenv(constants.TRACER_CAPTURE_ERROR_ENV, "true"), + choice=capture_error, + ) + + # Maintenance: Need a factory/builder here to simplify this now + if inspect.iscoroutinefunction(method): + return self._decorate_async_function( + method=method, + capture_response=capture_response, + capture_error=capture_error, + method_name=method_name, + ) + elif inspect.isgeneratorfunction(method): + return self._decorate_generator_function( + method=method, + capture_response=capture_response, + capture_error=capture_error, + method_name=method_name, + ) + elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__): + return self._decorate_generator_function_with_context_manager( + method=method, + capture_response=capture_response, + capture_error=capture_error, + method_name=method_name, + ) + else: + return self._decorate_sync_function( + method=method, + capture_response=capture_response, + capture_error=capture_error, + method_name=method_name, + ) + + def _decorate_async_function( + self, + method: Callable, + capture_response: Optional[Union[bool, str]] = None, + capture_error: Optional[Union[bool, str]] = None, + method_name: Optional[str] = None, + ): + @functools.wraps(method) + async def decorate(*args, **kwargs): + async with self.trace_async(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + response = await method(*args, **kwargs) + self._add_response_as_metadata( + method_name=method_name, + data=response, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata( + method_name=method_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + raise + + return response + + return decorate + + def _decorate_generator_function( + self, + method: Callable, + capture_response: Optional[Union[bool, str]] = None, + capture_error: Optional[Union[bool, str]] = None, + method_name: Optional[str] = None, + ): + @functools.wraps(method) + def decorate(*args, **kwargs): + with self.trace(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + result = yield from method(*args, **kwargs) + self._add_response_as_metadata( + method_name=method_name, + data=result, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata( + method_name=method_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + raise + + return result + + return decorate + + def _decorate_generator_function_with_context_manager( + self, + method: Callable, + capture_response: Optional[Union[bool, str]] = None, + capture_error: Optional[Union[bool, str]] = None, + method_name: Optional[str] = None, + ): + @functools.wraps(method) + @contextlib.contextmanager + def decorate(*args, **kwargs): + with self.trace(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + with method(*args, **kwargs) as return_val: + result = return_val + yield result + self._add_response_as_metadata( + method_name=method_name, + data=result, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata( + method_name=method_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + raise + + return decorate + + def _decorate_sync_function( + self, + method: AnyCallableT, + capture_response: Optional[Union[bool, str]] = None, + capture_error: Optional[Union[bool, str]] = None, + method_name: Optional[str] = None, + ) -> AnyCallableT: + @functools.wraps(method) + def decorate(*args, **kwargs): + with self.trace(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + response = method(*args, **kwargs) + self._add_response_as_metadata( + method_name=method_name, + data=response, + subsegment=subsegment, + capture_response=capture_response, + ) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata( + method_name=method_name, + error=err, + subsegment=subsegment, + capture_error=capture_error, + ) + raise + + return response + + return cast(AnyCallableT, decorate) + + def _add_response_as_metadata( + self, + method_name: Optional[str] = None, + data: Optional[Any] = None, + subsegment: Optional[BaseSegment] = None, + capture_response: Optional[Union[bool, str]] = None, + ): + """Add response as metadata for given subsegment + + Parameters + ---------- + method_name : str, optional + method name to add as metadata key, by default None + data : Any, optional + data to add as subsegment metadata, by default None + subsegment : BaseSegment, optional + existing subsegment to add metadata on, by default None + capture_response : bool, optional + Do not include response as metadata + """ + if data is None or not capture_response or subsegment is None: + return + + subsegment.put_metadata(key=f"{method_name} response", value=data, namespace=self.service) + + def _add_full_exception_as_metadata( + self, + method_name: str, + error: Exception, + subsegment: BaseSegment, + capture_error: Optional[bool] = None, + ): + """Add full exception object as metadata for given subsegment + + Parameters + ---------- + method_name : str + method name to add as metadata key, by default None + error : Exception + error to add as subsegment metadata, by default None + subsegment : BaseSegment + existing subsegment to add metadata on, by default None + capture_error : bool, optional + Do not include error as metadata, by default True + """ + if not capture_error: + return + + subsegment.put_metadata(key=f"{method_name} error", value=error, namespace=self.service) diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index fceab00b7bd..6b0c26a42cf 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -5,7 +5,7 @@ import logging import numbers import os -from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast, overload +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union, cast, overload from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.functions import ( @@ -15,14 +15,16 @@ ) from aws_lambda_powertools.shared.lazy_import import LazyLoader from aws_lambda_powertools.shared.types import AnyCallableT +from aws_lambda_powertools.tracing.base import BaseProvider, BaseSegment from aws_lambda_powertools.tracing.provider.aws_xray.aws_xray_tracer import AwsXrayProvider -from aws_lambda_powertools.tracing.provider.base import BaseProvider, BaseSpan is_cold_start = True logger = logging.getLogger(__name__) aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE) +T = TypeVar("T") + class Tracer: """Tracer using AWS-XRay to provide decorators with known defaults for Lambda functions @@ -155,7 +157,7 @@ def __init__( disabled: Optional[bool] = None, auto_patch: Optional[bool] = None, patch_modules: Optional[Sequence[str]] = None, - provider: Optional[AwsXrayProvider] = None, + provider: Optional[BaseProvider] = None, ): self.__build_config( service=service, @@ -164,7 +166,7 @@ def __init__( patch_modules=patch_modules, provider=provider, ) - self.provider: AwsXrayProvider = self._config["provider"] + self.provider: BaseProvider = self._config["provider"] self.disabled = self._config["disabled"] self.service = self._config["service"] self.auto_patch = self._config["auto_patch"] @@ -178,36 +180,6 @@ def __init__( if self._is_xray_provider(): self._disable_xray_trace_batching() - def set_attribute(self, key: str, value: Any, **kwargs): - """Set an attribute on current active span with a key-value pair. - - Parameters - ---------- - key: str - attribute key - value: Any - Value for attribute - kwargs: Optional[dict] - Optional parameters to be passed to provider.set_attributes - - Example - ------- - Set an attribute for a pseudo service named payment - - tracer = Tracer(service="payment") - tracer.set_attribute("PaymentStatus", "CONFIRMED") - """ - if self.disabled: - logger.debug("Tracing has been disabled, aborting set_attribute") - return - - logger.debug(f"setting attribute on key '{key}' with '{value}'") - - namespace = kwargs.get("namespace") or self.service - kwargs["namespace"] = namespace - - self.provider.set_attribute(key=key, value=value, **kwargs) - def put_annotation(self, key: str, value: Union[str, numbers.Number, bool]): """Adds annotation to existing segment or subsegment @@ -230,8 +202,7 @@ def put_annotation(self, key: str, value: Union[str, numbers.Number, bool]): return logger.debug(f"Annotating on key '{key}' with '{value}'") - - self.provider.set_attribute(key=key, value=value, category="Annotation") + self.provider.put_annotation(key=key, value=value) def put_metadata(self, key: str, value: Any, namespace: Optional[str] = None): """Adds metadata to existing segment or subsegment @@ -259,8 +230,7 @@ def put_metadata(self, key: str, value: Any, namespace: Optional[str] = None): namespace = namespace or self.service logger.debug(f"Adding metadata on key '{key}' with '{value}' at namespace '{namespace}'") - - self.provider.set_attribute(key=key, value=value, namespace=namespace, category="Metadata") + self.provider.put_metadata(key=key, value=value, namespace=namespace) def patch(self, modules: Optional[Sequence[str]] = None): """Patch modules for instrumentation. @@ -283,7 +253,7 @@ def patch(self, modules: Optional[Sequence[str]] = None): def capture_lambda_handler( self, - lambda_handler: Union[Callable[[Dict, Any], Any], Optional[Callable[[Dict, Any, Optional[Dict]], Any]]] = None, + lambda_handler: Optional[Union[Callable[[T, Any], Any], Callable[[T, Any, Any], Any]]] = None, capture_response: Optional[bool] = None, capture_error: Optional[bool] = None, ): @@ -344,7 +314,7 @@ def handler(event, context): @functools.wraps(lambda_handler) def decorate(event, context, **kwargs): - with self.provider.trace(name=f"## {lambda_handler_name}") as subsegment: + with self.provider.in_subsegment(name=f"## {lambda_handler_name}") as subsegment: try: logger.debug("Calling lambda handler") response = lambda_handler(event, context, **kwargs) @@ -368,13 +338,13 @@ def decorate(event, context, **kwargs): finally: global is_cold_start logger.debug("Annotating cold start") - subsegment.set_attribute(key="ColdStart", value=is_cold_start) + subsegment.put_annotation(key="ColdStart", value=is_cold_start) if is_cold_start: is_cold_start = False if self.service: - subsegment.set_attribute(key="Service", value=self.service) + subsegment.put_annotation(key="Service", value=self.service) return response @@ -608,7 +578,7 @@ def _decorate_async_function( ): @functools.wraps(method) async def decorate(*args, **kwargs): - async with self.provider.trace_async(name=f"## {method_name}") as subsegment: + async with self.provider.in_subsegment_async(name=f"## {method_name}") as subsegment: try: logger.debug(f"Calling method: {method_name}") response = await method(*args, **kwargs) @@ -641,7 +611,7 @@ def _decorate_generator_function( ): @functools.wraps(method) def decorate(*args, **kwargs): - with self.provider.trace(name=f"## {method_name}") as subsegment: + with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: try: logger.debug(f"Calling method: {method_name}") result = yield from method(*args, **kwargs) @@ -675,7 +645,7 @@ def _decorate_generator_function_with_context_manager( @functools.wraps(method) @contextlib.contextmanager def decorate(*args, **kwargs): - with self.provider.trace(name=f"## {method_name}") as subsegment: + with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: try: logger.debug(f"Calling method: {method_name}") with method(*args, **kwargs) as return_val: @@ -708,7 +678,7 @@ def _decorate_sync_function( ) -> AnyCallableT: @functools.wraps(method) def decorate(*args, **kwargs): - with self.provider.trace(name=f"## {method_name}") as subsegment: + with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: try: logger.debug(f"Calling method: {method_name}") response = method(*args, **kwargs) @@ -736,7 +706,7 @@ def _add_response_as_metadata( self, method_name: Optional[str] = None, data: Optional[Any] = None, - subsegment: Optional[BaseSpan] = None, + subsegment: Optional[BaseSegment] = None, capture_response: Optional[Union[bool, str]] = None, ): """Add response as metadata for given subsegment @@ -747,20 +717,21 @@ def _add_response_as_metadata( method name to add as metadata key, by default None data : Any, optional data to add as subsegment metadata, by default None - subsegment : BaseSpan, optional + subsegment : BaseSegment, optional existing subsegment to add metadata on, by default None capture_response : bool, optional Do not include response as metadata """ if data is None or not capture_response or subsegment is None: return - subsegment.set_attribute(key=f"{method_name} response", value=data, namespace=self.service, category="Metadata") + + subsegment.put_metadata(key=f"{method_name} response", value=data, namespace=self.service) def _add_full_exception_as_metadata( self, method_name: str, error: Exception, - subsegment: BaseSpan, + subsegment: BaseSegment, capture_error: Optional[bool] = None, ): """Add full exception object as metadata for given subsegment @@ -771,7 +742,7 @@ def _add_full_exception_as_metadata( method name to add as metadata key, by default None error : Exception error to add as subsegment metadata, by default None - subsegment : BaseSpan + subsegment : BaseSegment existing subsegment to add metadata on, by default None capture_error : bool, optional Do not include error as metadata, by default True @@ -779,12 +750,7 @@ def _add_full_exception_as_metadata( if not capture_error: return - subsegment.set_attribute( - key=f"{method_name} error", - value=error, - namespace=self.service, - category="Metadata", - ) + subsegment.put_metadata(key=f"{method_name} error", value=error, namespace=self.service) @staticmethod def _disable_tracer_provider(): @@ -835,18 +801,23 @@ def __build_config( is_service = resolve_env_var_choice(choice=service, env=os.getenv(constants.SERVICE_NAME_ENV)) # Logic: Choose overridden option first, previously cached config, or default if available - self._config["provider"] = provider or self._config["provider"] or self._patch_xray_provider() self._config["auto_patch"] = auto_patch if auto_patch is not None else self._config["auto_patch"] self._config["service"] = is_service or self._config["service"] self._config["disabled"] = is_disabled or self._config["disabled"] self._config["patch_modules"] = patch_modules or self._config["patch_modules"] + self._config["provider"] = provider or self._config["provider"] or self._patch_xray_provider() @classmethod def _reset_config(cls): cls._config = copy.copy(cls._default_config) def _patch_xray_provider(self): - return AwsXrayProvider() + return AwsXrayProvider( + service=self._config["service"], + auto_patch=self._config["auto_patch"], + patch_modules=self._config["patch_modules"], + disabled=self._config["disabled"], + ) def _disable_xray_trace_batching(self): """Configure X-Ray SDK to send subsegment individually over batching @@ -859,7 +830,7 @@ def _disable_xray_trace_batching(self): aws_xray_sdk.core.xray_recorder.configure(streaming_threshold=0) def _is_xray_provider(self): - return isinstance(self.provider, AwsXrayProvider) + return any(module in self.provider.__module__ for module in ("aws_xray_sdk", "aws_xray_tracer")) def ignore_endpoint(self, hostname: Optional[str] = None, urls: Optional[List[str]] = None): """If you want to ignore certain httplib requests you can do so based on the hostname or URL that is being diff --git a/tests/functional/test_tracing.py b/tests/functional/test_tracing.py index 5f48b233d91..c7926c24b7a 100644 --- a/tests/functional/test_tracing.py +++ b/tests/functional/test_tracing.py @@ -3,6 +3,7 @@ import pytest from aws_lambda_powertools import Tracer +from aws_lambda_powertools.tracing.provider.aws_xray.aws_xray_tracer import AwsXrayProvider @pytest.fixture @@ -23,7 +24,7 @@ def service_name(): def test_capture_lambda_handler(dummy_response): # GIVEN tracer lambda handler decorator is used - tracer = Tracer(disabled=True) + tracer = AwsXrayProvider(disabled=True) # WHEN a lambda handler is run @tracer.capture_lambda_handler diff --git a/tests/unit/test_tracing.py b/tests/unit/test_tracing.py index 9d9731b0147..0d12afa629b 100644 --- a/tests/unit/test_tracing.py +++ b/tests/unit/test_tracing.py @@ -1,5 +1,4 @@ import contextlib -from numbers import Number from typing import NamedTuple from unittest import mock from unittest.mock import MagicMock @@ -32,24 +31,10 @@ def __init__( ): self.put_metadata_mock = put_metadata_mock or mocker.MagicMock() self.put_annotation_mock = put_annotation_mock or mocker.MagicMock() - self.trace = self.in_subsegment self.in_subsegment = in_subsegment or mocker.MagicMock() self.patch_mock = patch_mock or mocker.MagicMock() self.disable_tracing_provider_mock = disable_tracing_provider_mock or mocker.MagicMock() self.in_subsegment_async = in_subsegment_async or mocker.MagicMock(spec=True) - self.trace_async = self.in_subsegment_async - - def set_attribute(self, *args, **kwargs): - if kwargs.get("category") == "Metadata": - return self.put_metadata(*args, **kwargs) - - if kwargs.get("category") == "Annotation": - return self.put_annotation(*args, **kwargs) - - if isinstance(kwargs.get("value"), (str, Number, bool)): - return self.put_annotation(*args, **kwargs) - - return self.put_metadata(*args, **kwargs) def put_metadata(self, *args, **kwargs): return self.put_metadata_mock(*args, **kwargs) @@ -80,8 +65,8 @@ def reset_tracing_config(mocker): @pytest.fixture -def in_subsegment_mock(mocker): - class AsyncContextManager(mocker.MagicMock): +def in_subsegment_mock(): + class AsyncContextManager(mock.MagicMock): async def __aenter__(self, *args, **kwargs): return self.__enter__() @@ -93,26 +78,10 @@ class InSubsegment(NamedTuple): put_annotation: mock.MagicMock = mock.MagicMock() put_metadata: mock.MagicMock = mock.MagicMock() - def set_attribute(self, *args, **kwargs): - if kwargs.get("category") == "Metadata": - kwargs.pop("category") - return self.put_metadata(*args, **kwargs) - - if kwargs.get("category") == "Annotation": - kwargs.pop("category") - return self.put_annotation(*args, **kwargs) - - if isinstance(kwargs.get("value"), (str, Number, bool)): - return self.put_annotation(*args, **kwargs) - - return self.put_metadata(*args, **kwargs) - in_subsegment = InSubsegment() in_subsegment.in_subsegment.return_value.__enter__.return_value.put_annotation = in_subsegment.put_annotation in_subsegment.in_subsegment.return_value.__enter__.return_value.put_metadata = in_subsegment.put_metadata in_subsegment.in_subsegment.return_value.__aenter__.return_value.put_metadata = in_subsegment.put_metadata - in_subsegment.in_subsegment.return_value.__enter__.return_value.set_attribute = in_subsegment.set_attribute - in_subsegment.in_subsegment.return_value.__aenter__.return_value.set_attribute = in_subsegment.set_attribute yield in_subsegment @@ -186,7 +155,6 @@ def test_tracer_custom_metadata(monkeypatch, mocker, dummy_response, provider_st key=annotation_key, value=annotation_value, namespace="booking", - category="Metadata", ) @@ -204,11 +172,7 @@ def test_tracer_custom_annotation(monkeypatch, mocker, dummy_response, provider_ # THEN we should have an annotation as expected assert put_annotation_mock.call_count == 1 - assert put_annotation_mock.call_args == mocker.call( - key=annotation_key, - value=annotation_value, - category="Annotation", - ) + assert put_annotation_mock.call_args == mocker.call(key=annotation_key, value=annotation_value) @mock.patch("aws_lambda_powertools.tracing.Tracer.patch") From 4c70f121774c0161b961e98a0b772c96c809a227 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 19 Aug 2024 20:02:29 +0100 Subject: [PATCH 4/4] Merging from V3 --- .../provider/aws_xray/aws_xray_tracer.py | 10 +-- .../tracing/provider/base.py | 72 ++++++++++--------- 2 files changed, 43 insertions(+), 39 deletions(-) diff --git a/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py b/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py index 1163e87df45..6eedd3aac43 100644 --- a/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py +++ b/aws_lambda_powertools/tracing/provider/aws_xray/aws_xray_tracer.py @@ -2,7 +2,7 @@ from contextlib import asynccontextmanager, contextmanager from numbers import Number -from typing import Any, AsyncGenerator, Generator, Literal, Optional, Sequence, Union +from typing import Any, AsyncGenerator, Generator, Literal, Sequence from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.lazy_import import LazyLoader @@ -73,9 +73,9 @@ class AwsXrayProvider(BaseProvider): def __init__( self, service: str = "", - disabled: Optional[bool] = None, - auto_patch: Optional[bool] = None, - patch_modules: Optional[Sequence[str]] = None, + disabled: bool | None = None, + auto_patch: bool | None = None, + patch_modules: Sequence[str] | None = None, ): from aws_xray_sdk.core import xray_recorder # type: ignore @@ -141,7 +141,7 @@ def set_attribute( # Auto & not in (str, Number, bool) self.put_metadata(key=key, value=value, namespace=kwargs.get("namespace", "dafault")) - def put_annotation(self, key: str, value: Union[str, Number, bool]) -> None: + def put_annotation(self, key: str, value: str | Number | bool) -> None: return self.recorder.put_annotation(key=key, value=value) def put_metadata(self, key: str, value: Any, namespace: str = "default") -> None: diff --git a/aws_lambda_powertools/tracing/provider/base.py b/aws_lambda_powertools/tracing/provider/base.py index 316d49604f7..a29af067887 100644 --- a/aws_lambda_powertools/tracing/provider/base.py +++ b/aws_lambda_powertools/tracing/provider/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import functools import inspect @@ -5,12 +7,14 @@ import os from abc import ABC, abstractmethod from contextlib import asynccontextmanager, contextmanager -from typing import Any, AsyncGenerator, Callable, Generator, Optional, Sequence, Union, cast, overload +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Generator, Sequence, cast, overload from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice, sanitize_xray_segment_name from aws_lambda_powertools.shared.types import AnyCallableT -from aws_lambda_powertools.tracing.base import BaseSegment + +if TYPE_CHECKING: + from aws_lambda_powertools.tracing.base import BaseSegment logger = logging.getLogger(__name__) @@ -32,7 +36,7 @@ def set_attribute(self, key: str, value: Any, **kwargs) -> None: Attribute key value: Any Attribute value - kwargs: Optional[dict] + kwargs: dict | None Optional parameters """ @@ -43,8 +47,8 @@ def record_exception(self, exception: BaseException, **kwargs): Parameters ---------- exception: Exception - Caught exception during the exectution of this Span - kwargs: Optional[dict] + Caught exception during the execution of this Span + kwargs: dict | None Optional parameters """ @@ -71,7 +75,7 @@ def trace(self, name: str, **kwargs) -> Generator[BaseSpan, None, None]: ---------- name: str Span name - kwargs: Optional[dict] + kwargs: dict | None Optional parameters to be propagated to the span """ @@ -89,7 +93,7 @@ def trace_async(self, name: str, **kwargs) -> AsyncGenerator[BaseSpan, None]: ---------- name: str Span name - kwargs: Optional[dict] + kwargs: dict | None Optional parameters to be propagated to the span """ @@ -103,7 +107,7 @@ def set_attribute(self, key: str, value: Any, **kwargs) -> None: attribute key value: Any attribute value - kwargs: Optional[dict] + kwargs: dict | None Optional parameters to be propagated to the span """ @@ -132,8 +136,8 @@ def patch_all(self) -> None: def capture_lambda_handler( self, lambda_handler: Any = None, - capture_response: Optional[bool] = None, - capture_error: Optional[bool] = None, + capture_response: bool | None = None, + capture_error: bool | None = None, ): """Decorator to create subsegment for lambda handlers @@ -230,21 +234,21 @@ def decorate(event, context, **kwargs): # see #465 @overload - def capture_method(self, method: "AnyCallableT") -> "AnyCallableT": ... # pragma: no cover + def capture_method(self, method: AnyCallableT) -> AnyCallableT: ... # pragma: no cover @overload def capture_method( self, method: None = None, - capture_response: Optional[bool] = None, - capture_error: Optional[bool] = None, - ) -> Callable[["AnyCallableT"], "AnyCallableT"]: ... # pragma: no cover + capture_response: bool | None = None, + capture_error: bool | None = None, + ) -> Callable[[AnyCallableT], AnyCallableT]: ... # pragma: no cover def capture_method( self, - method: Optional[AnyCallableT] = None, - capture_response: Optional[bool] = None, - capture_error: Optional[bool] = None, + method: AnyCallableT | None = None, + capture_response: bool | None = None, + capture_error: bool | None = None, ) -> AnyCallableT: """Decorator to create subsegment for arbitrary functions @@ -450,9 +454,9 @@ async def async_tasks(): def _decorate_async_function( self, method: Callable, - capture_response: Optional[Union[bool, str]] = None, - capture_error: Optional[Union[bool, str]] = None, - method_name: Optional[str] = None, + capture_response: bool | str | None = None, + capture_error: bool | str | None = None, + method_name: str | None = None, ): @functools.wraps(method) async def decorate(*args, **kwargs): @@ -483,9 +487,9 @@ async def decorate(*args, **kwargs): def _decorate_generator_function( self, method: Callable, - capture_response: Optional[Union[bool, str]] = None, - capture_error: Optional[Union[bool, str]] = None, - method_name: Optional[str] = None, + capture_response: bool | str | None = None, + capture_error: bool | str | None = None, + method_name: str | None = None, ): @functools.wraps(method) def decorate(*args, **kwargs): @@ -516,9 +520,9 @@ def decorate(*args, **kwargs): def _decorate_generator_function_with_context_manager( self, method: Callable, - capture_response: Optional[Union[bool, str]] = None, - capture_error: Optional[Union[bool, str]] = None, - method_name: Optional[str] = None, + capture_response: bool | str | None = None, + capture_error: bool | str | None = None, + method_name: str | None = None, ): @functools.wraps(method) @contextlib.contextmanager @@ -550,9 +554,9 @@ def decorate(*args, **kwargs): def _decorate_sync_function( self, method: AnyCallableT, - capture_response: Optional[Union[bool, str]] = None, - capture_error: Optional[Union[bool, str]] = None, - method_name: Optional[str] = None, + capture_response: bool | str | None = None, + capture_error: bool | str | None = None, + method_name: str | None = None, ) -> AnyCallableT: @functools.wraps(method) def decorate(*args, **kwargs): @@ -582,10 +586,10 @@ def decorate(*args, **kwargs): def _add_response_as_metadata( self, - method_name: Optional[str] = None, - data: Optional[Any] = None, - subsegment: Optional[BaseSegment] = None, - capture_response: Optional[Union[bool, str]] = None, + method_name: str | None = None, + data: Any | None = None, + subsegment: BaseSegment | None = None, + capture_response: bool | str | None = None, ): """Add response as metadata for given subsegment @@ -610,7 +614,7 @@ def _add_full_exception_as_metadata( method_name: str, error: Exception, subsegment: BaseSegment, - capture_error: Optional[bool] = None, + capture_error: bool | None = None, ): """Add full exception object as metadata for given subsegment