Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 365c3b8

Browse files
author
Artem Krivonos
committedAug 22, 2023
Add structured logging implementation
1 parent 0161f76 commit 365c3b8

File tree

5 files changed

+436
-55
lines changed

5 files changed

+436
-55
lines changed
 

‎awslambdaric/bootstrap.py

+97-43
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,19 @@
1313
from .lambda_context import LambdaContext
1414
from .lambda_runtime_client import LambdaRuntimeClient
1515
from .lambda_runtime_exception import FaultException
16+
from .lambda_runtime_log_utils import (
17+
_DATETIME_FORMAT,
18+
_DEFAULT_FRAME_TYPE,
19+
_JSON_FRAME_TYPES,
20+
JsonFormatter,
21+
LogFormat,
22+
)
1623
from .lambda_runtime_marshaller import to_json
1724

1825
ERROR_LOG_LINE_TERMINATE = "\r"
1926
ERROR_LOG_IDENT = "\u00a0" # NO-BREAK SPACE U+00A0
27+
_AWS_LAMBDA_LOG_FORMAT = LogFormat.from_str(os.environ.get("AWS_LAMBDA_LOG_FORMAT"))
28+
_AWS_LAMBDA_LOG_LEVEL = os.environ.get("AWS_LAMBDA_LOG_LEVEL", "").upper()
2029

2130

2231
def _get_handler(handler):
@@ -73,7 +82,12 @@ def result(*args):
7382
return result
7483

7584

76-
def make_error(error_message, error_type, stack_trace, invoke_id=None):
85+
def make_error(
86+
error_message,
87+
error_type,
88+
stack_trace,
89+
invoke_id=None,
90+
):
7791
result = {
7892
"errorMessage": error_message if error_message else "",
7993
"errorType": error_type if error_type else "",
@@ -92,34 +106,52 @@ def replace_line_indentation(line, indent_char, new_indent_char):
92106
return (new_indent_char * ident_chars_count) + line[ident_chars_count:]
93107

94108

95-
def log_error(error_result, log_sink):
96-
error_description = "[ERROR]"
109+
if _AWS_LAMBDA_LOG_FORMAT == LogFormat.JSON:
110+
_ERROR_FRAME_TYPE = _JSON_FRAME_TYPES[logging.ERROR]
111+
112+
def log_error(error_result, log_sink):
113+
error_result = {
114+
"timestamp": time.strftime(
115+
_DATETIME_FORMAT, logging.Formatter.converter(time.time())
116+
),
117+
"log_level": "ERROR",
118+
**error_result,
119+
}
120+
log_sink.log_error(
121+
[to_json(error_result)],
122+
)
97123

98-
error_result_type = error_result.get("errorType")
99-
if error_result_type:
100-
error_description += " " + error_result_type
124+
else:
125+
_ERROR_FRAME_TYPE = _DEFAULT_FRAME_TYPE
101126

102-
error_result_message = error_result.get("errorMessage")
103-
if error_result_message:
127+
def log_error(error_result, log_sink):
128+
error_description = "[ERROR]"
129+
130+
error_result_type = error_result.get("errorType")
104131
if error_result_type:
105-
error_description += ":"
106-
error_description += " " + error_result_message
132+
error_description += " " + error_result_type
133+
134+
error_result_message = error_result.get("errorMessage")
135+
if error_result_message:
136+
if error_result_type:
137+
error_description += ":"
138+
error_description += " " + error_result_message
107139

108-
error_message_lines = [error_description]
140+
error_message_lines = [error_description]
109141

110-
stack_trace = error_result.get("stackTrace")
111-
if stack_trace is not None:
112-
error_message_lines += ["Traceback (most recent call last):"]
113-
for trace_element in stack_trace:
114-
if trace_element == "":
115-
error_message_lines += [""]
116-
else:
117-
for trace_line in trace_element.splitlines():
118-
error_message_lines += [
119-
replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT)
120-
]
142+
stack_trace = error_result.get("stackTrace")
143+
if stack_trace is not None:
144+
error_message_lines += ["Traceback (most recent call last):"]
145+
for trace_element in stack_trace:
146+
if trace_element == "":
147+
error_message_lines += [""]
148+
else:
149+
for trace_line in trace_element.splitlines():
150+
error_message_lines += [
151+
replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT)
152+
]
121153

