Skip to content

Propagate Step Function Trace Context through Managed Services #573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 19, 2025
28 changes: 13 additions & 15 deletions datadog_lambda/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
_EventSource,
parse_event_source,
get_first_record,
is_step_function_event,
EventTypes,
EventSubtypes,
)
Expand Down Expand Up @@ -320,12 +321,17 @@ def extract_context_from_eventbridge_event(event, lambda_context):
"""
Extract datadog trace context from an EventBridge message's Details.
This is only possible if Details is a JSON string.

If we find a Step Function context, try to extract the trace context from
that header.
"""
try:
detail = event.get("detail")
dd_context = detail.get("_datadog")
if not dd_context:
return extract_context_from_lambda_context(lambda_context)
if is_step_function_event(dd_context):
return extract_context_from_step_functions(detail, lambda_context)
return propagator.extract(dd_context)
except Exception as e:
logger.debug("The trace extractor returned with error %s", e)
Expand Down Expand Up @@ -424,7 +430,7 @@ def _generate_sfn_trace_id(execution_id: str, part: str):
def extract_context_from_step_functions(event, lambda_context):
"""
Only extract datadog trace context when Step Functions Context Object is injected
into lambda's event dict.
into lambda's event dict. Unwrap "Payload" if it exists to handle Legacy Lambda cases.

If '_datadog' header is present, we have two cases:
1. Root is a Lambda and we use its traceID
Expand All @@ -435,6 +441,8 @@ def extract_context_from_step_functions(event, lambda_context):
object.
"""
try:
event = event.get("Payload", event)

meta = {}
dd_data = event.get("_datadog")

Expand Down Expand Up @@ -472,20 +480,6 @@ def extract_context_from_step_functions(event, lambda_context):
return extract_context_from_lambda_context(lambda_context)


def is_legacy_lambda_step_function(event):
"""
Check if the event is a step function that called a legacy lambda
"""
if not isinstance(event, dict) or "Payload" not in event:
return False

event = event.get("Payload")
return isinstance(event, dict) and (
"_datadog" in event
or ("Execution" in event and "StateMachine" in event and "State" in event)
)


def extract_context_custom_extractor(extractor, event, lambda_context):
"""
Extract Datadog trace context using a custom trace extractor function
Expand Down Expand Up @@ -1320,6 +1314,10 @@ def create_inferred_span_from_eventbridge_event(event, context):
if span:
span.set_tags(tags)
span.start = dt.replace(tzinfo=timezone.utc).timestamp()

# Since inferred span will later parent Lambda, preserve Lambda's current parent
span.parent_id = dd_trace_context.span_id
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important because we have the following code in tracing.create_function_execution_span()

if parent_span:
    span.parent_id = parent_span.span_id

where parent_span is the generated inferred span so the Lambda's root span's parent_id will be set to the inferred span's span_id

If there is an upstream Step Function and we saved its trace context in dd_trace_context, we want to preserve the parenting relationship and not let the inferred span completely erase it

This line solves the issue by making the inferred span be a child of the upstream service


return span


Expand Down
30 changes: 27 additions & 3 deletions datadog_lambda/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def parse_event_source(event: dict) -> _EventSource:
if event.get("source") == "aws.events" or has_event_categories:
event_source = _EventSource(EventTypes.CLOUDWATCH_EVENTS)

if (
"_datadog" in event and event.get("_datadog").get("serverless-version") == "v1"
) or ("Execution" in event and "StateMachine" in event and "State" in event):
if is_step_function_event(event):
event_source = _EventSource(EventTypes.STEPFUNCTIONS)

event_record = get_first_record(event)
Expand Down Expand Up @@ -369,3 +367,29 @@ def extract_http_status_code_tag(trigger_tags, response):
status_code = response.status_code

return str(status_code)


def is_step_function_event(event):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way we can memoize this function? It looks like it can potentially be called several times in the course of a single invocation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, or it looks like the function can be called multiple times per invocation, but with different "events" each time? If that's true, then we can probably leave it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great idea!

Correct me if I'm wrong but does the layer only handle one event per invocation? Or if it's a busy Lambda does it stay alive and potentially handle hundreds of events?

Just wondering to get an idea of how large to make the cache. I guess it can be pretty small anyway since each event is new and we don't repeat

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each runtime instance will only ever handle one event at a time. It never handles two events concurrently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah just realized we can't memoize it because event is a dict and mutable types are unhashable

We could serialize the dict and use that but I'm thinking that'd be much slower

"""
Check if the event is a step function that invoked the current lambda.

The whole event can be wrapped in "Payload" in Legacy Lambda cases. There may also be a
"_datadog" for JSONata style context propagation.

The actual event must contain "Execution", "StateMachine", and "State" fields.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like these comments. For someone who hasn't work on step functions for a while, these comments help me recollect these historical context. It'll help future maintenance of the code as well.

