Skip to content

feat: add non-repeating config logger #4268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/sagemaker/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
from botocore.utils import merge_dicts
from six.moves.urllib.parse import urlparse
from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA
from sagemaker.config.config_utils import get_sagemaker_config_logger
from sagemaker.config.config_utils import non_repeating_log_factory, get_sagemaker_config_logger

logger = get_sagemaker_config_logger()
log_info_function = non_repeating_log_factory(logger, "info")

_APP_NAME = "sagemaker"
# The default name of the config file.
Expand All @@ -52,7 +53,9 @@
S3_PREFIX = "s3://"


def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource=None) -> dict:
def load_sagemaker_config(
additional_config_paths: List[str] = None, s3_resource=None, repeat_log=False
) -> dict:
"""Loads config files and merges them.

By default, this method first searches for config files in the default locations
Expand Down Expand Up @@ -99,6 +102,8 @@ def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource
<https://boto3.amazonaws.com/v1/documentation/api\
/latest/reference/core/session.html#boto3.session.Session.resource>`__.
This argument is not needed if the config files are present in the local file system.
repeat_log (bool): Whether the log with the same contents should be emitted.
Default to ``False``
"""
default_config_path = os.getenv(
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH
Expand All @@ -109,6 +114,11 @@ def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource
config_paths += additional_config_paths
config_paths = list(filter(lambda item: item is not None, config_paths))
merged_config = {}

log_info = log_info_function
if repeat_log:
log_info = logger.info

for file_path in config_paths:
config_from_file = {}
if file_path.startswith(S3_PREFIX):
Expand All @@ -130,9 +140,9 @@ def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource
if config_from_file:
validate_sagemaker_config(config_from_file)
merge_dicts(merged_config, config_from_file)
logger.info("Fetched defaults config from location: %s", file_path)
log_info("Fetched defaults config from location: %s", file_path)
else:
logger.info("Not applying SDK defaults from location: %s", file_path)
log_info("Not applying SDK defaults from location: %s", file_path)

return merged_config

Expand Down
32 changes: 32 additions & 0 deletions src/sagemaker/config/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
These utils may be used inside or outside the config module.
"""
from __future__ import absolute_import
from collections import deque

import logging
import sys
from typing import Callable


def get_sagemaker_config_logger():
Expand Down Expand Up @@ -197,3 +199,33 @@ def _log_sagemaker_config_merge(
else:
# nothing was specified in the config and nothing is being automatically applied
logger.debug("Skipped value because no value defined\n config key = %s", config_key_path)


def non_repeating_log_factory(logger: logging.Logger, method: str, cache_size=100) -> Callable:
"""Create log function that filters the repeated messages.

By default. It only keeps track of last 100 messages, if a repeated
message arrives after the ``cache_size`` messages, it will be displayed.

Args:
logger (logging.Logger): the logger to be used to dispatch the message.
method (str): the log method, can be info, warning or debug.
cache_size (int): the number of last log messages to keep in cache.
Default to 100

Returns:
(Callable): the new log method
"""
if method not in ["info", "warning", "debug"]:
raise ValueError("Not supported logging method.")

_caches = deque(maxlen=cache_size)
log_method = getattr(logger, method)

def new_log_method(msg, *args, **kwargs):
key = f"{msg}:{args}"
if key not in _caches:
log_method(msg, *args, **kwargs)
_caches.append(key)

return new_log_method
108 changes: 85 additions & 23 deletions tests/unit/sagemaker/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
import pytest
import yaml
import logging
from mock import Mock, MagicMock, patch
from mock import Mock, MagicMock, patch, call

from sagemaker.config.config import (
load_local_mode_config,
load_sagemaker_config,
logger,
non_repeating_log_factory,
_DEFAULT_ADMIN_CONFIG_FILE_PATH,
_DEFAULT_USER_CONFIG_FILE_PATH,
_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH,
Expand All @@ -45,14 +46,14 @@ def expected_merged_config(get_data_dir):


def test_config_when_default_config_file_and_user_config_file_is_not_found():
assert load_sagemaker_config() == {}
assert load_sagemaker_config(repeat_log=True) == {}


def test_config_when_overriden_default_config_file_is_not_found(get_data_dir):
fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml")
os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = fake_config_file_path
with pytest.raises(ValueError):
load_sagemaker_config()
load_sagemaker_config(repeat_log=True)
del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"]


Expand All @@ -63,14 +64,14 @@ def test_invalid_config_file_which_has_python_code(get_data_dir):
# PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using
# yaml.safe_load internally
with pytest.raises(ConstructorError) as exception_info:
load_sagemaker_config(additional_config_paths=[invalid_config_file_path])
load_sagemaker_config(additional_config_paths=[invalid_config_file_path], repeat_log=True)
assert "python/object/apply:eval" in str(exception_info.value)


def test_config_when_additional_config_file_path_is_not_found(get_data_dir):
fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml")
with pytest.raises(ValueError):
load_sagemaker_config(additional_config_paths=[fake_config_file_path])
load_sagemaker_config(additional_config_paths=[fake_config_file_path], repeat_log=True)


def test_config_factory_when_override_user_config_file_is_not_found(get_data_dir):
Expand All @@ -79,15 +80,15 @@ def test_config_factory_when_override_user_config_file_is_not_found(get_data_dir
)
os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = fake_additional_override_config_file_path
with pytest.raises(ValueError):
load_sagemaker_config()
load_sagemaker_config(repeat_log=True)
del os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"]


def test_default_config_file_with_invalid_schema(get_data_dir):
config_file_path = os.path.join(get_data_dir, "invalid_config_file.yaml")
os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = config_file_path
with pytest.raises(exceptions.ValidationError):
load_sagemaker_config()
load_sagemaker_config(repeat_log=True)
del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"]


Expand All @@ -98,7 +99,7 @@ def test_default_config_file_when_directory_is_provided_as_the_path(
expected_config = base_config_with_schema
expected_config["SageMaker"] = valid_config_with_all_the_scopes
os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = get_data_dir
assert expected_config == load_sagemaker_config()
assert expected_config == load_sagemaker_config(repeat_log=True)
del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"]


Expand All @@ -108,7 +109,9 @@ def test_additional_config_paths_when_directory_is_provided(
# This will try to load config.yaml file from that directory if present.
expected_config = base_config_with_schema
expected_config["SageMaker"] = valid_config_with_all_the_scopes
assert expected_config == load_sagemaker_config(additional_config_paths=[get_data_dir])
assert expected_config == load_sagemaker_config(
additional_config_paths=[get_data_dir], repeat_log=True
)


def test_default_config_file_when_path_is_provided_as_environment_variable(
Expand All @@ -118,7 +121,7 @@ def test_default_config_file_when_path_is_provided_as_environment_variable(
# This will try to load config.yaml file from that directory if present.
expected_config = base_config_with_schema
expected_config["SageMaker"] = valid_config_with_all_the_scopes
assert expected_config == load_sagemaker_config()
assert expected_config == load_sagemaker_config(repeat_log=True)
del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"]


Expand All @@ -131,7 +134,9 @@ def test_merge_behavior_when_additional_config_file_path_is_not_found(
)
os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = valid_config_file_path
with pytest.raises(ValueError):
load_sagemaker_config(additional_config_paths=[fake_additional_override_config_file_path])
load_sagemaker_config(
additional_config_paths=[fake_additional_override_config_file_path], repeat_log=True
)
del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"]


Expand All @@ -142,10 +147,10 @@ def test_merge_behavior(get_data_dir, expected_merged_config):
)
os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = valid_config_file_path
assert expected_merged_config == load_sagemaker_config(
additional_config_paths=[additional_override_config_file_path]
additional_config_paths=[additional_override_config_file_path], repeat_log=True
)
os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = additional_override_config_file_path
assert expected_merged_config == load_sagemaker_config()
assert expected_merged_config == load_sagemaker_config(repeat_log=True)
del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"]
del os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"]

Expand All @@ -169,7 +174,7 @@ def test_s3_config_file(
expected_config = base_config_with_schema
expected_config["SageMaker"] = valid_config_with_all_the_scopes
assert expected_config == load_sagemaker_config(
additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock
additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock, repeat_log=True
)


Expand All @@ -183,7 +188,9 @@ def test_config_factory_when_default_s3_config_file_is_not_found(s3_resource_moc
config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix)
with pytest.raises(ValueError):
load_sagemaker_config(
additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock
additional_config_paths=[config_file_s3_uri],
s3_resource=s3_resource_mock,
repeat_log=True,
)


Expand Down Expand Up @@ -213,7 +220,7 @@ def test_s3_config_file_when_uri_provided_corresponds_to_a_path(
expected_config = base_config_with_schema
expected_config["SageMaker"] = valid_config_with_all_the_scopes
assert expected_config == load_sagemaker_config(
additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock
additional_config_paths=[config_file_s3_uri], s3_resource=s3_resource_mock, repeat_log=True
)


Expand Down Expand Up @@ -242,6 +249,7 @@ def test_merge_of_s3_default_config_file_and_regular_config_file(
assert expected_merged_config == load_sagemaker_config(
additional_config_paths=[additional_override_config_file_path],
s3_resource=s3_resource_mock,
repeat_log=True,
)
del os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"]

Expand All @@ -254,7 +262,7 @@ def test_logging_when_overridden_admin_is_found_and_overridden_user_config_is_fo

os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = get_data_dir
os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = get_data_dir
load_sagemaker_config()
load_sagemaker_config(repeat_log=True)
assert "Fetched defaults config from location: {}".format(get_data_dir) in caplog.text
assert (
"Not applying SDK defaults from location: {}".format(_DEFAULT_ADMIN_CONFIG_FILE_PATH)
Expand All @@ -275,7 +283,7 @@ def test_logging_when_overridden_admin_is_found_and_default_user_config_not_foun
logger.propagate = True
caplog.set_level(logging.DEBUG, logger=logger.name)
os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = get_data_dir
load_sagemaker_config()
load_sagemaker_config(repeat_log=True)
assert "Fetched defaults config from location: {}".format(get_data_dir) in caplog.text
assert (
"Not applying SDK defaults from location: {}".format(_DEFAULT_USER_CONFIG_FILE_PATH)
Expand All @@ -297,7 +305,7 @@ def test_logging_when_default_admin_not_found_and_overriden_user_config_is_found
logger.propagate = True
caplog.set_level(logging.DEBUG, logger=logger.name)
os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = get_data_dir
load_sagemaker_config()
load_sagemaker_config(repeat_log=True)
assert "Fetched defaults config from location: {}".format(get_data_dir) in caplog.text
assert (
"Not applying SDK defaults from location: {}".format(_DEFAULT_ADMIN_CONFIG_FILE_PATH)
Expand All @@ -318,7 +326,7 @@ def test_logging_when_default_admin_not_found_and_default_user_config_not_found(
# for admin and user config since both are missing from default location
logger.propagate = True
caplog.set_level(logging.DEBUG, logger=logger.name)
load_sagemaker_config()
load_sagemaker_config(repeat_log=True)
assert (
"Not applying SDK defaults from location: {}".format(_DEFAULT_ADMIN_CONFIG_FILE_PATH)
in caplog.text
Expand All @@ -342,6 +350,26 @@ def test_logging_when_default_admin_not_found_and_default_user_config_not_found(
logger.propagate = False


@patch("sagemaker.config.config.log_info_function")
def test_load_config_without_repeating_log(log_info):

load_sagemaker_config(repeat_log=False)
assert log_info.call_count == 2
log_info.assert_has_calls(
[
call(
"Not applying SDK defaults from location: %s",
_DEFAULT_ADMIN_CONFIG_FILE_PATH,
),
call(
"Not applying SDK defaults from location: %s",
_DEFAULT_USER_CONFIG_FILE_PATH,
),
],
any_order=True,
)


def test_logging_when_default_admin_not_found_and_overriden_user_config_not_found(
get_data_dir, caplog
):
Expand All @@ -351,7 +379,7 @@ def test_logging_when_default_admin_not_found_and_overriden_user_config_not_foun
fake_config_file_path = os.path.join(get_data_dir, "config-not-found.yaml")
os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = fake_config_file_path
with pytest.raises(ValueError):
load_sagemaker_config()
load_sagemaker_config(repeat_log=True)
assert (
"Not applying SDK defaults from location: {}".format(_DEFAULT_ADMIN_CONFIG_FILE_PATH)
in caplog.text
Expand All @@ -374,7 +402,7 @@ def test_logging_when_overriden_admin_not_found_and_overridden_user_config_not_f
os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = fake_config_file_path
os.environ["SAGEMAKER_ADMIN_CONFIG_OVERRIDE"] = fake_config_file_path
with pytest.raises(ValueError):
load_sagemaker_config()
load_sagemaker_config(repeat_log=True)
assert (
"Not applying SDK defaults from location: {}".format(_DEFAULT_ADMIN_CONFIG_FILE_PATH)
not in caplog.text
Expand All @@ -394,7 +422,7 @@ def test_logging_with_additional_configs_and_none_are_found(caplog):
# Should throw exception when config in additional_config_path is missing
logger.propagate = True
with pytest.raises(ValueError):
load_sagemaker_config(additional_config_paths=["fake-path"])
load_sagemaker_config(additional_config_paths=["fake-path"], repeat_log=True)
assert (
"Not applying SDK defaults from location: {}".format(_DEFAULT_ADMIN_CONFIG_FILE_PATH)
in caplog.text
Expand All @@ -414,3 +442,37 @@ def test_load_local_mode_config(mock_load_config):

def test_load_local_mode_config_when_config_file_is_not_found():
assert load_local_mode_config() is None


@pytest.mark.parametrize(
"method_name",
["info", "warning", "debug"],
)
def test_non_repeating_log_factory(method_name):
tmp_logger = logging.getLogger("test-logger")
mock = MagicMock()
setattr(tmp_logger, method_name, mock)

log_function = non_repeating_log_factory(tmp_logger, method_name)
log_function("foo")
log_function("foo")

mock.assert_called_once()


@pytest.mark.parametrize(
"method_name",
["info", "warning", "debug"],
)
def test_non_repeating_log_factory_cache_size(method_name):
tmp_logger = logging.getLogger("test-logger")
mock = MagicMock()
setattr(tmp_logger, method_name, mock)

log_function = non_repeating_log_factory(tmp_logger, method_name, cache_size=2)
log_function("foo")
log_function("bar")
log_function("foo2")
log_function("foo")

assert mock.call_count == 4