122-
log_sink.log_error(error_message_lines)
154+
log_sink.log_error(error_message_lines)
123155

124156

125157
def handle_event_request(
@@ -152,7 +184,12 @@ def handle_event_request(
152184
)
153185
except FaultException as e:
154186
xray_fault = make_xray_fault("LambdaValidationError", e.msg, os.getcwd(), [])
155-
error_result = make_error(e.msg, e.exception_type, e.trace, invoke_id)
187+
error_result = make_error(
188+
e.msg,
189+
e.exception_type,
190+
e.trace,
191+
invoke_id,
192+
)
156193

157194
except Exception:
158195
etype, value, tb = sys.exc_info()
@@ -221,7 +258,9 @@ def build_fault_result(exc_info, msg):
221258
break
222259

223260
return make_error(
224-
msg if msg else str(value), etype.__name__, traceback.format_list(tb_tuples)
261+
msg if msg else str(value),
262+
etype.__name__,
263+
traceback.format_list(tb_tuples),
225264
)
226265

227266

@@ -257,7 +296,8 @@ def __init__(self, log_sink):
257296

258297
def emit(self, record):
259298
msg = self.format(record)
260-
self.log_sink.log(msg)
299+
300+
self.log_sink.log(msg, frame_type=getattr(record, "_frame_type", None))
261301

262302

263303
class LambdaLoggerFilter(logging.Filter):
@@ -298,7 +338,7 @@ def __enter__(self):
298338
def __exit__(self, exc_type, exc_value, exc_tb):
299339
pass
300340

301-
def log(self, msg):
341+
def log(self, msg, frame_type=None):
302342
sys.stdout.write(msg)
303343

304344
def log_error(self, message_lines):
@@ -324,7 +364,6 @@ class FramedTelemetryLogSink(object):
324364

325365
def __init__(self, fd):
326366
self.fd = int(fd)
327-
self.frame_type = 0xA55A0003.to_bytes(4, "big")
328367

329368
def __enter__(self):
330369
self.file = os.fdopen(self.fd, "wb", 0)
@@ -333,11 +372,12 @@ def __enter__(self):
333372
def __exit__(self, exc_type, exc_value, exc_tb):
334373
self.file.close()
335374

336-
def log(self, msg):
375+
def log(self, msg, frame_type=None):
337376
encoded_msg = msg.encode("utf8")
377+
338378
timestamp = int(time.time_ns() / 1000) # UNIX timestamp in microseconds
339379
log_msg = (
340-
self.frame_type
380+
(frame_type or _DEFAULT_FRAME_TYPE)
341381
+ len(encoded_msg).to_bytes(4, "big")
342382
+ timestamp.to_bytes(8, "big")
343383
+ encoded_msg
@@ -346,7 +386,10 @@ def log(self, msg):
346386

347387
def log_error(self, message_lines):
348388
error_message = "\n".join(message_lines)
349-
self.log(error_message)
389+
self.log(
390+
error_message,
391+
frame_type=_ERROR_FRAME_TYPE,
392+
)
350393

351394

352395
def update_xray_env_variable(xray_trace_id):
@@ -370,6 +413,28 @@ def create_log_sink():
370413
_GLOBAL_AWS_REQUEST_ID = None
371414

372415

416+
def _setup_logging(log_format, log_level, log_sink):
417+
logging.Formatter.converter = time.gmtime
418+
logger = logging.getLogger()
419+
logger_handler = LambdaLoggerHandler(log_sink)
420+
if log_format == LogFormat.JSON:
421+
logger_handler.setFormatter(JsonFormatter())
422+
423+
logging.addLevelName(logging.DEBUG, "TRACE")
424+
if log_level in logging._nameToLevel:
425+
logger.setLevel(log_level)
426+
else:
427+
logger_handler.setFormatter(
428+
logging.Formatter(
429+
"[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n",
430+
"%Y-%m-%dT%H:%M:%S",
431+
)
432+
)
433+
434+
logger_handler.addFilter(LambdaLoggerFilter())
435+
logger.addHandler(logger_handler)
436+
437+
373438
def run(app_root, handler, lambda_runtime_api_addr):
374439
sys.stdout = Unbuffered(sys.stdout)
375440
sys.stderr = Unbuffered(sys.stderr)
@@ -378,18 +443,7 @@ def run(app_root, handler, lambda_runtime_api_addr):
378443
lambda_runtime_client = LambdaRuntimeClient(lambda_runtime_api_addr)
379444

380445
try:
381-
logging.Formatter.converter = time.gmtime
382-
logger = logging.getLogger()
383-
logger_handler = LambdaLoggerHandler(log_sink)
384-
logger_handler.setFormatter(
385-
logging.Formatter(
386-
"[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n",
387-
"%Y-%m-%dT%H:%M:%S",
388-
)
389-
)
390-
logger_handler.addFilter(LambdaLoggerFilter())
391-
logger.addHandler(logger_handler)
392-
446+
_setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink)
393447
global _GLOBAL_AWS_REQUEST_ID
394448

395449
request_handler = _get_handler(handler)
+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
"""
4+
5+
import json
6+
import logging
7+
import traceback
8+
from enum import IntEnum
9+
10+
_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
11+
_RESERVED_FIELDS = {
12+
"name",
13+
"msg",
14+
"args",
15+
"levelname",
16+
"levelno",
17+
"pathname",
18+
"filename",
19+
"module",
20+
"exc_info",
21+
"exc_text",
22+
"stack_info",
23+
"lineno",
24+
"funcName",
25+
"created",
26+
"msecs",
27+
"relativeCreated",
28+
"thread",
29+
"threadName",
30+
"processName",
31+
"process",
32+
"aws_request_id",
33+
"_frame_type",
34+
}
35+
36+
37+
class LogFormat(IntEnum):
38+
JSON = 0b0
39+
TEXT = 0b1
40+
41+
@classmethod
42+
def from_str(cls, value: str):
43+
if value and value.upper() == "JSON":
44+
return cls.JSON.value
45+
return cls.TEXT.value
46+
47+
48+
_JSON_FRAME_TYPES = {
49+
logging.NOTSET: 0xA55A0002.to_bytes(4, "big"),
50+
logging.DEBUG: 0xA55A000A.to_bytes(4, "big"),
51+
logging.INFO: 0xA55A000E.to_bytes(4, "big"),
52+
logging.WARNING: 0xA55A0012.to_bytes(4, "big"),
53+
logging.ERROR: 0xA55A0016.to_bytes(4, "big"),
54+
logging.CRITICAL: 0xA55A001A.to_bytes(4, "big"),
55+
}
56+
_DEFAULT_FRAME_TYPE = 0xA55A0003.to_bytes(4, "big")
57+
58+
_json_encoder = json.JSONEncoder(ensure_ascii=False)
59+
_encode_json = _json_encoder.encode
60+
61+
62+
class JsonFormatter(logging.Formatter):
63+
def __init__(self):
64+
super().__init__(datefmt=_DATETIME_FORMAT)
65+
66+
@staticmethod
67+
def __format_stacktrace(exc_info):
68+
if not exc_info:
69+
return None
70+
return traceback.format_tb(exc_info[2])
71+
72+
@staticmethod
73+
def __format_exception_name(exc_info):
74+
if not exc_info:
75+
return None
76+
77+
return exc_info[0].__name__
78+
79+
@staticmethod
80+
def __format_exception(exc_info):
81+
if not exc_info:
82+
return None
83+
84+
return str(exc_info[1])
85+
86+
@staticmethod
87+
def __format_location(record: logging.LogRecord):
88+
if not record.exc_info:
89+
return None
90+
91+
return f"{record.pathname}:{record.funcName}:{record.lineno}"
92+
93+
@staticmethod
94+
def __format_log_level(record: logging.LogRecord):
95+
record.levelno = min(50, max(0, record.levelno)) // 10 * 10
96+
record.levelname = logging.getLevelName(record.levelno)
97+
98+
def format(self, record: logging.LogRecord) -> str:
99+
self.__format_log_level(record)
100+
record._frame_type = _JSON_FRAME_TYPES.get(
101+
record.levelno, _JSON_FRAME_TYPES[logging.NOTSET]
102+
)
103+
104+
result = {
105+
"timestamp": self.formatTime(record, self.datefmt),
106+
"level": record.levelname,
107+
"message": record.getMessage(),
108+
"logger": record.name,
109+
"stackTrace": self.__format_stacktrace(record.exc_info),
110+
"errorType": self.__format_exception_name(record.exc_info),
111+
"errorMessage": self.__format_exception(record.exc_info),
112+
"requestId": getattr(record, "aws_request_id", None),
113+
"location": self.__format_location(record),
114+
}
115+
result.update(
116+
(key, value)
117+
for key, value in record.__dict__.items()
118+
if key not in _RESERVED_FIELDS and key not in result
119+
)
120+
121+
result = {k: v for k, v in result.items() if v is not None}
122+
123+
return _encode_json(result) + "\n"

‎tests/test_bootstrap.py

+213-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import importlib
66
import json
7+
import logging
8+
import logging.config
79
import os
810
import re
911
import tempfile
@@ -16,6 +18,7 @@
1618

1719
import awslambdaric.bootstrap as bootstrap
1820
from awslambdaric.lambda_runtime_exception import FaultException
21+
from awslambdaric.lambda_runtime_log_utils import LogFormat
1922
from awslambdaric.lambda_runtime_marshaller import LambdaMarshaller
2023

2124

@@ -613,14 +616,7 @@ def test_handle_event_request_fault_exception_logging_syntax_error(
613616
bootstrap.StandardLogSink(),
614617
)
615618

616-
import sys
617-
618-
sys.stderr.write(mock_stdout.getvalue())
619-
620-
error_logs = (
621-
"[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': "
622-
"unexpected EOF while parsing (<string>, line 1)\r"
623-
)
619+
error_logs = f"[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': {syntax_error}\r"
624620
error_logs += "Traceback (most recent call last):\r"
625621
error_logs += '  File "<string>" Line 1\r'
626622
error_logs += "    -\n"
@@ -1174,6 +1170,215 @@ def test_multiple_frame(self):
11741170
self.assertEqual(content[pos:], b"")
11751171

11761172

1173+
class TestLoggingSetup(unittest.TestCase):
1174+
def test_log_level(self) -> None:
1175+
test_cases = [
1176+
(LogFormat.JSON, "TRACE", logging.DEBUG),
1177+
(LogFormat.JSON, "DEBUG", logging.DEBUG),
1178+
(LogFormat.JSON, "INFO", logging.INFO),
1179+
(LogFormat.JSON, "WARN", logging.WARNING),
1180+
(LogFormat.JSON, "ERROR", logging.ERROR),
1181+
(LogFormat.JSON, "FATAL", logging.CRITICAL),
1182+
# Log level is set only for Json format
1183+
(LogFormat.TEXT, "TRACE", logging.NOTSET),
1184+
(LogFormat.TEXT, "DEBUG", logging.NOTSET),
1185+
(LogFormat.TEXT, "INFO", logging.NOTSET),
1186+
(LogFormat.TEXT, "WARN", logging.NOTSET),
1187+
(LogFormat.TEXT, "ERROR", logging.NOTSET),
1188+
(LogFormat.TEXT, "FATAL", logging.NOTSET),
1189+
("Unknown format", "INFO", logging.NOTSET),
1190+
# if level is unknown fall back to default
1191+
(LogFormat.JSON, "Unknown level", logging.NOTSET),
1192+
]
1193+
for fmt, log_level, expected_level in test_cases:
1194+
with self.subTest():
1195+
# Drop previous setup
1196+
logging.getLogger().handlers.clear()
1197+
logging.getLogger().level = logging.NOTSET
1198+
1199+
bootstrap._setup_logging(fmt, log_level, bootstrap.StandardLogSink())
1200+
1201+
self.assertEqual(expected_level, logging.getLogger().level)
1202+
1203+
1204+
class TestLogging(unittest.TestCase):
1205+
@classmethod
1206+
def setUpClass(cls) -> None:
1207+
logging.getLogger().handlers.clear()
1208+
logging.getLogger().level = logging.NOTSET
1209+
bootstrap._setup_logging(
1210+
LogFormat.from_str("JSON"), "INFO", bootstrap.StandardLogSink()
1211+
)
1212+
1213+
@patch("sys.stderr", new_callable=StringIO)
1214+
def test_json_formatter(self, mock_stderr):
1215+
logger = logging.getLogger("a.b")
1216+
1217+
test_cases = [
1218+
(
1219+
logging.ERROR,
1220+
"TEST 1",
1221+
{
1222+
"level": "ERROR",
1223+
"logger": "a.b",
1224+
"message": "TEST 1",
1225+
"requestId": "",
1226+
},
1227+
),
1228+
(
1229+
logging.ERROR,
1230+
"test \nwith \nnew \nlines",
1231+
{
1232+
"level": "ERROR",
1233+
"logger": "a.b",
1234+
"message": "test \nwith \nnew \nlines",
1235+
"requestId": "",
1236+
},
1237+
),
1238+
(
1239+
logging.CRITICAL,
1240+
"TEST CRITICAL",
1241+
{
1242+
"level": "CRITICAL",
1243+
"logger": "a.b",
1244+
"message": "TEST CRITICAL",
1245+
"requestId": "",
1246+
},
1247+
),
1248+
]
1249+
for level, msg, expected in test_cases:
1250+
with self.subTest(msg):
1251+
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
1252+
logger.log(level, msg)
1253+
1254+
data = json.loads(mock_stdout.getvalue())
1255+
data.pop("timestamp")
1256+
self.assertEqual(
1257+
data,
1258+
expected,
1259+
)
1260+
self.assertEqual(mock_stderr.getvalue(), "")
1261+
1262+
@patch("sys.stdout", new_callable=StringIO)
1263+
@patch("sys.stderr", new_callable=StringIO)
1264+
def test_exception(self, mock_stderr, mock_stdout):
1265+
try:
1266+
raise ValueError("error message")
1267+
except ValueError:
1268+
logging.getLogger("test.logger").exception("test exception")
1269+
1270+
exception_log = json.loads(mock_stdout.getvalue())
1271+
self.assertIn("location", exception_log)
1272+
self.assertIn("stackTrace", exception_log)
1273+
exception_log.pop("timestamp")
1274+
exception_log.pop("location")
1275+
stack_trace = exception_log.pop("stackTrace")
1276+
1277+
self.assertEqual(len(stack_trace), 1)
1278+
1279+
self.assertEqual(
1280+
exception_log,
1281+
{
1282+
"errorMessage": "error message",
1283+
"errorType": "ValueError",
1284+
"level": "ERROR",
1285+
"logger": "test.logger",
1286+
"message": "test exception",
1287+
"requestId": "",
1288+
},
1289+
)
1290+
1291+
self.assertEqual(mock_stderr.getvalue(), "")
1292+
1293+
@patch("sys.stdout", new_callable=StringIO)
1294+
@patch("sys.stderr", new_callable=StringIO)
1295+
def test_log_level(self, mock_stderr, mock_stdout):
1296+
logger = logging.getLogger("test.logger")
1297+
1298+
logger.debug("debug message")
1299+
logger.info("info message")
1300+
1301+
data = json.loads(mock_stdout.getvalue())
1302+
data.pop("timestamp")
1303+
1304+
self.assertEqual(
1305+
data,
1306+
{
1307+
"level": "INFO",
1308+
"logger": "test.logger",
1309+
"message": "info message",
1310+
"requestId": "",
1311+
},
1312+
)
1313+
self.assertEqual(mock_stderr.getvalue(), "")
1314+
1315+
@patch("sys.stdout", new_callable=StringIO)
1316+
@patch("sys.stderr", new_callable=StringIO)
1317+
def test_set_log_level_manually(self, mock_stderr, mock_stdout):
1318+
logger = logging.getLogger("test.logger")
1319+
1320+
# Changing log level after `bootstrap.setup_logging`
1321+
logging.getLogger().setLevel(logging.CRITICAL)
1322+
1323+
logger.debug("debug message")
1324+
logger.info("info message")
1325+
logger.warning("warning message")
1326+
logger.error("error message")
1327+
logger.critical("critical message")
1328+
1329+
data = json.loads(mock_stdout.getvalue())
1330+
data.pop("timestamp")
1331+
1332+
self.assertEqual(
1333+
data,
1334+
{
1335+
"level": "CRITICAL",
1336+
"logger": "test.logger",
1337+
"message": "critical message",
1338+
"requestId": "",
1339+
},
1340+
)
1341+
self.assertEqual(mock_stderr.getvalue(), "")
1342+
1343+
@patch("sys.stdout", new_callable=StringIO)
1344+
@patch("sys.stderr", new_callable=StringIO)
1345+
def test_set_log_level_with_dictConfig(self, mock_stderr, mock_stdout):
1346+
# Changing log level after `bootstrap.setup_logging`
1347+
logging.config.dictConfig(
1348+
{
1349+
"version": 1,
1350+
"disable_existing_loggers": False,
1351+
"formatters": {"simple": {"format": "%(levelname)-8s - %(message)s"}},
1352+
"handlers": {
1353+
"stdout": {
1354+
"class": "logging.StreamHandler",
1355+
"formatter": "simple",
1356+
},
1357+
},
1358+
"root": {
1359+
"level": "CRITICAL",
1360+
"handlers": [
1361+
"stdout",
1362+
],
1363+
},
1364+
}
1365+
)
1366+
1367+
logger = logging.getLogger("test.logger")
1368+
logger.debug("debug message")
1369+
logger.info("info message")
1370+
logger.warning("warning message")
1371+
logger.error("error message")
1372+
logger.critical("critical message")
1373+
1374+
data = mock_stderr.getvalue()
1375+
self.assertEqual(
1376+
data,
1377+
"CRITICAL - critical message\n",
1378+
)
1379+
self.assertEqual(mock_stdout.getvalue(), "")
1380+
1381+
11771382
class TestBootstrapModule(unittest.TestCase):
11781383
@patch("awslambdaric.bootstrap.handle_event_request")
11791384
@patch("awslambdaric.bootstrap.LambdaRuntimeClient")

‎tests/test_lambda_context.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import os
66
import unittest
7-
from unittest.mock import patch, MagicMock
7+
from unittest.mock import MagicMock, patch
88

99
from awslambdaric.lambda_context import LambdaContext
1010

‎tests/test_lambda_runtime_client.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
import http.client
77
import unittest.mock
88
from unittest.mock import MagicMock, patch
9-
from awslambdaric import __version__
10-
119

10+
from awslambdaric import __version__
1211
from awslambdaric.lambda_runtime_client import (
12+
InvocationRequest,
1313
LambdaRuntimeClient,
1414
LambdaRuntimeClientError,
15-
InvocationRequest,
1615
_user_agent,
1716
)
1817

0 commit comments

Comments
 (0)
Please sign in to comment.