Skip to content

fix(tracer): add name sanitization for X-Ray subsegments #4005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aws_lambda_powertools/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
XRAY_SDK_CORE_MODULE: str = "aws_xray_sdk.core"
XRAY_TRACE_ID_ENV: str = "_X_AMZN_TRACE_ID"
MIDDLEWARE_FACTORY_TRACE_ENV: str = "POWERTOOLS_TRACE_MIDDLEWARES"
INVALID_XRAY_NAME_CHARACTERS = r"[?;*()!$~^<>]"

# Logger constants
LOGGER_LOG_SAMPLING_RATE: str = "POWERTOOLS_LOGGER_SAMPLE_RATE"
Expand Down
12 changes: 5 additions & 7 deletions aws_lambda_powertools/shared/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import logging
import os
import re
import warnings
from binascii import Error as BinAsciiError
from pathlib import Path
Expand Down Expand Up @@ -275,13 +276,10 @@ def abs_lambda_path(relative_path: str = "") -> str:
If the path is empty, it will return the current working directory.
"""
# Retrieve the LAMBDA_TASK_ROOT environment variable or default to an empty string
current_working_directory = os.environ.get("LAMBDA_TASK_ROOT", "")
current_working_directory = os.environ.get("LAMBDA_TASK_ROOT", "") or str(Path.cwd())

# If LAMBDA_TASK_ROOT is not set, use the current working directory
if not current_working_directory:
current_working_directory = str(Path.cwd())
return str(Path(current_working_directory, relative_path))

# Combine the current working directory and the relative path to get the absolute path
absolute_path = str(Path(current_working_directory, relative_path))

return absolute_path
def sanitize_xray_segment_name(name: str) -> str:
return re.sub(constants.INVALID_XRAY_NAME_CHARACTERS, "", name)
9 changes: 7 additions & 2 deletions aws_lambda_powertools/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast, overload

from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice
from aws_lambda_powertools.shared.functions import (
resolve_env_var_choice,
resolve_truthy_env_var_choice,
sanitize_xray_segment_name,
)
from aws_lambda_powertools.shared.lazy_import import LazyLoader
from aws_lambda_powertools.shared.types import AnyCallableT
from aws_lambda_powertools.tracing.base import BaseProvider, BaseSegment
Expand Down Expand Up @@ -520,7 +524,8 @@ async def async_tasks():
)

# Example: app.ClassA.get_all # noqa ERA001
method_name = f"{method.__module__}.{method.__qualname__}"
# Valid characters can be found at http://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html
method_name = sanitize_xray_segment_name(f"{method.__module__}.{method.__qualname__}")

capture_response = resolve_truthy_env_var_choice(
env=os.getenv(constants.TRACER_CAPTURE_RESPONSE_ENV, "true"),
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test_shared_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
resolve_env_var_choice,
resolve_max_age,
resolve_truthy_env_var_choice,
sanitize_xray_segment_name,
strtobool,
)
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
Expand Down Expand Up @@ -175,3 +176,27 @@ def test_abs_lambda_path_w_filename_envvar(default_lambda_path):
os.environ["LAMBDA_TASK_ROOT"] = default_lambda_path
# Then path = env + relative_path
assert abs_lambda_path(relative_path="cert/pub.cert") == str(Path(os.environ["LAMBDA_TASK_ROOT"], relative_path))


def test_sanitize_xray_segment_name():
# GIVEN a name with invalid characters
invalid_name = "app?;*.lambda_function.(<locals>).get_todos!$~^<>"

# WHEN we sanitize this name by removing invalid characters
sanitized_name = sanitize_xray_segment_name(invalid_name)

# THEN the sanitized name should not contain invalid characters
expected_name = "app.lambda_function.locals.get_todos"
assert sanitized_name == expected_name


def test_sanitize_xray_segment_name_with_no_special_characters():
# GIVEN a name without any invalid characters
valid_name = "app#lambda_function"

# WHEN we sanitize this name
sanitized_name = sanitize_xray_segment_name(valid_name)

# THEN the sanitized name remains the same as the original name
expected_name = valid_name
assert sanitized_name == expected_name
27 changes: 13 additions & 14 deletions tests/unit/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def greeting(name, message):
# and use service name as a metadata namespace
assert in_subsegment_mock.in_subsegment.call_count == 1
assert in_subsegment_mock.in_subsegment.call_args == mocker.call(
name=f"## {MODULE_PREFIX}.test_tracer_method.<locals>.greeting",
name=f"## {MODULE_PREFIX}.test_tracer_method.locals.greeting",
)
assert in_subsegment_mock.put_metadata.call_args == mocker.call(
key=f"{MODULE_PREFIX}.test_tracer_method.<locals>.greeting response",
key=f"{MODULE_PREFIX}.test_tracer_method.locals.greeting response",
value=dummy_response,
namespace="booking",
)
Expand Down Expand Up @@ -261,8 +261,7 @@ def greeting(name, message):
# and their service name as the namespace
put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1]
assert (
put_metadata_mock_args["key"]
== f"{MODULE_PREFIX}.test_tracer_method_exception_metadata.<locals>.greeting error"
put_metadata_mock_args["key"] == f"{MODULE_PREFIX}.test_tracer_method_exception_metadata.locals.greeting error"
)
assert put_metadata_mock_args["namespace"] == "booking"

Expand Down Expand Up @@ -316,20 +315,20 @@ async def greeting(name, message):
# THEN we should add metadata for each response like we would for a sync decorated method
assert in_subsegment_mock.in_subsegment.call_count == 2
assert in_subsegment_greeting_call_args == mocker.call(
name=f"## {MODULE_PREFIX}.test_tracer_method_nested_async.<locals>.greeting",
name=f"## {MODULE_PREFIX}.test_tracer_method_nested_async.locals.greeting",
)
assert in_subsegment_greeting2_call_args == mocker.call(
name=f"## {MODULE_PREFIX}.test_tracer_method_nested_async.<locals>.greeting_2",
name=f"## {MODULE_PREFIX}.test_tracer_method_nested_async.locals.greeting_2",
)

assert in_subsegment_mock.put_metadata.call_count == 2
assert put_metadata_greeting2_call_args == mocker.call(
key=f"{MODULE_PREFIX}.test_tracer_method_nested_async.<locals>.greeting_2 response",
key=f"{MODULE_PREFIX}.test_tracer_method_nested_async.locals.greeting_2 response",
value=dummy_response,
namespace="booking",
)
assert put_metadata_greeting_call_args == mocker.call(
key=f"{MODULE_PREFIX}.test_tracer_method_nested_async.<locals>.greeting response",
key=f"{MODULE_PREFIX}.test_tracer_method_nested_async.locals.greeting response",
value=dummy_response,
namespace="booking",
)
Expand Down Expand Up @@ -375,7 +374,7 @@ async def greeting(name, message):
put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1]
assert (
put_metadata_mock_args["key"]
== f"{MODULE_PREFIX}.test_tracer_method_exception_metadata_async.<locals>.greeting error"
== f"{MODULE_PREFIX}.test_tracer_method_exception_metadata_async.locals.greeting error"
)
assert put_metadata_mock_args["namespace"] == "booking"

Expand Down Expand Up @@ -409,7 +408,7 @@ def handler(event, context):
assert in_subsegment_mock.in_subsegment.call_count == 2
assert handler_trace == mocker.call(name="## handler")
assert yield_function_trace == mocker.call(
name=f"## {MODULE_PREFIX}.test_tracer_yield_from_context_manager.<locals>.yield_with_capture",
name=f"## {MODULE_PREFIX}.test_tracer_yield_from_context_manager.locals.yield_with_capture",
)
assert "test result" in result

Expand All @@ -436,7 +435,7 @@ def yield_with_capture():
put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1]
assert (
put_metadata_mock_args["key"]
== f"{MODULE_PREFIX}.test_tracer_yield_from_context_manager_exception_metadata.<locals>.yield_with_capture error" # noqa E501
== f"{MODULE_PREFIX}.test_tracer_yield_from_context_manager_exception_metadata.locals.yield_with_capture error" # noqa E501
)
assert isinstance(put_metadata_mock_args["value"], ValueError)
assert put_metadata_mock_args["namespace"] == "booking"
Expand Down Expand Up @@ -480,7 +479,7 @@ def handler(event, context):
assert in_subsegment_mock.in_subsegment.call_count == 2
assert handler_trace == mocker.call(name="## handler")
assert yield_function_trace == mocker.call(
name=f"## {MODULE_PREFIX}.test_tracer_yield_from_nested_context_manager.<locals>.yield_with_capture",
name=f"## {MODULE_PREFIX}.test_tracer_yield_from_nested_context_manager.locals.yield_with_capture",
)
assert "test result" in result

Expand Down Expand Up @@ -512,7 +511,7 @@ def handler(event, context):
assert in_subsegment_mock.in_subsegment.call_count == 2
assert handler_trace == mocker.call(name="## handler")
assert generator_fn_trace == mocker.call(
name=f"## {MODULE_PREFIX}.test_tracer_yield_from_generator.<locals>.generator_fn",
name=f"## {MODULE_PREFIX}.test_tracer_yield_from_generator.locals.generator_fn",
)
assert "test result" in result

Expand All @@ -538,7 +537,7 @@ def generator_fn():
put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1]
assert (
put_metadata_mock_args["key"]
== f"{MODULE_PREFIX}.test_tracer_yield_from_generator_exception_metadata.<locals>.generator_fn error"
== f"{MODULE_PREFIX}.test_tracer_yield_from_generator_exception_metadata.locals.generator_fn error"
)
assert put_metadata_mock_args["namespace"] == "booking"
assert isinstance(put_metadata_mock_args["value"], ValueError)
Expand Down