Skip to content

Commit b4fade9

Browse files
author
Artem Krivonos
committed
bugfixes and speedups
1 parent 31c93cc commit b4fade9

File tree

4 files changed

+60
-112
lines changed

4 files changed

+60
-112
lines changed

awslambdaric/bootstrap.py

+39-37
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,21 @@
1414
from .lambda_runtime_client import LambdaRuntimeClient
1515
from .lambda_runtime_exception import FaultException
1616
from .lambda_runtime_log_utils import (
17-
INVOCATION_LOGGING_CONTEXT,
1817
JsonFormatter,
1918
DATETIME_FORMAT,
20-
LambdaLogFormat,
19+
JSON_FORMAT,
20+
TEXT_FORMAT,
21+
get_log_format_from_str,
2122
)
2223
from .lambda_runtime_marshaller import to_json
2324

2425
ERROR_LOG_LINE_TERMINATE = "\r"
2526
ERROR_LOG_IDENT = "\u00a0" # NO-BREAK SPACE U+00A0
2627
RUNTIME_ERROR_LOGGER_NAME = "system"
28+
AWS_LAMBDA_LOG_FORMAT = get_log_format_from_str(
29+
os.environ.get("AWS_LAMBDA_LOG_FORMAT", "TEXT").upper()
30+
)
31+
AWS_LAMBDA_LOG_LEVEL = os.environ.get("AWS_LAMBDA_LOG_LEVEL")
2732

2833

2934
def _get_handler(handler):
@@ -106,18 +111,11 @@ def replace_line_indentation(line, indent_char, new_indent_char):
106111
return (new_indent_char * ident_chars_count) + line[ident_chars_count:]
107112

108113

109-
def log_error(error_result, log_sink):
110-
if AWS_LAMBDA_LOG_FORMAT == LambdaLogFormat.JSON:
111-
log_error_json(error_result, log_sink)
112-
else:
113-
log_error_text(error_result, log_sink)
114-
115-
116114
def log_error_json(error_result, log_sink):
117115
log_level = error_result.pop("log_level", logging.ERROR)
118116
error_result["level"] = logging.getLevelName(log_level)
119117
log_sink.log_error(
120-
[to_json(error_result)],
118+
[to_json(error_result, ensure_ascii=False)],
121119
log_level=log_level,
122120
)
123121

@@ -154,6 +152,12 @@ def log_error_text(error_result, log_sink):
154152
)
155153

156154

