Skip to content

Commit 8c82b48

Browse files
committed
Add function that caches sampling results
1 parent 00d7a2b commit 8c82b48

File tree

3 files changed

+229
-0
lines changed

3 files changed

+229
-0
lines changed

docs/api_reference.rst

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Utils
4949

5050
spline.bspline_interpolation
5151
prior.prior_from_idata
52+
cache.cache_sampling
5253

5354

5455
Statespace Models
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
import pymc as pm
4+
5+
from pymc_experimental.utils.cache import cache_sampling
6+
7+
8+
def test_cache_sampling(tmpdir):
9+
10+
with pm.Model() as m:
11+
x = pm.Normal("x", 0, 1)
12+
y = pm.Normal("y", mu=x, observed=[0, 1, 2])
13+
14+
cache_prior = cache_sampling(pm.sample_prior_predictive, dir=tmpdir)
15+
cache_post = cache_sampling(pm.sample, dir=tmpdir)
16+
cache_pred = cache_sampling(pm.sample_posterior_predictive, dir=tmpdir)
17+
assert len(os.listdir(tmpdir)) == 0
18+
19+
prior1, prior2 = (cache_prior(samples=5) for _ in range(2))
20+
assert len(os.listdir(tmpdir)) == 1
21+
assert prior1.prior["x"].mean() == prior2.prior["x"].mean()
22+
23+
post1, post2 = (cache_post(tune=5, draws=5, progressbar=False) for _ in range(2))
24+
assert len(os.listdir(tmpdir)) == 2
25+
assert post1.posterior["x"].mean() == post2.posterior["x"].mean()
26+
27+
# Change model
28+
with pm.Model() as m:
29+
x = pm.Normal("x", 0, 1)
30+
y = pm.Normal("y", mu=x, observed=[0, 1, 2, 3])
31+
32+
post3 = cache_post(tune=5, draws=5, progressbar=False)
33+
assert len(os.listdir(tmpdir)) == 3
34+
assert post3.posterior["x"].mean() != post1.posterior["x"].mean()
35+
36+
pred1, pred2 = (cache_pred(trace=post3, progressbar=False) for _ in range(2))
37+
assert len(os.listdir(tmpdir)) == 4
38+
assert pred1.posterior_predictive["y"].mean() == pred2.posterior_predictive["y"].mean()
39+
assert "x" not in pred1.posterior_predictive
40+
41+
# Change kwargs
42+
pred3 = cache_pred(trace=post3, progressbar=False, var_names=["x"])
43+
assert len(os.listdir(tmpdir)) == 5
44+
assert "x" in pred3.posterior_predictive

