Skip to content

Commit a61da6c

Browse files
authored
implement disabling specific modules (brain-score#6)
* implement disabling specific modules * make example more specific wrt. brain-score
1 parent b897e53 commit a61da6c

File tree

3 files changed

+79
-8
lines changed

3 files changed

+79
-8
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ changed through the environment variable `RESULTCACHING_HOME`.
3232
| Variable | description |
3333
|-----------------------|----------------------------------------------------------------------------------|
3434
| RESULTCACHING_HOME | directory to cache results (benchmark ceilings) in, `~/.result_caching` by default |
35-
| RESULTCACHING_DISABLE | `'1'` to disable loading and saving of results, functions will be called directly |
35+
| RESULTCACHING_DISABLE | * `'1'` to disable loading and saving of results, functions will be called directly |
36+
| | * `'candidate_models.score_model,model_tools.activations`' to disable loading and saving of function identifiers starting with one of the specifiers separated by a comma (e.g. any package or function inside `model_tools.activations` will not be considered) |

result_caching/__init__.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,14 @@ def get_function_identifier(function, call_args):
3131
return function_identifier
3232

3333

34-
def is_enabled():
35-
return os.getenv('RESULTCACHING_DISABLE', '0') != '1'
34+
def is_enabled(function_identifier):
35+
disable = os.getenv('RESULTCACHING_DISABLE', '0')
36+
if disable == '1':
37+
return False
38+
if disable == '':
39+
return True
40+
disabled_modules = disable.split(',')
41+
return not any(function_identifier.startswith(disabled_module) for disabled_module in disabled_modules)
3642

3743

3844
class _Storage(object):
@@ -49,12 +55,12 @@ def __call__(self, function):
4955
def wrapper(*args, **kwargs):
5056
call_args = self.getcallargs(function, *args, **kwargs)
5157
function_identifier = self.get_function_identifier(function, call_args)
52-
if is_enabled() and self.is_stored(function_identifier):
58+
if is_enabled(function_identifier) and self.is_stored(function_identifier):
5359
self._logger.debug("Loading from storage: {}".format(function_identifier))
5460
return self.load(function_identifier)
5561
self._logger.debug("Running function: {}".format(function_identifier))
5662
result = function(*args, **kwargs)
57-
if is_enabled():
63+
if is_enabled(function_identifier):
5864
self._logger.debug("Saving to storage: {}".format(function_identifier))
5965
self.save(result, function_identifier)
6066
return result
@@ -258,7 +264,7 @@ def wrapper(*args, **kwargs):
258264
if key in self._combine_fields}
259265
function_identifier = self.get_function_identifier(function, call_args)
260266
stored_result, reduced_call_args = None, call_args
261-
if is_enabled() and self.is_stored(function_identifier):
267+
if is_enabled(function_identifier) and self.is_stored(function_identifier):
262268
self._logger.debug(f"Loading from storage: {function_identifier}")
263269
stored_result = self.load(function_identifier)
264270
missing_call_args = self.filter_coords(infile_call_args, stored_result) if not self._sub_fields \

tests/test___init__.py

+66-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import tempfile
3+
from collections import defaultdict
34

45
import numpy as np
56
import pytest
@@ -210,8 +211,8 @@ def __repr__(self):
210211
# second call returns same thing and doesn't actually call function again
211212
assert c.f(1) == 1
212213

213-
def test_disable_store(self):
214-
previous_disable_value = os.getenv('RESULTCACHING_DISABLE', '0')
214+
def test_disable_all(self):
215+
previous_disable_value = os.getenv('RESULTCACHING_DISABLE', '')
215216
with tempfile.TemporaryDirectory() as storage_dir:
216217
os.environ['RESULTCACHING_HOME'] = storage_dir
217218
os.environ['RESULTCACHING_DISABLE'] = '1'
@@ -232,6 +233,69 @@ def func(x):
232233

233234
os.environ['RESULTCACHING_DISABLE'] = previous_disable_value
234235

236+
def test_disable_specific(self):
237+
previous_disable_value = os.getenv('RESULTCACHING_DISABLE', '')
238+
with tempfile.TemporaryDirectory() as storage_dir:
239+
os.environ['RESULTCACHING_HOME'] = storage_dir
240+
os.environ['RESULTCACHING_DISABLE'] = 'test___init__.func1'
241+
242+
function_calls = defaultdict(lambda: 0)
243+
244+
@store()
245+
def func1(x):
246+
nonlocal function_calls
247+
function_calls[1] += 1
248+
return x
249+
250+
@store()
251+
def func2(x):
252+
nonlocal function_calls
253+
function_calls[2] += 1
254+
return x
255+
256+
assert func1(1) == 1
257+
assert function_calls[1] == 1
258+
assert not os.listdir(storage_dir)
259+
assert func1(1) == 1
260+
assert function_calls[1] == 2
261+
262+
assert func2(1) == 1
263+
assert function_calls[2] == 1
264+
assert func2(1) == 1
265+
assert function_calls[2] == 1
266+
267+
os.environ['RESULTCACHING_DISABLE'] = previous_disable_value
268+
269+
def test_disable_module(self):
270+
previous_disable_value = os.getenv('RESULTCACHING_DISABLE', '')
271+
with tempfile.TemporaryDirectory() as storage_dir:
272+
os.environ['RESULTCACHING_HOME'] = storage_dir
273+
os.environ['RESULTCACHING_DISABLE'] = 'test___init__'
274+
275+
function_calls = defaultdict(lambda: 0)
276+
277+
@store()
278+
def func1(x):
279+
nonlocal function_calls
280+
function_calls[1] += 1
281+
return x
282+
283+
@store()
284+
def func2(x):
285+
nonlocal function_calls
286+
function_calls[2] += 1
287+
return x
288+
289+
assert func1(1) == 1
290+
assert func1(1) == 1
291+
assert function_calls[1] == 2
292+
assert func2(1) == 1
293+
assert func2(1) == 1
294+
assert function_calls[2] == 2
295+
assert not os.listdir(storage_dir)
296+
297+
os.environ['RESULTCACHING_DISABLE'] = previous_disable_value
298+
235299

236300
class TestCache:
237301
def test(self):

0 commit comments

Comments
 (0)