Skip to content

Commit 6570f9d

Browse files
authored
remove save_trace and load_trace function (#5123)
* remove save_trace and load_trace function * fix pre-commit error, change imports to inline
1 parent ace2589 commit 6570f9d

File tree

4 files changed

+2
-175
lines changed

4 files changed

+2
-175
lines changed

pymc/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,7 @@ def __set_compiler_flags():
7676

7777
from pymc import gp, ode, sampling
7878
from pymc.aesaraf import *
79-
from pymc.backends import (
80-
load_trace,
81-
predictions_to_inference_data,
82-
save_trace,
83-
to_inference_data,
84-
)
79+
from pymc.backends import predictions_to_inference_data, to_inference_data
8580
from pymc.backends.tracetab import *
8681
from pymc.bart import *
8782
from pymc.blocking import *

pymc/backends/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,4 @@
6161
6262
"""
6363
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
64-
from pymc.backends.ndarray import (
65-
NDArray,
66-
load_trace,
67-
point_list_to_multitrace,
68-
save_trace,
69-
)
64+
from pymc.backends.ndarray import NDArray, point_list_to_multitrace

pymc/backends/ndarray.py

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -32,94 +32,6 @@
3232
from pymc.model import Model, modelcontext
3333

3434

35-
def save_trace(trace: MultiTrace, directory: Optional[str] = None, overwrite=False) -> str:
36-
"""Save multitrace to file.
37-
38-
TODO: Also save warnings.
39-
40-
This is a custom data format for PyMC traces. Each chain goes inside
41-
a directory, and each directory contains a metadata json file, and a
42-
numpy compressed file. See https://docs.scipy.org/doc/numpy/neps/npy-format.html
43-
for more information about this format.
44-
45-
Parameters
46-
----------
47-
trace: pm.MultiTrace
48-
trace to save to disk
49-
directory: str (optional)
50-
path to a directory to save the trace
51-
overwrite: bool (default False)
52-
whether to overwrite an existing directory.
53-
54-
Returns
55-
-------
56-
str, path to the directory where the trace was saved
57-
"""
58-
warnings.warn(
59-
"The `save_trace` function will soon be removed."
60-
"Instead, use `arviz.to_netcdf` to save traces.",
61-
FutureWarning,
62-
)
63-
64-
if isinstance(trace, MultiTrace):
65-
if directory is None:
66-
directory = ".pymc_{}.trace"
67-
idx = 1
68-
while os.path.exists(directory.format(idx)):
69-
idx += 1
70-
directory = directory.format(idx)
71-
72-
if os.path.isdir(directory):
73-
if overwrite:
74-
shutil.rmtree(directory)
75-
else:
76-
raise OSError(
77-
"Cautiously refusing to overwrite the already existing {}! Please supply "
78-
"a different directory, or set `overwrite=True`".format(directory)
79-
)
80-
os.makedirs(directory)
81-
82-
for chain, ndarray in trace._straces.items():
83-
SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray)
84-
return directory
85-
else:
86-
raise TypeError(
87-
f"You are attempting to save an InferenceData object but this function "
88-
"works only for MultiTrace objects. Use `arviz.to_netcdf` instead"
89-
)
90-
91-
92-
def load_trace(directory: str, model=None) -> MultiTrace:
93-
"""Loads a multitrace that has been written to file.
94-
95-
A the model used for the trace must be passed in, or the command
96-
must be run in a model context.
97-
98-
Parameters
99-
----------
100-
directory: str
101-
Path to a pymc serialized trace
102-
model: pm.Model (optional)
103-
Model used to create the trace. Can also be inferred from context
104-
105-
Returns
106-
-------
107-
pm.Multitrace that was saved in the directory
108-
"""
109-
warnings.warn(
110-
"The `load_trace` function will soon be removed."
111-
"Instead, use `arviz.from_netcdf` to load traces.",
112-
FutureWarning,
113-
)
114-
straces = []
115-
for subdir in glob.glob(os.path.join(directory, "*")):
116-
if os.path.isdir(subdir):
117-
straces.append(SerializeNDArray(subdir).load(model))
118-
if not straces:
119-
raise TraceDirectoryError("%s is not a PyMC saved chain directory." % directory)
120-
return base.MultiTrace(straces)
121-
122-
12335
class SerializeNDArray:
12436
metadata_file = "metadata.json"
12537
samples_file = "samples.npz"

pymc/tests/test_ndarray_backend.py

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -205,78 +205,3 @@ def test_combine_true_squeeze_true(self):
205205
expected = np.concatenate([self.x, self.y])
206206
result = base._squeeze_cat([self.x, self.y], True, True)
207207
npt.assert_equal(result, expected)
208-
209-
210-
class TestSaveLoad:
211-
@staticmethod
212-
def model(rng_seeder=None):
213-
with pm.Model(rng_seeder=rng_seeder) as model:
214-
x = pm.Normal("x", 0, 1)
215-
y = pm.Normal("y", x, 1, observed=2)
216-
z = pm.Normal("z", x + y, 1)
217-
return model
218-
219-
@classmethod
220-
def setup_class(cls):
221-
with TestSaveLoad.model():
222-
cls.trace = pm.sample(return_inferencedata=False)
223-
224-
def test_save_new_model(self, tmpdir_factory):
225-
directory = str(tmpdir_factory.mktemp("data"))
226-
save_dir = pm.save_trace(self.trace, directory, overwrite=True)
227-
228-
assert save_dir == directory
229-
with pm.Model() as model:
230-
w = pm.Normal("w", 0, 1)
231-
new_trace = pm.sample(return_inferencedata=False)
232-
233-
with pytest.raises(OSError):
234-
_ = pm.save_trace(new_trace, directory)
235-
236-
_ = pm.save_trace(new_trace, directory, overwrite=True)
237-
with model:
238-
new_trace_copy = pm.load_trace(directory)
239-
240-
assert (new_trace["w"] == new_trace_copy["w"]).all()
241-
242-
def test_save_and_load(self, tmpdir_factory):
243-
directory = str(tmpdir_factory.mktemp("data"))
244-
save_dir = pm.save_trace(self.trace, directory, overwrite=True)
245-
246-
assert save_dir == directory
247-
248-
trace2 = pm.load_trace(directory, model=TestSaveLoad.model())
249-
250-
for var in ("x", "z"):
251-
assert (self.trace[var] == trace2[var]).all()
252-
253-
assert self.trace.stat_names == trace2.stat_names
254-
for stat in self.trace.stat_names:
255-
assert all(self.trace[stat] == trace2[stat]), (
256-
"Restored value of statistic %s does not match stored value" % stat
257-
)
258-
259-
def test_bad_load(self, tmpdir_factory):
260-
directory = str(tmpdir_factory.mktemp("data"))
261-
with pytest.raises(pm.TraceDirectoryError):
262-
pm.load_trace(directory, model=TestSaveLoad.model())
263-
264-
def test_sample_posterior_predictive(self, tmpdir_factory):
265-
directory = str(tmpdir_factory.mktemp("data"))
266-
save_dir = pm.save_trace(self.trace, directory, overwrite=True)
267-
268-
assert save_dir == directory
269-
270-
rng = np.random.RandomState(10)
271-
272-
with TestSaveLoad.model(rng_seeder=rng):
273-
ppc = pm.sample_posterior_predictive(self.trace)
274-
275-
rng = np.random.RandomState(10)
276-
277-
with TestSaveLoad.model(rng_seeder=rng):
278-
trace2 = pm.load_trace(directory)
279-
ppc2 = pm.sample_posterior_predictive(trace2)
280-
281-
for key, value in ppc.items():
282-
assert (value == ppc2[key]).all()

0 commit comments

Comments
 (0)