Skip to content

Commit 69f8f92

Browse files
committed
add trace id to sqs SendMessage message attributes (#3)
1 parent e875db3 commit 69f8f92

File tree

2 files changed

+93
-24
lines changed
  • instrumentation
    • opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda
    • opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore

2 files changed

+93
-24
lines changed

instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,23 @@ def custom_event_context_extractor(lambda_event):
8787
TRACE_HEADER_KEY,
8888
AwsXRayPropagator,
8989
)
90+
from opentelemetry.propagators import textmap
9091
from opentelemetry.semconv.resource import ResourceAttributes
9192
from opentelemetry.semconv.trace import SpanAttributes
9293
from opentelemetry.trace import (
9394
Span,
9495
SpanKind,
96+
Link,
9597
TracerProvider,
98+
get_current_span,
9699
get_tracer,
97100
get_tracer_provider,
98101
set_span_in_context
99102
)
100103
from opentelemetry.trace.propagation import get_current_span
101104
from opentelemetry.trace.status import Status, StatusCode
102105
import json
106+
import typing
103107
#import traceback
104108
#import tracemalloc
105109

@@ -424,8 +428,15 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
424428
if lambda_event["Records"][0]["eventSource"] in {
425429
"aws:sqs",
426430
}:
431+
links = []
432+
for record in lambda_event["Records"]:
433+
attributes = record.get("messageAttributes")
434+
if attributes is not None:
435+
ctx = get_global_textmap().extract(carrier=attributes, getter=SQSGetter())
436+
links.append(Link(get_current_span(ctx).get_span_context()))
437+
427438
span_name = orig_handler_name
428-
sqsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER)
439+
sqsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER, links=links)
429440
sqsTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub")
430441
sqsTriggerSpan.set_attribute("faas.trigger.type", "SQS")
431442

@@ -435,7 +446,8 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
435446
sqsTriggerSpan.set_attribute(
436447
"rpc.request.body",
437448
lambda_event["Records"][0].get("body"),
438-
)
449+
)
450+
439451
except Exception as ex:
440452
pass
441453

@@ -810,3 +822,29 @@ def _uninstrument(self, **kwargs):
810822
import_module(self._wrapped_module_name),
811823
self._wrapped_function_name,
812824
)
825+
826+
827+
class SQSGetter():
828+
def get(
829+
self, carrier: typing.Mapping[str, textmap.CarrierValT], key: str
830+
) -> typing.Optional[typing.List[str]]:
831+
"""Getter implementation to retrieve a value from a dictionary.
832+
833+
Args:
834+
carrier: dictionary in which to get value
835+
key: the key used to get the value
836+
Returns:
837+
A list with a single string with the value if it exists, else None.
838+
"""
839+
val = carrier.get(key, None)
840+
if val is None:
841+
return None
842+
if val.get("stringValue") is not None:
843+
return [val.get("stringValue")]
844+
return None
845+
846+
def keys(
847+
self, carrier: typing.Mapping[str, textmap.CarrierValT]
848+
) -> typing.List[str]:
849+
"""Keys implementation that returns all keys from a dictionary."""
850+
return list(carrier.keys())

instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ def response_hook(span, service_name, operation_name, result):
102102
suppress_http_instrumentation,
103103
unwrap,
104104
)
105-
from opentelemetry.propagators.aws.aws_xray_propagator import AwsXRayPropagator
105+
from opentelemetry.propagate import inject
106+
from opentelemetry.propagators import textmap
106107
from opentelemetry.semconv.trace import SpanAttributes
107108
from opentelemetry.trace import get_tracer
108109
from opentelemetry.trace.span import Span
109-
import copy
110110
import base64
111-
import traceback
111+
import typing
112112
logger = logging.getLogger(__name__)
113113

114114

