Skip to content

Commit 334a211

Browse files
committed
.wip force load
1 parent 8c82b48 commit 334a211

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

pymc_experimental/tests/utils/test_cache.py

+5
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,13 @@ def test_cache_sampling(tmpdir):
1717
assert len(os.listdir(tmpdir)) == 0
1818

1919
prior1, prior2 = (cache_prior(samples=5) for _ in range(2))
20+
prior3 = cache_sampling(pm.sample_prior_predictive, dir=tmpdir, force_sample=True)(
21+
samples=5
22+
)
2023
assert len(os.listdir(tmpdir)) == 1
2124
assert prior1.prior["x"].mean() == prior2.prior["x"].mean()
25+
assert prior2.prior["x"].mean() != prior3.prior["x"].mean()
26+
assert prior2.prior_predictive["y"].mean() != prior3.prior_predictive["y"].mean()
2227

2328
post1, post2 = (cache_post(tune=5, draws=5, progressbar=False) for _ in range(2))
2429
assert len(os.listdir(tmpdir)) == 2

pymc_experimental/utils/cache.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def cache_sampling(
7676
sampling_fn: Literal[sample, sample_prior_predictive, sample_posterior_predictive],
7777
dir: str = "",
7878
force_sample: bool = False,
79+
force_load: bool = True,
7980
) -> Callable:
8081
"""Cache the result of PyMC sampling.
8182
@@ -88,6 +89,7 @@ def cache_sampling(
8889
The directory where the results should be saved or retrieved from. Defaults to working directory.
8990
force_sample: bool, Optional
9091
Whether to force sampling even if cache is found. Defaults to False.
92+
force_load:
9193
9294
Returns
9395
-------
@@ -163,7 +165,9 @@ def wrapped_sampling_fn(*args, model=None, random_seed=None, **kwargs):
163165

164166
if not force_sample and os.path.exists(file_path):
165167
print("Cache hit! Returning stored result", file=sys.stdout)
166-
idata_out = az.from_netcdf(file_path)
168+
idata_out: az.InferenceData = az.from_netcdf(file_path)
169+
if force_load:
170+
idata_out.load()
167171

168172
else:
169173
idata_out = sampling_fn(*args, **kwargs, model=model, random_seed=random_seed)

0 commit comments

Comments
 (0)