Skip to content

fix: prevent touching preconfigured loggers #249 #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed
- **Logger**: Bugfix to prevent parent loggers with the same name being configured more than once

## [1.9.0] - 2020-12-04

### Added
Expand Down
41 changes: 26 additions & 15 deletions aws_lambda_powertools/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,32 @@ def _get_logger(self):
def _init_logger(self, **kwargs):
"""Configures new logger"""

# Skip configuration if it's a child logger to prevent
# multiple handlers being attached as well as different sampling mechanisms
# and multiple messages from being logged as handlers can be duplicated
if not self.child:
self._configure_sampling()
self._logger.setLevel(self.log_level)
self._logger.addHandler(self._handler)
self.structure_logs(**kwargs)

logger.debug("Adding filter in root logger to suppress child logger records to bubble up")
for handler in logging.root.handlers:
# It'll add a filter to suppress any child logger from self.service
# Where service is Order, it'll reject parent logger Order,
# and child loggers such as Order.checkout, Order.shared
handler.addFilter(SuppressFilter(self.service))
# Skip configuration if it's a child logger or a pre-configured logger
# to prevent the following:
# a) multiple handlers being attached
# b) different sampling mechanisms
# c) multiple messages from being logged as handlers can be duplicated
is_logger_preconfigured = getattr(self._logger, "init", False)
if self.child or is_logger_preconfigured:
return

self._configure_sampling()
self._logger.setLevel(self.log_level)
self._logger.addHandler(self._handler)
self.structure_logs(**kwargs)

logger.debug("Adding filter in root logger to suppress child logger records to bubble up")
for handler in logging.root.handlers:
# It'll add a filter to suppress any child logger from self.service
# Where service is Order, it'll reject parent logger Order,
# and child loggers such as Order.checkout, Order.shared
handler.addFilter(SuppressFilter(self.service))

# as per bug in #249, we should not be pre-configuring an existing logger
# therefore we set a custom attribute in the Logger that will be returned
# std logging will return the same Logger with our attribute if name is reused
logger.debug(f"Marking logger {self.service} as preconfigured")
self._logger.init = True

def _configure_sampling(self):
"""Dynamically set log level based on sampling rate
Expand Down
68 changes: 41 additions & 27 deletions tests/functional/test_aws_lambda_logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""aws_lambda_logging tests."""
import io
import json
import random
import string

import pytest

Expand All @@ -12,9 +14,15 @@ def stdout():
return io.StringIO()


@pytest.fixture
def service_name():
chars = string.ascii_letters + string.digits
return "".join(random.SystemRandom().choice(chars) for _ in range(15))


