From 365c3b88aa2e6d15a866be2bdae73f7684479777 Mon Sep 17 00:00:00 2001 From: Artem Krivonos Date: Mon, 21 Aug 2023 17:15:10 +0100 Subject: [PATCH] Add structured logging implementation --- awslambdaric/bootstrap.py | 140 +++++++++----- awslambdaric/lambda_runtime_log_utils.py | 123 +++++++++++++ tests/test_bootstrap.py | 221 ++++++++++++++++++++++- tests/test_lambda_context.py | 2 +- tests/test_lambda_runtime_client.py | 5 +- 5 files changed, 436 insertions(+), 55 deletions(-) create mode 100644 awslambdaric/lambda_runtime_log_utils.py diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index e7b9e5a..5ad7bb5 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -13,10 +13,19 @@ from .lambda_context import LambdaContext from .lambda_runtime_client import LambdaRuntimeClient from .lambda_runtime_exception import FaultException +from .lambda_runtime_log_utils import ( + _DATETIME_FORMAT, + _DEFAULT_FRAME_TYPE, + _JSON_FRAME_TYPES, + JsonFormatter, + LogFormat, +) from .lambda_runtime_marshaller import to_json ERROR_LOG_LINE_TERMINATE = "\r" ERROR_LOG_IDENT = "\u00a0" # NO-BREAK SPACE U+00A0 +_AWS_LAMBDA_LOG_FORMAT = LogFormat.from_str(os.environ.get("AWS_LAMBDA_LOG_FORMAT")) +_AWS_LAMBDA_LOG_LEVEL = os.environ.get("AWS_LAMBDA_LOG_LEVEL", "").upper() def _get_handler(handler): @@ -73,7 +82,12 @@ def result(*args): return result -def make_error(error_message, error_type, stack_trace, invoke_id=None): +def make_error( + error_message, + error_type, + stack_trace, + invoke_id=None, +): result = { "errorMessage": error_message if error_message else "", "errorType": error_type if error_type else "", @@ -92,34 +106,52 @@ def replace_line_indentation(line, indent_char, new_indent_char): return (new_indent_char * ident_chars_count) + line[ident_chars_count:] -def log_error(error_result, log_sink): - error_description = "[ERROR]" +if _AWS_LAMBDA_LOG_FORMAT == LogFormat.JSON: + _ERROR_FRAME_TYPE = _JSON_FRAME_TYPES[logging.ERROR] + + def log_error(error_result, log_sink): + error_result = { + "timestamp": time.strftime( + _DATETIME_FORMAT, logging.Formatter.converter(time.time()) + ), + "log_level": "ERROR", + **error_result, + } + log_sink.log_error( + [to_json(error_result)], + ) - error_result_type = error_result.get("errorType") - if error_result_type: - error_description += " " + error_result_type +else: + _ERROR_FRAME_TYPE = _DEFAULT_FRAME_TYPE - error_result_message = error_result.get("errorMessage") - if error_result_message: + def log_error(error_result, log_sink): + error_description = "[ERROR]" + + error_result_type = error_result.get("errorType") if error_result_type: - error_description += ":" - error_description += " " + error_result_message + error_description += " " + error_result_type + + error_result_message = error_result.get("errorMessage") + if error_result_message: + if error_result_type: + error_description += ":" + error_description += " " + error_result_message - error_message_lines = [error_description] + error_message_lines = [error_description] - stack_trace = error_result.get("stackTrace") - if stack_trace is not None: - error_message_lines += ["Traceback (most recent call last):"] - for trace_element in stack_trace: - if trace_element == "": - error_message_lines += [""] - else: - for trace_line in trace_element.splitlines(): - error_message_lines += [ - replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT) - ] + stack_trace = error_result.get("stackTrace") + if stack_trace is not None: + error_message_lines += ["Traceback (most recent call last):"] + for trace_element in stack_trace: + if trace_element == "": + error_message_lines += [""] + else: + for trace_line in trace_element.splitlines(): + error_message_lines += [ + replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT) + ] - log_sink.log_error(error_message_lines) + log_sink.log_error(error_message_lines) def handle_event_request( @@ -152,7 +184,12 @@ def handle_event_request( ) except FaultException as e: xray_fault = make_xray_fault("LambdaValidationError", e.msg, os.getcwd(), []) - error_result = make_error(e.msg, e.exception_type, e.trace, invoke_id) + error_result = make_error( + e.msg, + e.exception_type, + e.trace, + invoke_id, + ) except Exception: etype, value, tb = sys.exc_info() @@ -221,7 +258,9 @@ def build_fault_result(exc_info, msg): break return make_error( - msg if msg else str(value), etype.__name__, traceback.format_list(tb_tuples) + msg if msg else str(value), + etype.__name__, + traceback.format_list(tb_tuples), ) @@ -257,7 +296,8 @@ def __init__(self, log_sink): def emit(self, record): msg = self.format(record) - self.log_sink.log(msg) + + self.log_sink.log(msg, frame_type=getattr(record, "_frame_type", None)) class LambdaLoggerFilter(logging.Filter): @@ -298,7 +338,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_tb): pass - def log(self, msg): + def log(self, msg, frame_type=None): sys.stdout.write(msg) def log_error(self, message_lines): @@ -324,7 +364,6 @@ class FramedTelemetryLogSink(object): def __init__(self, fd): self.fd = int(fd) - self.frame_type = 0xA55A0003.to_bytes(4, "big") def __enter__(self): self.file = os.fdopen(self.fd, "wb", 0) @@ -333,11 +372,12 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_tb): self.file.close() - def log(self, msg): + def log(self, msg, frame_type=None): encoded_msg = msg.encode("utf8") + timestamp = int(time.time_ns() / 1000) # UNIX timestamp in microseconds log_msg = ( - self.frame_type + (frame_type or _DEFAULT_FRAME_TYPE) + len(encoded_msg).to_bytes(4, "big") + timestamp.to_bytes(8, "big") + encoded_msg @@ -346,7 +386,10 @@ def log(self, msg): def log_error(self, message_lines): error_message = "\n".join(message_lines) - self.log(error_message) + self.log( + error_message, + frame_type=_ERROR_FRAME_TYPE, + ) def update_xray_env_variable(xray_trace_id): @@ -370,6 +413,28 @@ def create_log_sink(): _GLOBAL_AWS_REQUEST_ID = None +def _setup_logging(log_format, log_level, log_sink): + logging.Formatter.converter = time.gmtime + logger = logging.getLogger() + logger_handler = LambdaLoggerHandler(log_sink) + if log_format == LogFormat.JSON: + logger_handler.setFormatter(JsonFormatter()) + + logging.addLevelName(logging.DEBUG, "TRACE") + if log_level in logging._nameToLevel: + logger.setLevel(log_level) + else: + logger_handler.setFormatter( + logging.Formatter( + "[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n", + "%Y-%m-%dT%H:%M:%S", + ) + ) + + logger_handler.addFilter(LambdaLoggerFilter()) + logger.addHandler(logger_handler) + + def run(app_root, handler, lambda_runtime_api_addr): sys.stdout = Unbuffered(sys.stdout) sys.stderr = Unbuffered(sys.stderr) @@ -378,18 +443,7 @@ def run(app_root, handler, lambda_runtime_api_addr): lambda_runtime_client = LambdaRuntimeClient(lambda_runtime_api_addr) try: - logging.Formatter.converter = time.gmtime - logger = logging.getLogger() - logger_handler = LambdaLoggerHandler(log_sink) - logger_handler.setFormatter( - logging.Formatter( - "[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n", - "%Y-%m-%dT%H:%M:%S", - ) - ) - logger_handler.addFilter(LambdaLoggerFilter()) - logger.addHandler(logger_handler) - + _setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink) global _GLOBAL_AWS_REQUEST_ID request_handler = _get_handler(handler) diff --git a/awslambdaric/lambda_runtime_log_utils.py b/awslambdaric/lambda_runtime_log_utils.py new file mode 100644 index 0000000..f140253 --- /dev/null +++ b/awslambdaric/lambda_runtime_log_utils.py @@ -0,0 +1,123 @@ +""" +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +""" + +import json +import logging +import traceback +from enum import IntEnum + +_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" +_RESERVED_FIELDS = { + "name", + "msg", + "args", + "levelname", + "levelno", + "pathname", + "filename", + "module", + "exc_info", + "exc_text", + "stack_info", + "lineno", + "funcName", + "created", + "msecs", + "relativeCreated", + "thread", + "threadName", + "processName", + "process", + "aws_request_id", + "_frame_type", +} + + +class LogFormat(IntEnum): + JSON = 0b0 + TEXT = 0b1 + + @classmethod + def from_str(cls, value: str): + if value and value.upper() == "JSON": + return cls.JSON.value + return cls.TEXT.value + + +_JSON_FRAME_TYPES = { + logging.NOTSET: 0xA55A0002.to_bytes(4, "big"), + logging.DEBUG: 0xA55A000A.to_bytes(4, "big"), + logging.INFO: 0xA55A000E.to_bytes(4, "big"), + logging.WARNING: 0xA55A0012.to_bytes(4, "big"), + logging.ERROR: 0xA55A0016.to_bytes(4, "big"), + logging.CRITICAL: 0xA55A001A.to_bytes(4, "big"), +} +_DEFAULT_FRAME_TYPE = 0xA55A0003.to_bytes(4, "big") + +_json_encoder = json.JSONEncoder(ensure_ascii=False) +_encode_json = _json_encoder.encode + + +class JsonFormatter(logging.Formatter): + def __init__(self): + super().__init__(datefmt=_DATETIME_FORMAT) + + @staticmethod + def __format_stacktrace(exc_info): + if not exc_info: + return None + return traceback.format_tb(exc_info[2]) + + @staticmethod + def __format_exception_name(exc_info): + if not exc_info: + return None + + return exc_info[0].__name__ + + @staticmethod + def __format_exception(exc_info): + if not exc_info: + return None + + return str(exc_info[1]) + + @staticmethod + def __format_location(record: logging.LogRecord): + if not record.exc_info: + return None + + return f"{record.pathname}:{record.funcName}:{record.lineno}" + + @staticmethod + def __format_log_level(record: logging.LogRecord): + record.levelno = min(50, max(0, record.levelno)) // 10 * 10 + record.levelname = logging.getLevelName(record.levelno) + + def format(self, record: logging.LogRecord) -> str: + self.__format_log_level(record) + record._frame_type = _JSON_FRAME_TYPES.get( + record.levelno, _JSON_FRAME_TYPES[logging.NOTSET] + ) + + result = { + "timestamp": self.formatTime(record, self.datefmt), + "level": record.levelname, + "message": record.getMessage(), + "logger": record.name, + "stackTrace": self.__format_stacktrace(record.exc_info), + "errorType": self.__format_exception_name(record.exc_info), + "errorMessage": self.__format_exception(record.exc_info), + "requestId": getattr(record, "aws_request_id", None), + "location": self.__format_location(record), + } + result.update( + (key, value) + for key, value in record.__dict__.items() + if key not in _RESERVED_FIELDS and key not in result + ) + + result = {k: v for k, v in result.items() if v is not None} + + return _encode_json(result) + "\n" diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index edb0737..5614a2e 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -4,6 +4,8 @@ import importlib import json +import logging +import logging.config import os import re import tempfile @@ -16,6 +18,7 @@ import awslambdaric.bootstrap as bootstrap from awslambdaric.lambda_runtime_exception import FaultException +from awslambdaric.lambda_runtime_log_utils import LogFormat from awslambdaric.lambda_runtime_marshaller import LambdaMarshaller @@ -613,14 +616,7 @@ def test_handle_event_request_fault_exception_logging_syntax_error( bootstrap.StandardLogSink(), ) - import sys - - sys.stderr.write(mock_stdout.getvalue()) - - error_logs = ( - "[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': " - "unexpected EOF while parsing (, line 1)\r" - ) + error_logs = f"[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': {syntax_error}\r" error_logs += "Traceback (most recent call last):\r" error_logs += '  File "" Line 1\r' error_logs += "    -\n" @@ -1174,6 +1170,215 @@ def test_multiple_frame(self): self.assertEqual(content[pos:], b"") +class TestLoggingSetup(unittest.TestCase): + def test_log_level(self) -> None: + test_cases = [ + (LogFormat.JSON, "TRACE", logging.DEBUG), + (LogFormat.JSON, "DEBUG", logging.DEBUG), + (LogFormat.JSON, "INFO", logging.INFO), + (LogFormat.JSON, "WARN", logging.WARNING), + (LogFormat.JSON, "ERROR", logging.ERROR), + (LogFormat.JSON, "FATAL", logging.CRITICAL), + # Log level is set only for Json format + (LogFormat.TEXT, "TRACE", logging.NOTSET), + (LogFormat.TEXT, "DEBUG", logging.NOTSET), + (LogFormat.TEXT, "INFO", logging.NOTSET), + (LogFormat.TEXT, "WARN", logging.NOTSET), + (LogFormat.TEXT, "ERROR", logging.NOTSET), + (LogFormat.TEXT, "FATAL", logging.NOTSET), + ("Unknown format", "INFO", logging.NOTSET), + # if level is unknown fall back to default + (LogFormat.JSON, "Unknown level", logging.NOTSET), + ] + for fmt, log_level, expected_level in test_cases: + with self.subTest(): + # Drop previous setup + logging.getLogger().handlers.clear() + logging.getLogger().level = logging.NOTSET + + bootstrap._setup_logging(fmt, log_level, bootstrap.StandardLogSink()) + + self.assertEqual(expected_level, logging.getLogger().level) + + +class TestLogging(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + logging.getLogger().handlers.clear() + logging.getLogger().level = logging.NOTSET + bootstrap._setup_logging( + LogFormat.from_str("JSON"), "INFO", bootstrap.StandardLogSink() + ) + + @patch("sys.stderr", new_callable=StringIO) + def test_json_formatter(self, mock_stderr): + logger = logging.getLogger("a.b") + + test_cases = [ + ( + logging.ERROR, + "TEST 1", + { + "level": "ERROR", + "logger": "a.b", + "message": "TEST 1", + "requestId": "", + }, + ), + ( + logging.ERROR, + "test \nwith \nnew \nlines", + { + "level": "ERROR", + "logger": "a.b", + "message": "test \nwith \nnew \nlines", + "requestId": "", + }, + ), + ( + logging.CRITICAL, + "TEST CRITICAL", + { + "level": "CRITICAL", + "logger": "a.b", + "message": "TEST CRITICAL", + "requestId": "", + }, + ), + ] + for level, msg, expected in test_cases: + with self.subTest(msg): + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + logger.log(level, msg) + + data = json.loads(mock_stdout.getvalue()) + data.pop("timestamp") + self.assertEqual( + data, + expected, + ) + self.assertEqual(mock_stderr.getvalue(), "") + + @patch("sys.stdout", new_callable=StringIO) + @patch("sys.stderr", new_callable=StringIO) + def test_exception(self, mock_stderr, mock_stdout): + try: + raise ValueError("error message") + except ValueError: + logging.getLogger("test.logger").exception("test exception") + + exception_log = json.loads(mock_stdout.getvalue()) + self.assertIn("location", exception_log) + self.assertIn("stackTrace", exception_log) + exception_log.pop("timestamp") + exception_log.pop("location") + stack_trace = exception_log.pop("stackTrace") + + self.assertEqual(len(stack_trace), 1) + + self.assertEqual( + exception_log, + { + "errorMessage": "error message", + "errorType": "ValueError", + "level": "ERROR", + "logger": "test.logger", + "message": "test exception", + "requestId": "", + }, + ) + + self.assertEqual(mock_stderr.getvalue(), "") + + @patch("sys.stdout", new_callable=StringIO) + @patch("sys.stderr", new_callable=StringIO) + def test_log_level(self, mock_stderr, mock_stdout): + logger = logging.getLogger("test.logger") + + logger.debug("debug message") + logger.info("info message") + + data = json.loads(mock_stdout.getvalue()) + data.pop("timestamp") + + self.assertEqual( + data, + { + "level": "INFO", + "logger": "test.logger", + "message": "info message", + "requestId": "", + }, + ) + self.assertEqual(mock_stderr.getvalue(), "") + + @patch("sys.stdout", new_callable=StringIO) + @patch("sys.stderr", new_callable=StringIO) + def test_set_log_level_manually(self, mock_stderr, mock_stdout): + logger = logging.getLogger("test.logger") + + # Changing log level after `bootstrap.setup_logging` + logging.getLogger().setLevel(logging.CRITICAL) + + logger.debug("debug message") + logger.info("info message") + logger.warning("warning message") + logger.error("error message") + logger.critical("critical message") + + data = json.loads(mock_stdout.getvalue()) + data.pop("timestamp") + + self.assertEqual( + data, + { + "level": "CRITICAL", + "logger": "test.logger", + "message": "critical message", + "requestId": "", + }, + ) + self.assertEqual(mock_stderr.getvalue(), "") + + @patch("sys.stdout", new_callable=StringIO) + @patch("sys.stderr", new_callable=StringIO) + def test_set_log_level_with_dictConfig(self, mock_stderr, mock_stdout): + # Changing log level after `bootstrap.setup_logging` + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": {"simple": {"format": "%(levelname)-8s - %(message)s"}}, + "handlers": { + "stdout": { + "class": "logging.StreamHandler", + "formatter": "simple", + }, + }, + "root": { + "level": "CRITICAL", + "handlers": [ + "stdout", + ], + }, + } + ) + + logger = logging.getLogger("test.logger") + logger.debug("debug message") + logger.info("info message") + logger.warning("warning message") + logger.error("error message") + logger.critical("critical message") + + data = mock_stderr.getvalue() + self.assertEqual( + data, + "CRITICAL - critical message\n", + ) + self.assertEqual(mock_stdout.getvalue(), "") + + class TestBootstrapModule(unittest.TestCase): @patch("awslambdaric.bootstrap.handle_event_request") @patch("awslambdaric.bootstrap.LambdaRuntimeClient") diff --git a/tests/test_lambda_context.py b/tests/test_lambda_context.py index 545efa1..34d59da 100644 --- a/tests/test_lambda_context.py +++ b/tests/test_lambda_context.py @@ -4,7 +4,7 @@ import os import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from awslambdaric.lambda_context import LambdaContext diff --git a/tests/test_lambda_runtime_client.py b/tests/test_lambda_runtime_client.py index 814ca96..47d95cf 100644 --- a/tests/test_lambda_runtime_client.py +++ b/tests/test_lambda_runtime_client.py @@ -6,13 +6,12 @@ import http.client import unittest.mock from unittest.mock import MagicMock, patch -from awslambdaric import __version__ - +from awslambdaric import __version__ from awslambdaric.lambda_runtime_client import ( + InvocationRequest, LambdaRuntimeClient, LambdaRuntimeClientError, - InvocationRequest, _user_agent, )