diff --git a/azure/durable_functions/models/DurableOrchestrationClient.py b/azure/durable_functions/models/DurableOrchestrationClient.py index 1be5a28e..bd812be9 100644 --- a/azure/durable_functions/models/DurableOrchestrationClient.py +++ b/azure/durable_functions/models/DurableOrchestrationClient.py @@ -4,6 +4,7 @@ from time import time from asyncio import sleep from urllib.parse import urlparse, quote +from opentelemetry import trace import azure.functions as func @@ -71,8 +72,25 @@ async def start_new(self, request_url = self._get_start_new_url( instance_id=instance_id, orchestration_function_name=orchestration_function_name) + # Get the current span + current_span = trace.get_current_span() + span_context = current_span.get_span_context() + + # Get the traceparent and tracestate from the span context + # Follows the W3C Trace Context specification for traceparent + # https://www.w3.org/TR/trace-context/#traceparent-header + trace_id = format(span_context.trace_id, '032x') + span_id = format(span_context.span_id, '016x') + trace_flags = format(span_context.trace_flags, '02x') + trace_parent = f"00-{trace_id}-{span_id}-{trace_flags}" + + trace_state = span_context.trace_state + response: List[Any] = await self._post_async_request( - request_url, self._get_json_input(client_input)) + request_url, + self._get_json_input(client_input), + trace_parent, + trace_state) status_code: int = response[0] if status_code <= 202 and response[1]: diff --git a/azure/durable_functions/models/utils/http_utils.py b/azure/durable_functions/models/utils/http_utils.py index e45cef68..eaa3a07d 100644 --- a/azure/durable_functions/models/utils/http_utils.py +++ b/azure/durable_functions/models/utils/http_utils.py @@ -3,7 +3,10 @@ import aiohttp -async def post_async_request(url: str, data: Any = None) -> List[Union[int, Any]]: +async def post_async_request(url: str, + data: Any = None, + trace_parent: str = None, + trace_state: str = None) -> List[Union[int, Any]]: """Post request with the data provided to the url provided. Parameters @@ -12,6 +15,10 @@ async def post_async_request(url: str, data: Any = None) -> List[Union[int, Any] url to make the post to data: Any object to post + trace_parent: str + traceparent header to send with the request + trace_state: str + tracestate header to send with the request Returns ------- @@ -19,8 +26,12 @@ async def post_async_request(url: str, data: Any = None) -> List[Union[int, Any] Tuple with the Response status code and the data returned from the request """ async with aiohttp.ClientSession() as session: - async with session.post(url, - json=data) as response: + headers = {} + if trace_parent: + headers["traceparent"] = trace_parent + if trace_state: + headers["tracestate"] = trace_state + async with session.post(url, json=data, headers=headers) as response: # We disable aiohttp's input type validation # as the server may respond with alternative # data encodings. This is potentially unsafe. diff --git a/requirements.txt b/requirements.txt index 69900faf..acba43a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ furl==2.1.0 pytest-asyncio==0.20.2 autopep8 types-python-dateutil +opentelemetry-api \ No newline at end of file diff --git a/tests/models/test_DurableOrchestrationClient.py b/tests/models/test_DurableOrchestrationClient.py index 028975ed..6660a2e2 100644 --- a/tests/models/test_DurableOrchestrationClient.py +++ b/tests/models/test_DurableOrchestrationClient.py @@ -67,7 +67,7 @@ async def delete(self, url: str): assert url == self._expected_url return self._response - async def post(self, url: str, data: Any = None): + async def post(self, url: str, data: Any = None, trace_parent: str = None, trace_state: str = None): assert url == self._expected_url return self._response