Skip to content

Commit 273052c

Browse files
authored
add an environment flag to raise an error if result is not cached (brain-score#8)
1 parent cd21071 commit 273052c

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@ changed through the environment variable `RESULTCACHING_HOME`.
3434
| RESULTCACHING_HOME | directory to cache results (benchmark ceilings) in, `~/.result_caching` by default |
3535
| RESULTCACHING_DISABLE | * `'1'` to disable loading and saving of results, functions will be called directly |
3636
| | * `'candidate_models.score_model,model_tools.activations`' to disable loading and saving of function identifiers starting with one of the specifiers separated by a comma (e.g. any package or function inside `model_tools.activations` will not be considered) |
37+
| RESULTCACHING_CACHEDONLY | If enabled, raises an error when trying to run a function that does not have its result already cached (follows the same matching rules as `RESULTCACHING_DISABLE`) |

result_caching/__init__.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,25 @@ def get_function_identifier(function, call_args):
3333

3434
def is_enabled(function_identifier):
3535
disable = os.getenv('RESULTCACHING_DISABLE', '0')
36-
if disable == '1':
37-
return False
38-
if disable == '':
36+
return not _match_identifier(function_identifier, disable)
37+
38+
39+
def cached_only(function_identifier):
40+
cachedonly = os.getenv('RESULTCACHING_CACHEDONLY', '0')
41+
return _match_identifier(function_identifier, cachedonly)
42+
43+
44+
def _match_identifier(function_identifier, match_value):
45+
if match_value == '1':
3946
return True
40-
disabled_modules = disable.split(',')
41-
return not any(function_identifier.startswith(disabled_module) for disabled_module in disabled_modules)
47+
if match_value == '':
48+
return False
49+
disabled_modules = match_value.split(',')
50+
return any(function_identifier.startswith(disabled_module) for disabled_module in disabled_modules)
51+
52+
53+
class NotCachedError(Exception):
54+
pass
4255

4356

4457
class _Storage(object):
@@ -58,6 +71,8 @@ def wrapper(*args, **kwargs):
5871
if is_enabled(function_identifier) and self.is_stored(function_identifier):
5972
self._logger.debug("Loading from storage: {}".format(function_identifier))
6073
return self.load(function_identifier)
74+
if cached_only(function_identifier):
75+
raise NotCachedError(f"No result stored for '{function_identifier}'")
6176
self._logger.debug("Running function: {}".format(function_identifier))
6277
result = function(*args, **kwargs)
6378
if is_enabled(function_identifier):
@@ -197,6 +212,9 @@ def wrapper(*args, **kwargs):
197212
reduced_call_args = {**non_variable_call_args, **infile_missing_call_args}
198213
self._logger.debug(f"Computing missing: {reduced_call_args}")
199214
if reduced_call_args:
215+
if cached_only(function_identifier):
216+
raise NotCachedError(f"The following arguments for '{function_identifier}' "
217+
f"are not stored: {reduced_call_args}")
200218
# run function if some args are uncomputed
201219
self._logger.debug(f"Running function: {function_identifier}")
202220
result = function(**reduced_call_args)
@@ -283,6 +301,9 @@ def wrapper(*args, **kwargs):
283301
reduced_call_args = {**non_variable_call_args, **missing_call_args}
284302
self._logger.debug(f"Computing missing: {reduced_call_args}")
285303
if reduced_call_args:
304+
if cached_only(function_identifier):
305+
raise NotCachedError(f"The following arguments for '{function_identifier}' "
306+
f"are not stored: {reduced_call_args}")
286307
self._logger.debug(f"Running function: {function_identifier}")
287308
# run function if some args are uncomputed
288309
result = function(**reduced_call_args)

tests/test___init__.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import xarray as xr
88

9-
from result_caching import store_xarray, store, cache, store_dict, get_function_identifier
9+
from result_caching import store_xarray, store, cache, store_dict, get_function_identifier, NotCachedError
1010

1111

1212
class TestFunctionIdentifier:
@@ -296,6 +296,35 @@ def func2(x):
296296

297297
os.environ['RESULTCACHING_DISABLE'] = previous_disable_value
298298

299+
def test_cachedonly_specific(self):
300+
previous_cached_value = os.getenv('RESULTCACHING_CACHEDONLY', '')
301+
with tempfile.TemporaryDirectory() as storage_dir:
302+
os.environ['RESULTCACHING_HOME'] = storage_dir
303+
304+
@store()
305+
def func1(x):
306+
return x
307+
308+
@store()
309+
def func2(x):
310+
return x
311+
312+
# when allowing only cached results from func2, func1 should work, but func2 should err
313+
os.environ['RESULTCACHING_CACHEDONLY'] = 'test___init__.func2'
314+
assert func1(1) == 1
315+
with pytest.raises(NotCachedError):
316+
func2(2)
317+
318+
# when allow reruns, func2 should work again
319+
os.environ['RESULTCACHING_CACHEDONLY'] = ''
320+
assert func2(2) == 2
321+
322+
# when now only allowing cached results again, func2 should work because results are already cached
323+
os.environ['RESULTCACHING_CACHEDONLY'] = 'test___init__.func2'
324+
assert func2(2) == 2
325+
326+
os.environ['RESULTCACHING_CACHEDONLY'] = previous_cached_value
327+
299328

300329
class TestCache:
301330
def test(self):

0 commit comments

Comments
 (0)