Skip to content

Commit aa44b88

Browse files
ColCarrollrpgoldman
andcommitted
Make sampler stats accessible after loading (#3773)
* Make sampler stats accessible after loading * Add check to verify the fix. Co-authored-by: rpgoldman <[email protected]>
1 parent 60757fa commit aa44b88

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

pymc3/backends/ndarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,21 @@ def to_metadata(ndarray):
101101
"""Extract ndarray metadata into json-serializable content"""
102102
if ndarray._stats is None:
103103
stats = ndarray._stats
104+
sampler_vars = None
104105
else:
105106
stats = []
107+
sampler_vars = []
106108
for stat in ndarray._stats:
107109
stats.append({key: value.tolist() for key, value in stat.items()})
110+
sampler_vars.append({key: str(value.dtype) for key, value in stat.items()})
111+
108112

109113
metadata = {
110114
'draw_idx': ndarray.draw_idx,
111115
'draws': ndarray.draws,
112116
'_stats': stats,
113117
'chain': ndarray.chain,
118+
'sampler_vars': sampler_vars
114119
}
115120
return metadata
116121

@@ -145,6 +150,9 @@ def load(self, model: Model) -> 'NDArray':
145150

146151
metadata['_stats'] = [{k: np.array(v) for k, v in stat.items()} for stat in metadata['_stats']]
147152

153+
sampler_vars = metadata.pop('sampler_vars')
154+
new_trace._set_sampler_vars(sampler_vars)
155+
148156
for key, value in metadata.items():
149157
setattr(new_trace, key, value)
150158
new_trace.samples = dict(np.load(self.samples_path))

pymc3/tests/test_ndarray_backend.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,9 @@
66
import pytest
77

88

9-
STATS1 = [{
10-
'a': np.float64,
11-
'b': np.bool
12-
}]
9+
STATS1 = [{"a": np.float64, "b": np.bool}]
1310

14-
STATS2 = [{
15-
'a': np.float64
16-
}, {
17-
'a': np.float64,
18-
'b': np.int64,
19-
}]
11+
STATS2 = [{"a": np.float64}, {"a": np.float64, "b": np.int64,}]
2012

2113

2214
class TestNDArray0dSampling(bf.SamplingTestCase):
@@ -152,7 +144,7 @@ class TestMultiTrace_add_remove_values(bf.ModelBackendSampledTestCase):
152144
def test_add_values(self):
153145
mtrace = self.mtrace
154146
orig_varnames = list(mtrace.varnames)
155-
name = 'new_var'
147+
name = "new_var"
156148
vals = mtrace[orig_varnames[0]]
157149
mtrace.add_values({name: vals})
158150
assert len(orig_varnames) == len(mtrace.varnames) - 1
@@ -164,7 +156,6 @@ def test_add_values(self):
164156

165157

166158
class TestSqueezeCat:
167-
168159
def setup_method(self):
169160
self.x = np.arange(10)
170161
self.y = np.arange(10, 20)
@@ -194,13 +185,14 @@ def test_combine_true_squeeze_true(self):
194185
result = base._squeeze_cat([self.x, self.y], True, True)
195186
npt.assert_equal(result, expected)
196187

188+
197189
class TestSaveLoad:
198190
@staticmethod
199191
def model():
200192
with pm.Model() as model:
201-
x = pm.Normal('x', 0, 1)
202-
y = pm.Normal('y', x, 1, observed=2)
203-
z = pm.Normal('z', x + y, 1)
193+
x = pm.Normal("x", 0, 1)
194+
y = pm.Normal("y", x, 1, observed=2)
195+
z = pm.Normal("z", x + y, 1)
204196
return model
205197

206198
@classmethod
@@ -209,12 +201,12 @@ def setup_class(cls):
209201
cls.trace = pm.sample()
210202

211203
def test_save_new_model(self, tmpdir_factory):
212-
directory = str(tmpdir_factory.mktemp('data'))
204+
directory = str(tmpdir_factory.mktemp("data"))
213205
save_dir = pm.save_trace(self.trace, directory, overwrite=True)
214206

215207
assert save_dir == directory
216208
with pm.Model() as model:
217-
w = pm.Normal('w', 0, 1)
209+
w = pm.Normal("w", 0, 1)
218210
new_trace = pm.sample()
219211

220212
with pytest.raises(OSError):
@@ -224,26 +216,32 @@ def test_save_new_model(self, tmpdir_factory):
224216
with model:
225217
new_trace_copy = pm.load_trace(directory)
226218

227-
assert (new_trace['w'] == new_trace_copy['w']).all()
219+
assert (new_trace["w"] == new_trace_copy["w"]).all()
228220

229221
def test_save_and_load(self, tmpdir_factory):
230-
directory = str(tmpdir_factory.mktemp('data'))
222+
directory = str(tmpdir_factory.mktemp("data"))
231223
save_dir = pm.save_trace(self.trace, directory, overwrite=True)
232224

233225
assert save_dir == directory
234226

235227
trace2 = pm.load_trace(directory, model=TestSaveLoad.model())
236228

237-
for var in ('x', 'z'):
229+
for var in ("x", "z"):
238230
assert (self.trace[var] == trace2[var]).all()
239231

232+
assert self.trace.stat_names == trace2.stat_names
233+
for stat in self.trace.stat_names:
234+
assert all(self.trace[stat] == trace2[stat]), (
235+
"Restored value of statistic %s does not match stored value" % stat
236+
)
237+
240238
def test_bad_load(self, tmpdir_factory):
241-
directory = str(tmpdir_factory.mktemp('data'))
239+
directory = str(tmpdir_factory.mktemp("data"))
242240
with pytest.raises(pm.TraceDirectoryError):
243241
pm.load_trace(directory, model=TestSaveLoad.model())
244242

245243
def test_sample_posterior_predictive(self, tmpdir_factory):
246-
directory = str(tmpdir_factory.mktemp('data'))
244+
directory = str(tmpdir_factory.mktemp("data"))
247245
save_dir = pm.save_trace(self.trace, directory, overwrite=True)
248246

249247
assert save_dir == directory

0 commit comments

Comments
 (0)