155+
if AWS_LAMBDA_LOG_FORMAT == JSON_FORMAT:
156+
log_error = log_error_json
157+
else:
158+
log_error = log_error_text
159+
160+
157161
def handle_event_request(
158162
lambda_runtime_client,
159163
request_handler,
@@ -292,13 +296,13 @@ def __init__(self, log_sink):
292296
logging.Handler.__init__(self)
293297
self.log_sink = log_sink
294298

295-
def emit(self, record: logging.LogRecord):
299+
def emit(self, record):
296300
msg = self.format(record)
297301

298302
self.log_sink.log(
299303
msg,
300304
log_level=record.levelno,
301-
log_format=getattr(record, "log_format", LambdaLogFormat.TEXT),
305+
log_format=getattr(record, "log_format", TEXT_FORMAT),
302306
)
303307

304308

@@ -340,14 +344,31 @@ def __enter__(self):
340344
def __exit__(self, exc_type, exc_value, exc_tb):
341345
pass
342346

343-
def log(self, msg, log_level=logging.NOTSET, log_format=LambdaLogFormat.TEXT):
347+
def log(self, msg, log_level=None, log_format=None):
344348
sys.stdout.write(msg)
345349

346350
def log_error(self, message_lines, log_level=logging.ERROR):
347351
error_message = ERROR_LOG_LINE_TERMINATE.join(message_lines) + "\n"
348352
sys.stdout.write(error_message)
349353

350354

355+
FRAME_TYPES = {
356+
(JSON_FORMAT, logging.NOTSET): 0xA55A0002.to_bytes(4, "big"),
357+
(JSON_FORMAT, logging.DEBUG): 0xA55A000A.to_bytes(4, "big"),
358+
(JSON_FORMAT, logging.INFO): 0xA55A000E.to_bytes(4, "big"),
359+
(JSON_FORMAT, logging.WARNING): 0xA55A0012.to_bytes(4, "big"),
360+
(JSON_FORMAT, logging.ERROR): 0xA55A0016.to_bytes(4, "big"),
361+
(JSON_FORMAT, logging.CRITICAL): 0xA55A001A.to_bytes(4, "big"),
362+
(TEXT_FORMAT, logging.NOTSET): 0xA55A0003.to_bytes(4, "big"),
363+
(TEXT_FORMAT, logging.DEBUG): 0xA55A000B.to_bytes(4, "big"),
364+
(TEXT_FORMAT, logging.INFO): 0xA55A000F.to_bytes(4, "big"),
365+
(TEXT_FORMAT, logging.WARNING): 0xA55A0013.to_bytes(4, "big"),
366+
(TEXT_FORMAT, logging.ERROR): 0xA55A0017.to_bytes(4, "big"),
367+
(TEXT_FORMAT, logging.CRITICAL): 0xA55A001B.to_bytes(4, "big"),
368+
}
369+
DEFAULT_FRAME_TYPE = 0xA55A0003.to_bytes(4, "big")
370+
371+
351372
class FramedTelemetryLogSink(object):
352373
"""
353374
FramedTelemetryLogSink implements the logging contract between runtimes and the platform. It implements a simple
@@ -364,19 +385,8 @@ class FramedTelemetryLogSink(object):
364385
The next 'len' bytes contain the message. The byte order is big-endian.
365386
"""
366387

367-
LEVEL_TO_MASK = {
368-
logging.NOTSET: 0b00000,
369-
logging.DEBUG: 0b01000,
370-
logging.INFO: 0b01100,
371-
logging.WARNING: 0b10000,
372-
logging.ERROR: 0b10100,
373-
logging.FATAL: 0b11000,
374-
}
375-
DEFAULT_LEVEL_MASK = 0b00000
376-
377388
def __init__(self, fd):
378389
self.fd = int(fd)
379-
self.frame_type = 0xA55A0002
380390

381391
def __enter__(self):
382392
self.file = os.fdopen(self.fd, "wb", 0)
@@ -385,16 +395,13 @@ def __enter__(self):
385395
def __exit__(self, exc_type, exc_value, exc_tb):
386396
self.file.close()
387397

388-
def log(self, msg, log_level=logging.NOTSET, log_format=LambdaLogFormat.TEXT):
389-
frame_type = self.frame_type | self.LEVEL_TO_MASK.get(
390-
log_level, self.DEFAULT_LEVEL_MASK
391-
)
392-
frame_type = frame_type | log_format
393-
398+
def log(self, msg, log_level=logging.NOTSET, log_format: int = TEXT_FORMAT):
399+
frame_type = FRAME_TYPES.get((log_format, log_level), DEFAULT_FRAME_TYPE)
394400
encoded_msg = msg.encode("utf8")
401+
395402
timestamp = int(time.time_ns() / 1000) # UNIX timestamp in microseconds
396403
log_msg = (
397-
frame_type.to_bytes(4, "big")
404+
frame_type
398405
+ len(encoded_msg).to_bytes(4, "big")
399406
+ timestamp.to_bytes(8, "big")
400407
+ encoded_msg
@@ -424,10 +431,6 @@ def create_log_sink():
424431
return StandardLogSink()
425432

426433

427-
AWS_LAMBDA_LOG_FORMAT = LambdaLogFormat.from_str(
428-
os.environ.get("AWS_LAMBDA_LOG_FORMAT", "TEXT")
429-
)
430-
AWS_LAMBDA_LOG_LEVEL = os.environ.get("AWS_LAMBDA_LOG_LEVEL")
431434
_GLOBAL_AWS_REQUEST_ID = None
432435

433436

@@ -436,7 +439,7 @@ def setup_logging(log_format, log_level, log_sink):
436439
logging.Formatter.converter = time.gmtime
437440
logger = logging.getLogger()
438441
logger_handler = LambdaLoggerHandler(log_sink)
439-
if log_format == LambdaLogFormat.JSON:
442+
if log_format == JSON_FORMAT:
440443
logger_handler.setFormatter(JsonFormatter())
441444
else:
442445
logger_handler.setFormatter(
@@ -475,7 +478,6 @@ def run(app_root, handler, lambda_runtime_api_addr):
475478

476479
while True:
477480
event_request = lambda_runtime_client.wait_next_invocation()
478-
INVOCATION_LOGGING_CONTEXT.clear()
479481

480482
_GLOBAL_AWS_REQUEST_ID = event_request.invoke_id
481483

awslambdaric/lambda_runtime_log_utils.py

+9-27
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import logging
66
import traceback
7-
from enum import IntFlag
87

98
from .lambda_runtime_marshaller import to_json
109

@@ -33,17 +32,14 @@
3332
"aws_request_id",
3433
"log_format",
3534
}
35+
JSON_FORMAT = 0b0
36+
TEXT_FORMAT = 0b1
3637

3738

38-
class LambdaLogFormat(IntFlag):
39-
JSON = 0b0
40-
TEXT = 0b1
41-
42-
@classmethod
43-
def from_str(cls, value):
44-
if value == "JSON":
45-
return cls.JSON
46-
return cls.TEXT
39+
def get_log_format_from_str(value: str):
40+
if value == "JSON":
41+
return JSON_FORMAT
42+
return TEXT_FORMAT
4743

4844

4945
class JsonFormatter(logging.Formatter):
@@ -79,12 +75,12 @@ def format_location(record: logging.LogRecord):
7975

8076
@staticmethod
8177
def format_log_level(record: logging.LogRecord):
82-
level = record.levelno % 51 // 10 * 10
78+
level = min(50, max(0, record.levelno)) // 10 * 10
8379
record.levelname = logging.getLevelName(level)
8480

8581
def format(self, record: logging.LogRecord) -> str:
8682
self.format_log_level(record)
87-
record.log_format = LambdaLogFormat.JSON
83+
record.log_format = JSON_FORMAT
8884

8985
result = {
9086
"timestamp": self.formatTime(record, self.datefmt),
@@ -102,21 +98,7 @@ def format(self, record: logging.LogRecord) -> str:
10298
for key, value in record.__dict__.items()
10399
if key not in RESERVED_FIELDS and key not in result
104100
)
105-
result.update(INSTANCE_LOGGING_CONTEXT)
106-
result.update(INVOCATION_LOGGING_CONTEXT)
107101

108102
result = {k: v for k, v in result.items() if v is not None}
109103

110-
return to_json(result) + "\n"
111-
112-
113-
INSTANCE_LOGGING_CONTEXT = dict()
114-
INVOCATION_LOGGING_CONTEXT = dict()
115-
116-
117-
def lambda_instance_logs_update(key: str, value):
118-
INSTANCE_LOGGING_CONTEXT[str(key)] = value
119-
120-
121-
def lambda_invocation_logs_update(key: str, value):
122-
INVOCATION_LOGGING_CONTEXT[str(key)] = value
104+
return to_json(result, ensure_ascii=False) + "\n"

awslambdaric/lambda_runtime_marshaller.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# simplejson's Decimal encoding allows '-NaN' as an output, which is a parse error for json.loads
1414
# to get the good parts of Decimal support, we'll special-case NaN decimals and otherwise duplicate the encoding for decimals the same way simplejson does
1515
class Encoder(json.JSONEncoder):
16-
def __init__(self):
17-
super().__init__(use_decimal=False)
16+
def __init__(self, ensure_ascii=True):
17+
super().__init__(use_decimal=False, ensure_ascii=ensure_ascii)
1818

1919
def default(self, obj):
2020
if isinstance(obj, decimal.Decimal):
@@ -24,8 +24,8 @@ def default(self, obj):
2424
return super().default(obj)
2525

2626

27-
def to_json(obj):
28-
return Encoder().encode(obj)
27+
def to_json(obj, ensure_ascii=True):
28+
return Encoder(ensure_ascii=ensure_ascii).encode(obj)
2929

3030

3131
class LambdaMarshaller:

tests/test_bootstrap.py

+8-44
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020
from awslambdaric.lambda_runtime_exception import FaultException
2121
from awslambdaric.lambda_runtime_marshaller import LambdaMarshaller
2222
from awslambdaric.lambda_runtime_log_utils import (
23-
lambda_instance_logs_update,
24-
lambda_invocation_logs_update,
25-
INVOCATION_LOGGING_CONTEXT,
26-
INSTANCE_LOGGING_CONTEXT,
27-
LambdaLogFormat,
23+
JSON_FORMAT,
24+
TEXT_FORMAT,
25+
get_log_format_from_str,
2826
)
2927

3028

@@ -1144,12 +1142,12 @@ def test_log_level_frame_type(self):
11441142
content = f.read()
11451143

11461144
frame_type = int.from_bytes(content[:4], "big")
1147-
self.assertEqual(frame_type, expected_frame_type, hex(frame_type))
1145+
self.assertEqual(bin(frame_type), bin(expected_frame_type))
11481146

11491147
def test_log_format_frame_type(self):
11501148
test_cases = [
1151-
(LambdaLogFormat.TEXT, 0xA55A0003),
1152-
(LambdaLogFormat.JSON, 0xA55A0002),
1149+
(TEXT_FORMAT, 0xA55A0003),
1150+
(JSON_FORMAT, 0xA55A0002),
11531151
]
11541152

11551153
for fmt, expected_frame_type in test_cases:
@@ -1162,7 +1160,7 @@ def test_log_format_frame_type(self):
11621160
content = f.read()
11631161

11641162
frame_type = int.from_bytes(content[:4], "big")
1165-
self.assertEqual(frame_type, expected_frame_type, hex(frame_type))
1163+
self.assertEqual(hex(frame_type), hex(expected_frame_type))
11661164

11671165
def test_single_frame(self):
11681166
with NamedTemporaryFile() as temp_file:
@@ -1230,13 +1228,9 @@ class TestLogging(unittest.TestCase):
12301228
@classmethod
12311229
def setUpClass(cls) -> None:
12321230
bootstrap.setup_logging(
1233-
LambdaLogFormat.from_str("JSON"), "INFO", bootstrap.StandardLogSink()
1231+
get_log_format_from_str("JSON"), "INFO", bootstrap.StandardLogSink()
12341232
)
12351233

1236-
def tearDown(self) -> None:
1237-
INVOCATION_LOGGING_CONTEXT.clear()
1238-
INSTANCE_LOGGING_CONTEXT.clear()
1239-
12401234
@patch("sys.stderr", new_callable=StringIO)
12411235
def test_json_formatter(self, mock_stderr):
12421236
logger = logging.getLogger("a.b")
@@ -1317,36 +1311,6 @@ def test_exception(self, mock_stderr, mock_stdout):
13171311

13181312
self.assertEqual(mock_stderr.getvalue(), "")
13191313

1320-
@patch("sys.stdout", new_callable=StringIO)
1321-
@patch("sys.stderr", new_callable=StringIO)
1322-
def test_log_with_extra_params(self, mock_stderr, mock_stdout):
1323-
lambda_instance_logs_update("instance_key", "instance_value")
1324-
lambda_invocation_logs_update("invocation_key", "invocation_value")
1325-
lambda_invocation_logs_update("int_param", 42)
1326-
lambda_invocation_logs_update("list_param", ["1", 2, {}])
1327-
lambda_invocation_logs_update("dict_param", {"a": "b"})
1328-
1329-
logging.getLogger("test.logger").error("test extra params")
1330-
1331-
data = json.loads(mock_stdout.getvalue())
1332-
data.pop("timestamp")
1333-
1334-
self.assertEqual(
1335-
data,
1336-
{
1337-
"instance_key": "instance_value",
1338-
"invocation_key": "invocation_value",
1339-
"int_param": 42,
1340-
"list_param": ["1", 2, {}],
1341-
"dict_param": {"a": "b"},
1342-
"level": "ERROR",
1343-
"logger": "test.logger",
1344-
"message": "test extra params",
1345-
"requestId": "",
1346-
},
1347-
)
1348-
self.assertEqual(mock_stderr.getvalue(), "")
1349-
13501314
@patch("sys.stdout", new_callable=StringIO)
13511315
@patch("sys.stderr", new_callable=StringIO)
13521316
def test_log_level(self, mock_stderr, mock_stdout):

0 commit comments

Comments
 (0)