@@ -148,7 +148,7 @@ def _instrument(self, **kwargs):
148148
self.request_hook = kwargs.get("request_hook")
149149
self.response_hook = kwargs.get("response_hook")
150150
try:
151-
self.payload_size_limit = int(os.environ.get("OTEL_PAYLOAD_SIZE_LIMIT", 204800))
151+
self.payload_size_limit = int(os.environ.get("OTEL_PAYLOAD_SIZE_LIMIT", 51200))
152152
except ValueError:
153153
logger.error(
154154
"OTEL_PAYLOAD_SIZE_LIMIT is not a number"
@@ -201,11 +201,6 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
201201
if call_context is None:
202202
return original_func(*args, **kwargs)
203203

204-
#print("parsing context")
205-
#print(call_context.service)
206-
#print(call_context.operation)
207-
#print(args[1].get("ClientContext"))
208-
209204
extension = _find_extension(call_context)
210205
if not extension.should_trace_service_call():
211206
return original_func(*args, **kwargs)
@@ -224,21 +219,21 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
224219
elif call_context.operation == "PutObject":
225220
body = call_context.params.get("Body")
226221
if body is not None:
227-
attributes["rpc.request.payload"] = body.decode('ascii')
222+
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, body.decode('ascii'))
228223
elif call_context.operation == "PutItem":
229224
body = call_context.params.get("Item")
230225
if body is not None:
231-
attributes["rpc.request.payload"] = json.dumps(body, default=str)
226+
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, json.dumps(body, default=str))
232227
elif call_context.operation == "GetItem":
233228
body = call_context.params.get("Key")
234229
if body is not None:
235-
attributes["rpc.request.payload"] = json.dumps(body, default=str)
230+
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit,json.dumps(body, default=str))
236231
elif call_context.operation == "Publish":
237232
body = call_context.params.get("Message")
238233
if body is not None:
239-
attributes["rpc.request.payload"] = json.dumps(body, default=str)
234+
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit,json.dumps(body, default=str))
240235
else:
241-
attributes["rpc.request.payload"] = json.dumps(call_context.params, default=str)
236+
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, json.dumps(call_context.params, default=str))
242237
except Exception as ex:
243238
pass
244239

@@ -267,19 +262,34 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
267262
jctx = json.dumps(ctx)
268263
args[1]['ClientContext'] = base64.b64encode(jctx.encode('ascii')).decode('ascii')
269264
else:
270-
#ctx = {'custom': {'traceContext':{}}}
271-
#inject(ctx['custom']['traceContext'])
272265
ctx = {'custom': {}}
273266
inject(ctx['custom'])
274267
jctx = json.dumps(ctx)
275268
args[1]['ClientContext'] = base64.b64encode(jctx.encode('ascii')).decode('ascii')
276269

277270
except Exception as ex:
278-
#print(traceback.format_exc())
279-
#print("exception")
280-
#print(ex)
281271
pass
282272

273+
try:
274+
if call_context.service == "sqs" and call_context.operation == "SendMessage":
275+
if args[1].get("MessageAttributes") is not None:
276+
inject(carrier = args[1].get("MessageAttributes"), setter=SQSSetter())
277+
else:
278+
args[1]['MessageAttributes'] = {}
279+
inject(carrier = args[1].get("MessageAttributes"), setter=SQSSetter())
280+
281+
if call_context.service == "sqs" and call_context.operation == "SendMessageBatch":
282+
if args[1].get("Entries") is not None:
283+
for entry in args[1].get("Entries"):
284+
if entry.get("MessageAttributes") is not None:
285+
inject(carrier = entry.get("MessageAttributes"), setter=SQSSetter())
286+
else:
287+
entry['MessageAttributes'] = {}
288+
inject(carrier = entry.get("MessageAttributes"), setter=SQSSetter())
289+
290+
except Exception as ex:
291+
pass
292+
283293
result = None
284294
try:
285295
#print("calling original func")
@@ -404,9 +414,6 @@ def _apply_response_attributes(span: Span, result, payload_size_limit):
404414
span.set_attribute(
405415
"rpc.response.payload", json.dumps(result, default=str))
406416
except Exception as ex:
407-
#print(traceback.format_exc())
408-
#print("exception")
409-
#print(ex)
410417
pass
411418

412419

@@ -440,3 +447,27 @@ def _safe_invoke(function: Callable, *args):
440447
logger.error(
441448
"Error when invoking function '%s'", function_name, exc_info=ex
442449
)
450+
451+
class SQSSetter():
452+
def set(
453+
self,
454+
carrier: typing.MutableMapping[str, textmap.CarrierValT],
455+
key: str,
456+
value: textmap.CarrierValT,
457+
) -> None:
458+
"""Setter implementation to set a value into a dictionary.
459+
460+
Args:
461+
carrier: dictionary in which to set value
462+
key: the key used to set the value
463+
value: the value to set
464+
"""
465+
val = {"DataType": "String", "StringValue": value}
466+
carrier[key] = val
467+
468+
def limit_string_size(s: str, max_size: int) -> str:
469+
if len(s) > max_size:
470+
return s[:max_size]
471+
else:
472+
return s
473+

0 commit comments

Comments
 (0)