Skip to content

Commit b51a653

Browse files
trungleducakrishna1995
authored andcommitted
Limit cache size
1 parent e69d415 commit b51a653

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

src/sagemaker/config/config_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
These utils may be used inside or outside the config module.
1616
"""
1717
from __future__ import absolute_import
18+
from collections import deque
1819

1920
import logging
2021
import sys
@@ -200,26 +201,31 @@ def _log_sagemaker_config_merge(
200201
logger.debug("Skipped value because no value defined\n config key = %s", config_key_path)
201202

202203

203-
def non_repeating_log_factory(logger: logging.Logger, method: str) -> Callable:
204+
def non_repeating_log_factory(logger: logging.Logger, method: str, cache_size=100) -> Callable:
204205
"""Create log function that filters the repeated messages.
205206
207+
By default. It only keeps track of last 100 messages, if a repeated
208+
message arrives after the ``cache_size`` messages, it will be displayed.
209+
206210
Args:
207211
logger (logging.Logger): the logger to be used to dispatch the message.
208212
method (str): the log method, can be info, warning or debug.
213+
cache_size (int): the number of last log messages to keep in cache.
214+
Default to 100
209215
210216
Returns:
211217
(Callable): the new log method
212218
"""
213219
if method not in ["info", "warning", "debug"]:
214220
raise ValueError("Not supported logging method.")
215221

216-
_caches = set()
222+
_caches = deque(maxlen=cache_size)
217223
log_method = getattr(logger, method)
218224

219225
def new_log_method(msg, *args, **kwargs):
220226
key = f"{msg}:{args}"
221227
if key not in _caches:
222228
log_method(msg, *args, **kwargs)
223-
_caches.add(key)
229+
_caches.append(key)
224230

225231
return new_log_method

tests/unit/sagemaker/config/test_config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,3 +458,21 @@ def test_non_repeating_log_factory(method_name):
458458
log_function("foo")
459459

460460
mock.assert_called_once()
461+
462+
463+
@pytest.mark.parametrize(
464+
"method_name",
465+
["info", "warning", "debug"],
466+
)
467+
def test_non_repeating_log_factory_cache_size(method_name):
468+
tmp_logger = logging.getLogger("test-logger")
469+
mock = MagicMock()
470+
setattr(tmp_logger, method_name, mock)
471+
472+
log_function = non_repeating_log_factory(tmp_logger, method_name, cache_size=2)
473+
log_function("foo")
474+
log_function("bar")
475+
log_function("foo2")
476+
log_function("foo")
477+
478+
assert mock.call_count == 4

0 commit comments

Comments
 (0)