@pytest.mark.parametrize("level", ["DEBUG", "WARNING", "ERROR", "INFO", "CRITICAL"])
def test_setup_with_valid_log_levels(stdout, level):
logger = Logger(level=level, stream=stdout, request_id="request id!", another="value")
def test_setup_with_valid_log_levels(stdout, level, service_name):
logger = Logger(service=service_name, level=level, stream=stdout, request_id="request id!", another="value")
msg = "This is a test"
log_command = {
"INFO": logger.info,
Expand All @@ -37,8 +45,8 @@ def test_setup_with_valid_log_levels(stdout, level):
assert "exception" not in log_dict


def test_logging_exception_traceback(stdout):
logger = Logger(level="DEBUG", stream=stdout)
def test_logging_exception_traceback(stdout, service_name):
logger = Logger(service=service_name, level="DEBUG", stream=stdout)

try:
raise ValueError("Boom")
Expand All @@ -52,9 +60,9 @@ def test_logging_exception_traceback(stdout):
assert "exception" in log_dict


def test_setup_with_invalid_log_level(stdout):
def test_setup_with_invalid_log_level(stdout, service_name):
with pytest.raises(ValueError) as e:
Logger(level="not a valid log level")
Logger(service=service_name, level="not a valid log level")
assert "Unknown level" in e.value.args[0]


Expand All @@ -65,8 +73,8 @@ def check_log_dict(log_dict):
assert "message" in log_dict


def test_with_dict_message(stdout):
logger = Logger(level="DEBUG", stream=stdout)
def test_with_dict_message(stdout, service_name):
logger = Logger(service=service_name, level="DEBUG", stream=stdout)

msg = {"x": "isx"}
logger.critical(msg)
Expand All @@ -76,8 +84,8 @@ def test_with_dict_message(stdout):
assert msg == log_dict["message"]


def test_with_json_message(stdout):
logger = Logger(stream=stdout)
def test_with_json_message(stdout, service_name):
logger = Logger(service=service_name, stream=stdout)

msg = {"x": "isx"}
logger.info(json.dumps(msg))
Expand All @@ -87,8 +95,8 @@ def test_with_json_message(stdout):
assert msg == log_dict["message"]


def test_with_unserializable_value_in_message(stdout):
logger = Logger(level="DEBUG", stream=stdout)
def test_with_unserializable_value_in_message(stdout, service_name):
logger = Logger(service=service_name, level="DEBUG", stream=stdout)

class Unserializable:
pass
Expand All @@ -101,12 +109,17 @@ class Unserializable:
assert log_dict["message"]["x"].startswith("<")


def test_with_unserializable_value_in_message_custom(stdout):
def test_with_unserializable_value_in_message_custom(stdout, service_name):
class Unserializable:
pass

# GIVEN a custom json_default
logger = Logger(level="DEBUG", stream=stdout, json_default=lambda o: f"<non-serializable: {type(o).__name__}>")
logger = Logger(
service=service_name,
level="DEBUG",
stream=stdout,
json_default=lambda o: f"<non-serializable: {type(o).__name__}>",
)

# WHEN we log a message
logger.debug({"x": Unserializable()})
Expand All @@ -118,9 +131,9 @@ class Unserializable:
assert "json_default" not in log_dict


def test_log_dict_key_seq(stdout):
def test_log_dict_key_seq(stdout, service_name):
# GIVEN the default logger configuration
logger = Logger(stream=stdout)
logger = Logger(service=service_name, stream=stdout)

# WHEN logging a message
logger.info("Message")
Expand All @@ -131,9 +144,9 @@ def test_log_dict_key_seq(stdout):
assert ",".join(list(log_dict.keys())[:4]) == "level,location,message,timestamp"


def test_log_dict_key_custom_seq(stdout):
def test_log_dict_key_custom_seq(stdout, service_name):
# GIVEN a logger configuration with log_record_order set to ["message"]
logger = Logger(stream=stdout, log_record_order=["message"])
logger = Logger(service=service_name, stream=stdout, log_record_order=["message"])

# WHEN logging a message
logger.info("Message")
Expand All @@ -144,9 +157,9 @@ def test_log_dict_key_custom_seq(stdout):
assert list(log_dict.keys())[0] == "message"


def test_log_custom_formatting(stdout):
def test_log_custom_formatting(stdout, service_name):
# GIVEN a logger where we have a custom `location`, 'datefmt' format
logger = Logger(stream=stdout, location="[%(funcName)s] %(module)s", datefmt="fake-datefmt")
logger = Logger(service=service_name, stream=stdout, location="[%(funcName)s] %(module)s", datefmt="fake-datefmt")

# WHEN logging a message
logger.info("foo")
Expand All @@ -158,7 +171,7 @@ def test_log_custom_formatting(stdout):
assert log_dict["timestamp"] == "fake-datefmt"


def test_log_dict_key_strip_nones(stdout):
def test_log_dict_key_strip_nones(stdout, service_name):
# GIVEN a logger confirmation where we set `location` and `timestamp` to None
# Note: level, sampling_rate and service can not be suppressed
logger = Logger(stream=stdout, level=None, location=None, timestamp=None, sampling_rate=None, service=None)
Expand All @@ -170,14 +183,15 @@ def test_log_dict_key_strip_nones(stdout):

# THEN the keys should only include `level`, `message`, `service`, `sampling_rate`
assert sorted(log_dict.keys()) == ["level", "message", "sampling_rate", "service"]
assert log_dict["service"] == "service_undefined"


def test_log_dict_xray_is_present_when_tracing_is_enabled(stdout, monkeypatch):
def test_log_dict_xray_is_present_when_tracing_is_enabled(stdout, monkeypatch, service_name):
# GIVEN a logger is initialized within a Lambda function with X-Ray enabled
trace_id = "1-5759e988-bd862e3fe1be46a994272793"
trace_header = f"Root={trace_id};Parent=53995c3f42cd8ad8;Sampled=1"
monkeypatch.setenv(name="_X_AMZN_TRACE_ID", value=trace_header)
logger = Logger(stream=stdout)
logger = Logger(service=service_name, stream=stdout)

# WHEN logging a message
logger.info("foo")
Expand All @@ -190,9 +204,9 @@ def test_log_dict_xray_is_present_when_tracing_is_enabled(stdout, monkeypatch):
monkeypatch.delenv(name="_X_AMZN_TRACE_ID")


def test_log_dict_xray_is_not_present_when_tracing_is_disabled(stdout, monkeypatch):
def test_log_dict_xray_is_not_present_when_tracing_is_disabled(stdout, monkeypatch, service_name):
# GIVEN a logger is initialized within a Lambda function with X-Ray disabled (default)
logger = Logger(stream=stdout)
logger = Logger(service=service_name, stream=stdout)

# WHEN logging a message
logger.info("foo")
Expand All @@ -203,12 +217,12 @@ def test_log_dict_xray_is_not_present_when_tracing_is_disabled(stdout, monkeypat
assert "xray_trace_id" not in log_dict


def test_log_dict_xray_is_updated_when_tracing_id_changes(stdout, monkeypatch):
def test_log_dict_xray_is_updated_when_tracing_id_changes(stdout, monkeypatch, service_name):
# GIVEN a logger is initialized within a Lambda function with X-Ray enabled
trace_id = "1-5759e988-bd862e3fe1be46a994272793"
trace_header = f"Root={trace_id};Parent=53995c3f42cd8ad8;Sampled=1"
monkeypatch.setenv(name="_X_AMZN_TRACE_ID", value=trace_header)
logger = Logger(stream=stdout)
logger = Logger(service=service_name, stream=stdout)

# WHEN logging a message
logger.info("foo")
Expand Down
Loading