"""
event = event.get("Payload", event)

# JSONPath style
if all(field in event for field in ("Execution", "StateMachine", "State")):
return True

# JSONata style
if "_datadog" in event:
event = event["_datadog"]
return all(
field in event
for field in ("Execution", "StateMachine", "State", "serverless-version")
)

return False
3 changes: 0 additions & 3 deletions datadog_lambda/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
is_authorizer_response,
tracer,
propagator,
is_legacy_lambda_step_function,
)
from datadog_lambda.trigger import (
extract_trigger_tags,
Expand Down Expand Up @@ -279,8 +278,6 @@ def _before(self, event, context):
self.response = None
set_cold_start(init_timestamp_ns)
submit_invocations_metric(context)
if is_legacy_lambda_step_function(event):
event = event["Payload"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this unwrapping to happen inside of tracing.extract_context_from_step_functions()

self.trigger_tags = extract_trigger_tags(event, context)
# Extract Datadog trace context and source from incoming requests
dd_context, trace_context_source, event_source = extract_dd_trace_context(
Expand Down
50 changes: 49 additions & 1 deletion tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
service_mapping as global_service_mapping,
propagator,
emit_telemetry_on_exception_outside_of_handler,
is_legacy_lambda_step_function,
)
from datadog_lambda.trigger import EventTypes

Expand Down Expand Up @@ -836,6 +835,55 @@ def test_step_function_trace_data_sfn_root(self):
expected_context,
)

@with_trace_propagation_style("datadog")
def test_step_function_trace_data_event_bridge(self):
lambda_ctx = get_mock_context()
sfn_event = {
"_datadog": {
"Execution": {
"StartTime": "2025-03-11T01:16:31.408Z",
"Id": "arn:aws:states:sa-east-1:425362996713:execution:abhinav-inner-state-machine:eb6298d0-93b5-4fe0-8af9-fefe2933b0ed",
"RedriveCount": 0,
"RoleArn": "arn:aws:iam::425362996713:role/service-role/StepFunctions-abhinav-activity-state-machine-role-22jpbgl6j",
"Name": "eb6298d0-93b5-4fe0-8af9-fefe2933b0ed",
},
"StateMachine": {
"Id": "arn:aws:states:sa-east-1:425362996713:stateMachine:abhinav-inner-state-machine",
"Name": "abhinav-inner-state-machine",
},
"State": {
"EnteredTime": "2025-03-11T01:16:31.448Z",
"RetryCount": 0,
"Name": "EventBridge PutEvents",
},
"serverless-version": "v1",
"RootExecutionId": "arn:aws:states:sa-east-1:425362996713:execution:abhinav-inner-state-machine:eb6298d0-93b5-4fe0-8af9-fefe2933b0ed",
}
}
ctx, source, event_source = extract_dd_trace_context(sfn_event, lambda_ctx)
self.assertEqual(source, "event")
expected_context = Context(
trace_id=4728686021345621131,
span_id=2685222157636933868,
sampling_priority=1,
meta={"_dd.p.tid": "7683d2257c051fce"},
)
self.assertEqual(ctx, expected_context)
self.assertEqual(
get_dd_trace_context(),
{
TraceHeader.TRACE_ID: "4728686021345621131",
TraceHeader.PARENT_ID: "10713633173203262661",
TraceHeader.SAMPLING_PRIORITY: "1",
TraceHeader.TAGS: "_dd.p.tid=7683d2257c051fce",
},
)
create_dd_dummy_metadata_subsegment(ctx, XraySubsegment.TRACE_KEY)
self.mock_send_segment.assert_called_with(
XraySubsegment.TRACE_KEY,
expected_context,
)


class TestXRayContextConversion(unittest.TestCase):
def test_convert_xray_trace_id(self):
Expand Down
66 changes: 66 additions & 0 deletions tests/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_event_source_arn,
extract_trigger_tags,
extract_http_status_code_tag,
is_step_function_event,
)

from tests.utils import get_mock_context
Expand Down Expand Up @@ -543,3 +544,68 @@ def test_extract_http_status_code_tag_from_response_object(self):
response.status_code = 403
status_code = extract_http_status_code_tag(trigger_tags, response)
self.assertEqual(status_code, "403")


class IsStepFunctionEvent(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for putting the tests here which also make the code easier to understand for the future.

def test_is_step_function_event_jsonata(self):
event = {
"_datadog": {
"Execution": {
"Id": "665c417c-1237-4742-aaca-8b3becbb9e75",
"RedriveCount": 0,
},
"StateMachine": {},
"State": {
"Name": "my-awesome-state",
"EnteredTime": "Mon Nov 13 12:43:33 PST 2023",
"RetryCount": 0,
},
"x-datadog-trace-id": "5821803790426892636",
"x-datadog-tags": "_dd.p.dm=-0,_dd.p.tid=672a7cb100000000",
"serverless-version": "v1",
}
}
self.assertTrue(is_step_function_event(event))

def test_is_step_function_event_jsonpath(self):
event = {
"Execution": {
"Id": "665c417c-1237-4742-aaca-8b3becbb9e75",
"RedriveCount": 0,
},
"StateMachine": {},
"State": {
"Name": "my-awesome-state",
"EnteredTime": "Mon Nov 13 12:43:33 PST 2023",
"RetryCount": 0,
},
}
self.assertTrue(is_step_function_event(event))

def test_is_step_function_event_legacy_lambda(self):
event = {
"Payload": {
"Execution": {
"Id": "665c417c-1237-4742-aaca-8b3becbb9e75",
"RedriveCount": 0,
},
"StateMachine": {},
"State": {
"Name": "my-awesome-state",
"EnteredTime": "Mon Nov 13 12:43:33 PST 2023",
"RetryCount": 0,
},
}
}
self.assertTrue(is_step_function_event(event))

def test_is_step_function_event_dd_header(self):
event = {
"_datadog": {
"x-datadog-trace-id": "5821803790426892636",
"x-datadog-parent-id": "5821803790426892636",
"x-datadog-tags": "_dd.p.dm=-0,_dd.p.tid=672a7cb100000000",
"x-datadog-sampling-priority": "1",
}
}
self.assertFalse(is_step_function_event(event))
Loading