diff --git a/aws_lambda_powertools/shared/constants.py b/aws_lambda_powertools/shared/constants.py index 47bb384bc76..82c131843fb 100644 --- a/aws_lambda_powertools/shared/constants.py +++ b/aws_lambda_powertools/shared/constants.py @@ -6,6 +6,7 @@ XRAY_SDK_CORE_MODULE: str = "aws_xray_sdk.core" XRAY_TRACE_ID_ENV: str = "_X_AMZN_TRACE_ID" MIDDLEWARE_FACTORY_TRACE_ENV: str = "POWERTOOLS_TRACE_MIDDLEWARES" +INVALID_XRAY_NAME_CHARACTERS = r"[?;*()!$~^<>]" # Logger constants LOGGER_LOG_SAMPLING_RATE: str = "POWERTOOLS_LOGGER_SAMPLE_RATE" diff --git a/aws_lambda_powertools/shared/functions.py b/aws_lambda_powertools/shared/functions.py index ee274498260..0f943f36d39 100644 --- a/aws_lambda_powertools/shared/functions.py +++ b/aws_lambda_powertools/shared/functions.py @@ -4,6 +4,7 @@ import itertools import logging import os +import re import warnings from binascii import Error as BinAsciiError from pathlib import Path @@ -275,13 +276,10 @@ def abs_lambda_path(relative_path: str = "") -> str: If the path is empty, it will return the current working directory. """ # Retrieve the LAMBDA_TASK_ROOT environment variable or default to an empty string - current_working_directory = os.environ.get("LAMBDA_TASK_ROOT", "") + current_working_directory = os.environ.get("LAMBDA_TASK_ROOT", "") or str(Path.cwd()) - # If LAMBDA_TASK_ROOT is not set, use the current working directory - if not current_working_directory: - current_working_directory = str(Path.cwd()) + return str(Path(current_working_directory, relative_path)) - # Combine the current working directory and the relative path to get the absolute path - absolute_path = str(Path(current_working_directory, relative_path)) - return absolute_path +def sanitize_xray_segment_name(name: str) -> str: + return re.sub(constants.INVALID_XRAY_NAME_CHARACTERS, "", name) diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index 8c673294faa..a79ac4ec738 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -8,7 +8,11 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast, overload from aws_lambda_powertools.shared import constants -from aws_lambda_powertools.shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice +from aws_lambda_powertools.shared.functions import ( + resolve_env_var_choice, + resolve_truthy_env_var_choice, + sanitize_xray_segment_name, +) from aws_lambda_powertools.shared.lazy_import import LazyLoader from aws_lambda_powertools.shared.types import AnyCallableT from aws_lambda_powertools.tracing.base import BaseProvider, BaseSegment @@ -520,7 +524,8 @@ async def async_tasks(): ) # Example: app.ClassA.get_all # noqa ERA001 - method_name = f"{method.__module__}.{method.__qualname__}" + # Valid characters can be found at http://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html + method_name = sanitize_xray_segment_name(f"{method.__module__}.{method.__qualname__}") capture_response = resolve_truthy_env_var_choice( env=os.getenv(constants.TRACER_CAPTURE_RESPONSE_ENV, "true"), diff --git a/tests/unit/test_shared_functions.py b/tests/unit/test_shared_functions.py index c8c4bb2afb2..b286c536249 100644 --- a/tests/unit/test_shared_functions.py +++ b/tests/unit/test_shared_functions.py @@ -15,6 +15,7 @@ resolve_env_var_choice, resolve_max_age, resolve_truthy_env_var_choice, + sanitize_xray_segment_name, strtobool, ) from aws_lambda_powertools.utilities.data_classes.common import DictWrapper @@ -175,3 +176,27 @@ def test_abs_lambda_path_w_filename_envvar(default_lambda_path): os.environ["LAMBDA_TASK_ROOT"] = default_lambda_path # Then path = env + relative_path assert abs_lambda_path(relative_path="cert/pub.cert") == str(Path(os.environ["LAMBDA_TASK_ROOT"], relative_path)) + + +def test_sanitize_xray_segment_name(): + # GIVEN a name with invalid characters + invalid_name = "app?;*.lambda_function.().get_todos!$~^<>" + + # WHEN we sanitize this name by removing invalid characters + sanitized_name = sanitize_xray_segment_name(invalid_name) + + # THEN the sanitized name should not contain invalid characters + expected_name = "app.lambda_function.locals.get_todos" + assert sanitized_name == expected_name + + +def test_sanitize_xray_segment_name_with_no_special_characters(): + # GIVEN a name without any invalid characters + valid_name = "app#lambda_function" + + # WHEN we sanitize this name + sanitized_name = sanitize_xray_segment_name(valid_name) + + # THEN the sanitized name remains the same as the original name + expected_name = valid_name + assert sanitized_name == expected_name diff --git a/tests/unit/test_tracing.py b/tests/unit/test_tracing.py index 7b09bcde885..0d12afa629b 100644 --- a/tests/unit/test_tracing.py +++ b/tests/unit/test_tracing.py @@ -127,10 +127,10 @@ def greeting(name, message): # and use service name as a metadata namespace assert in_subsegment_mock.in_subsegment.call_count == 1 assert in_subsegment_mock.in_subsegment.call_args == mocker.call( - name=f"## {MODULE_PREFIX}.test_tracer_method..greeting", + name=f"## {MODULE_PREFIX}.test_tracer_method.locals.greeting", ) assert in_subsegment_mock.put_metadata.call_args == mocker.call( - key=f"{MODULE_PREFIX}.test_tracer_method..greeting response", + key=f"{MODULE_PREFIX}.test_tracer_method.locals.greeting response", value=dummy_response, namespace="booking", ) @@ -261,8 +261,7 @@ def greeting(name, message): # and their service name as the namespace put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1] assert ( - put_metadata_mock_args["key"] - == f"{MODULE_PREFIX}.test_tracer_method_exception_metadata..greeting error" + put_metadata_mock_args["key"] == f"{MODULE_PREFIX}.test_tracer_method_exception_metadata.locals.greeting error" ) assert put_metadata_mock_args["namespace"] == "booking" @@ -316,20 +315,20 @@ async def greeting(name, message): # THEN we should add metadata for each response like we would for a sync decorated method assert in_subsegment_mock.in_subsegment.call_count == 2 assert in_subsegment_greeting_call_args == mocker.call( - name=f"## {MODULE_PREFIX}.test_tracer_method_nested_async..greeting", + name=f"## {MODULE_PREFIX}.test_tracer_method_nested_async.locals.greeting", ) assert in_subsegment_greeting2_call_args == mocker.call( - name=f"## {MODULE_PREFIX}.test_tracer_method_nested_async..greeting_2", + name=f"## {MODULE_PREFIX}.test_tracer_method_nested_async.locals.greeting_2", ) assert in_subsegment_mock.put_metadata.call_count == 2 assert put_metadata_greeting2_call_args == mocker.call( - key=f"{MODULE_PREFIX}.test_tracer_method_nested_async..greeting_2 response", + key=f"{MODULE_PREFIX}.test_tracer_method_nested_async.locals.greeting_2 response", value=dummy_response, namespace="booking", ) assert put_metadata_greeting_call_args == mocker.call( - key=f"{MODULE_PREFIX}.test_tracer_method_nested_async..greeting response", + key=f"{MODULE_PREFIX}.test_tracer_method_nested_async.locals.greeting response", value=dummy_response, namespace="booking", ) @@ -375,7 +374,7 @@ async def greeting(name, message): put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1] assert ( put_metadata_mock_args["key"] - == f"{MODULE_PREFIX}.test_tracer_method_exception_metadata_async..greeting error" + == f"{MODULE_PREFIX}.test_tracer_method_exception_metadata_async.locals.greeting error" ) assert put_metadata_mock_args["namespace"] == "booking" @@ -409,7 +408,7 @@ def handler(event, context): assert in_subsegment_mock.in_subsegment.call_count == 2 assert handler_trace == mocker.call(name="## handler") assert yield_function_trace == mocker.call( - name=f"## {MODULE_PREFIX}.test_tracer_yield_from_context_manager..yield_with_capture", + name=f"## {MODULE_PREFIX}.test_tracer_yield_from_context_manager.locals.yield_with_capture", ) assert "test result" in result @@ -436,7 +435,7 @@ def yield_with_capture(): put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1] assert ( put_metadata_mock_args["key"] - == f"{MODULE_PREFIX}.test_tracer_yield_from_context_manager_exception_metadata..yield_with_capture error" # noqa E501 + == f"{MODULE_PREFIX}.test_tracer_yield_from_context_manager_exception_metadata.locals.yield_with_capture error" # noqa E501 ) assert isinstance(put_metadata_mock_args["value"], ValueError) assert put_metadata_mock_args["namespace"] == "booking" @@ -480,7 +479,7 @@ def handler(event, context): assert in_subsegment_mock.in_subsegment.call_count == 2 assert handler_trace == mocker.call(name="## handler") assert yield_function_trace == mocker.call( - name=f"## {MODULE_PREFIX}.test_tracer_yield_from_nested_context_manager..yield_with_capture", + name=f"## {MODULE_PREFIX}.test_tracer_yield_from_nested_context_manager.locals.yield_with_capture", ) assert "test result" in result @@ -512,7 +511,7 @@ def handler(event, context): assert in_subsegment_mock.in_subsegment.call_count == 2 assert handler_trace == mocker.call(name="## handler") assert generator_fn_trace == mocker.call( - name=f"## {MODULE_PREFIX}.test_tracer_yield_from_generator..generator_fn", + name=f"## {MODULE_PREFIX}.test_tracer_yield_from_generator.locals.generator_fn", ) assert "test result" in result @@ -538,7 +537,7 @@ def generator_fn(): put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1] assert ( put_metadata_mock_args["key"] - == f"{MODULE_PREFIX}.test_tracer_yield_from_generator_exception_metadata..generator_fn error" + == f"{MODULE_PREFIX}.test_tracer_yield_from_generator_exception_metadata.locals.generator_fn error" ) assert put_metadata_mock_args["namespace"] == "booking" assert isinstance(put_metadata_mock_args["value"], ValueError)