Skip to content

Commit a7d9001

Browse files
committed
feat: add xray_trace_id key when tracing is active #137
1 parent 319c363 commit a7d9001

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

Diff for: aws_lambda_powertools/logging/formatter.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import os
34

45

56
class JsonFormatter(logging.Formatter):
@@ -29,14 +30,27 @@ def __init__(self, **kwargs):
2930
# Set the default unserializable function, by default values will be cast as str.
3031
self.default_json_formatter = kwargs.pop("json_default", str)
3132
# Set the insertion order for the log messages
32-
self.format_dict = dict.fromkeys(kwargs.pop("log_record_order", ["level", "location", "message", "timestamp"]))
33+
self.format_dict = dict.fromkeys(
34+
kwargs.pop("log_record_order", ["level", "location", "message", "xray_trace_id", "timestamp"])
35+
)
36+
self.reserved_keys = ["timestamp", "level", "location"]
3337
# Set the date format used by `asctime`
3438
super(JsonFormatter, self).__init__(datefmt=kwargs.pop("datefmt", None))
3539

36-
self.reserved_keys = ["timestamp", "level", "location"]
37-
self.format_dict.update(
38-
{"level": "%(levelname)s", "location": "%(funcName)s:%(lineno)d", "timestamp": "%(asctime)s", **kwargs}
39-
)
40+
self.format_dict.update(self._build_root_keys(**kwargs))
41+
42+
@staticmethod
43+
def _build_root_keys(**kwargs):
44+
xray_trace_id = os.getenv("_X_AMZN_TRACE_ID")
45+
trace_id = xray_trace_id.split(";")[0].replace("Root=", "") if xray_trace_id else None
46+
47+
return {
48+
"level": "%(levelname)s",
49+
"location": "%(funcName)s:%(lineno)d",
50+
"xray_trace_id": trace_id,
51+
"timestamp": "%(asctime)s",
52+
**kwargs,
53+
}
4054

4155
def update_formatter(self, **kwargs):
4256
self.format_dict.update(kwargs)

Diff for: tests/functional/test_aws_lambda_logging.py

+31
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,34 @@ def test_log_dict_key_strip_nones(stdout):
170170

171171
# THEN the keys should only include `level`, `message`, `service`, `sampling_rate`
172172
assert sorted(log_dict.keys()) == ["level", "message", "sampling_rate", "service"]
173+
174+
175+
def test_log_dict_xray_is_present_when_tracing_is_enabled(stdout, monkeypatch):
176+
# GIVEN a logger is initialized within a Lambda function with X-Ray enabled
177+
trace_id = "1-5759e988-bd862e3fe1be46a994272793"
178+
trace_header = f"Root={trace_id};Parent=53995c3f42cd8ad8;Sampled=1"
179+
monkeypatch.setenv(name="_X_AMZN_TRACE_ID", value=trace_header)
180+
logger = Logger(stream=stdout)
181+
182+
# WHEN logging a message
183+
logger.info("foo")
184+
185+
log_dict: dict = json.loads(stdout.getvalue())
186+
187+
# THEN `xray_trace_id`` key should be present
188+
assert log_dict["xray_trace_id"] == trace_id
189+
190+
monkeypatch.delenv(name="_X_AMZN_TRACE_ID")
191+
192+
193+
def test_log_dict_xray_is_not_present_when_tracing_is_disabled(stdout, monkeypatch):
194+
# GIVEN a logger is initialized within a Lambda function with X-Ray disabled (default)
195+
logger = Logger(stream=stdout)
196+
197+
# WHEN logging a message
198+
logger.info("foo")
199+
200+
log_dict: dict = json.loads(stdout.getvalue())
201+
202+
# THEN `xray_trace_id`` key should not be present
203+
assert "xray_trace_id" not in log_dict

0 commit comments

Comments
 (0)