Skip to content

Commit 32d5fc0

Browse files
committed
add tests
1 parent e4d56e3 commit 32d5fc0

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

src/sagemaker/config/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
from botocore.utils import merge_dicts
2929
from six.moves.urllib.parse import urlparse
3030
from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA
31-
from sagemaker.config.config_utils import non_repeating_log, get_sagemaker_config_logger
31+
from sagemaker.config.config_utils import non_repeating_log_factory, get_sagemaker_config_logger
3232

3333
logger = get_sagemaker_config_logger()
34-
log_info_function = non_repeating_log(logger, "info")
34+
log_info_function = non_repeating_log_factory(logger, "info")
3535

3636
_APP_NAME = "sagemaker"
3737
# The default name of the config file.

src/sagemaker/config/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _log_sagemaker_config_merge(
200200
logger.debug("Skipped value because no value defined\n config key = %s", config_key_path)
201201

202202

203-
def non_repeating_log(logger: logging.Logger, method: str) -> Callable:
203+
def non_repeating_log_factory(logger: logging.Logger, method: str) -> Callable:
204204
"""Create log function that filters the repeated messages.
205205
206206
Args:

tests/unit/sagemaker/config/test_config.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
import pytest
1717
import yaml
1818
import logging
19-
from mock import Mock, MagicMock, patch
19+
from mock import Mock, MagicMock, patch, call
2020

2121
from sagemaker.config.config import (
2222
load_local_mode_config,
2323
load_sagemaker_config,
2424
logger,
25+
non_repeating_log_factory,
2526
_DEFAULT_ADMIN_CONFIG_FILE_PATH,
2627
_DEFAULT_USER_CONFIG_FILE_PATH,
2728
_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH,
@@ -349,6 +350,26 @@ def test_logging_when_default_admin_not_found_and_default_user_config_not_found(
349350
logger.propagate = False
350351

351352

353+
@patch("sagemaker.config.config.log_info_function")
354+
def test_load_config_without_repeating_log(log_info):
355+
356+
load_sagemaker_config(repeat_log=False)
357+
assert log_info.call_count == 2
358+
log_info.assert_has_calls(
359+
[
360+
call(
361+
"Not applying SDK defaults from location: %s",
362+
_DEFAULT_ADMIN_CONFIG_FILE_PATH,
363+
),
364+
call(
365+
"Not applying SDK defaults from location: %s",
366+
_DEFAULT_USER_CONFIG_FILE_PATH,
367+
),
368+
],
369+
any_order=True,
370+
)
371+
372+
352373
def test_logging_when_default_admin_not_found_and_overriden_user_config_not_found(
353374
get_data_dir, caplog
354375
):
@@ -421,3 +442,19 @@ def test_load_local_mode_config(mock_load_config):
421442

422443
def test_load_local_mode_config_when_config_file_is_not_found():
423444
assert load_local_mode_config() is None
445+
446+
447+
@pytest.mark.parametrize(
448+
"method_name",
449+
["info", "warning", "debug"],
450+
)
451+
def test_non_repeating_log_factory(method_name):
452+
tmp_logger = logging.getLogger("test-logger")
453+
mock = MagicMock()
454+
setattr(tmp_logger, method_name, mock)
455+
456+
log_function = non_repeating_log_factory(tmp_logger, method_name)
457+
log_function("foo")
458+
log_function("foo")
459+
460+
mock.assert_called_once()

0 commit comments

Comments
 (0)