pymc_experimental/utils/cache.py

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import hashlib
2+
import os
3+
import sys
4+
from typing import Callable, Literal
5+
6+
import arviz as az
7+
import numpy as np
8+
import pymc
9+
import pytensor
10+
from pymc import (
11+
modelcontext,
12+
sample,
13+
sample_posterior_predictive,
14+
sample_prior_predictive,
15+
)
16+
from pymc.model.fgraph import fgraph_from_model
17+
from pytensor.compile import SharedVariable
18+
from pytensor.graph import Constant, FunctionGraph, Variable
19+
from pytensor.scalar import ScalarType
20+
from pytensor.tensor import TensorType
21+
from pytensor.tensor.random.type import RandomType
22+
from pytensor.tensor.type_other import NoneTypeT
23+
24+
import pymc_experimental
25+
26+
27+
def hash_data(c: Variable) -> str:
28+
if isinstance(c.type, NoneTypeT):
29+
return "None"
30+
if isinstance(c.type, (ScalarType, TensorType)):
31+
if isinstance(c, Constant):
32+
arr = c.data
33+
elif isinstance(c, SharedVariable):
34+
arr = c.get_value(borrow=True)
35+
arr_data = arr.view(np.uint8) if arr.size > 1 else arr.tobytes()
36+
return hashlib.sha1(arr_data).hexdigest()
37+
else:
38+
raise NotImplementedError(f"Hashing not implemented for type {c.type}")
39+
40+
41+
def get_name_and_props(obj):
42+
name = str(obj)
43+
props = str(getattr(obj, "_props", lambda: {})())
44+
return name, props
45+
46+
47+
def hash_from_fg(fg: FunctionGraph) -> str:
48+
objects_to_hash = []
49+
for node in fg.toposort():
50+
objects_to_hash.append(
51+
(
52+
get_name_and_props(node.op),
53+
tuple(get_name_and_props(inp.type) for inp in node.inputs),
54+
tuple(get_name_and_props(out.type) for out in node.outputs),
55+
# Name is not a symbolic input in the fgraph representation, maybe it should?
56+
tuple(inp.name for inp in node.inputs if inp.name),
57+
tuple(out.name for out in node.outputs if out.name),
58+
)
59+
)
60+
objects_to_hash.append(
61+
tuple(
62+
hash_data(c)
63+
for c in node.inputs
64+
if (
65+
isinstance(c, (Constant, SharedVariable))
66+
# Ignore RNG values
67+
and not isinstance(c.type, RandomType)
68+
)
69+
)
70+
)
71+
str_hash = "\n".join(map(str, objects_to_hash))
72+
return hashlib.sha1(str_hash.encode()).hexdigest()
73+
74+
75+
def cache_sampling(
76+
sampling_fn: Literal[sample, sample_prior_predictive, sample_posterior_predictive],
77+
dir: str = "",
78+
force_sample: bool = False,
79+
) -> Callable:
80+
"""Cache the result of PyMC sampling.
81+
82+
Parameter
83+
---------
84+
sampling_fn: Callable
85+
Must be one of `pymc.sample`, `pymc.sample_prior_predictive` or `pymc.sample_posterior_predictive`.
86+
Positional arguments are disallowed.
87+
dir: string, Optional
88+
The directory where the results should be saved or retrieved from. Defaults to working directory.
89+
force_sample: bool, Optional
90+
Whether to force sampling even if cache is found. Defaults to False.
91+
92+
Returns
93+
-------
94+
cached_sampling_fn: Callable
95+
Function that wraps the sampling_fn. When called, the wrapped function will look for a valid cached result.
96+
A valid cache requires the same:
97+
1. Model and data
98+
2. Sampling function
99+
3. Sampling kwargs, ignoring ``random_seed``, ``trace``, ``progressbar``, ``extend_inferencedata`` and ``compile_kwargs``.
100+
4. PyMC, PyTensor, and PyMC-Experimental versions
101+
If a valid cache is found, sampling is bypassed altogether, unless ``force_sample=True``.
102+
Otherwise, sampling is performed and the result cached for future reuse.
103+
Caching is done on the basis of SHA-1 hashing, and there could be unlikely false positives.
104+
105+
106+
Examples
107+
--------
108+
109+
.. code-block:: python
110+
111+
import pymc as pm
112+
from pymc_experimental.utils.cache import cache_sampling
113+
114+
with pm.Model() as m:
115+
y_data = pm.MutableData("y_data", [0, 1, 2])
116+
x = pm.Normal("x", 0, 1)
117+
y = pm.Normal("y", mu=x, observed=y_data)
118+
119+
cache_sample = cache_sampling(pm.sample, dir="traces")
120+
idata1 = cache_sample(chains=2)
121+
122+
# Cache hit! Returning stored result
123+
idata2 = cache_sample(chains=2)
124+
125+
pm.set_data({"y_data": [1, 1, 1]})
126+
idata3 = cache_sample(chains=2)
127+
128+
assert idata1.posterior["x"].mean() == idata2.posterior["x"].mean()
129+
assert idata1.posterior["x"].mean() != idata3.posterior["x"].mean()
130+
131+
"""
132+
allowed_fns = (sample, sample_prior_predictive, sample_posterior_predictive)
133+
if sampling_fn not in allowed_fns:
134+
raise ValueError(f"Cache sampling can only be used with {allowed_fns}")
135+
136+
def wrapped_sampling_fn(*args, model=None, random_seed=None, **kwargs):
137+
if args:
138+
raise ValueError("Non-keyword arguments not allowed in cache_sampling")
139+
140+
extend_inferencedata = kwargs.pop("extend_inferencedata", False)
141+
142+
# Model hash
143+
model = modelcontext(model)
144+
fg, _ = fgraph_from_model(model)
145+
model_hash = hash_from_fg(fg)
146+
147+
# Sampling hash
148+
sampling_hash_dict = kwargs.copy()
149+
sampling_hash_dict.pop("trace", None)
150+
sampling_hash_dict.pop("random_seed", None)
151+
sampling_hash_dict.pop("progressbar", None)
152+
sampling_hash_dict.pop("compile_kwargs", None)
153+
sampling_hash_dict["sampling_fn"] = str(sampling_fn)
154+
sampling_hash_dict["versions"] = (
155+
pymc.__version__,
156+
pytensor.__version__,
157+
pymc_experimental.__version__,
158+
)
159+
sampling_hash = str(sampling_hash_dict)
160+
161+
file_name = hashlib.sha1((model_hash + sampling_hash).encode()).hexdigest() + ".nc"
162+
file_path = os.path.join(dir, file_name)
163+
164+
if not force_sample and os.path.exists(file_path):
165+
print("Cache hit! Returning stored result", file=sys.stdout)
166+
idata_out = az.from_netcdf(file_path)
167+
168+
else:
169+
idata_out = sampling_fn(*args, **kwargs, model=model, random_seed=random_seed)
170+
if os.path.exists(file_path):
171+
os.remove(file_path)
172+
if not os.path.exists(dir):
173+
os.mkdir(dir)
174+
az.to_netcdf(idata_out, file_path)
175+
176+
# We save inferencedata separately and extend if needed
177+
if extend_inferencedata:
178+
trace = kwargs["trace"]
179+
trace.extend(idata_out)
180+
idata_out = trace
181+
182+
return idata_out
183+
184+
return wrapped_sampling_fn

0 commit comments

Comments
 (0)