Skip to content

Commit 534b89b

Browse files
williambdeanmichaelosthege
authored andcommitted
assume supports_sample_stats=True
1 parent d47dac0 commit 534b89b

File tree

4 files changed

+13
-38
lines changed

4 files changed

+13
-38
lines changed

pymc/backends/base.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ class BaseTrace(ABC):
5252
use different test point that might be with changed variables shapes
5353
"""
5454

55-
supports_sampler_stats = False
56-
5755
def __init__(self, name, model=None, vars=None, test_point=None):
5856
self.name = name
5957

@@ -88,9 +86,6 @@ def _add_warnings(self, warnings):
8886
# Sampling methods
8987

9088
def _set_sampler_vars(self, sampler_vars):
91-
if sampler_vars is not None and not self.supports_sampler_stats:
92-
raise ValueError("Backend does not support sampler stats.")
93-
9489
if self._is_base_setup and self.sampler_vars != sampler_vars:
9590
raise ValueError("Can't change sampler_vars")
9691

@@ -117,9 +112,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
117112
chain: int
118113
Chain number
119114
sampler_vars: list of dictionaries (name -> dtype), optional
120-
Diagnostics / statistics for each sampler. Before passing this
121-
to a backend, you should check, that the `supports_sampler_state`
122-
flag is set.
115+
Diagnostics / statistics for each sampler
123116
"""
124117
self._set_sampler_vars(sampler_vars)
125118
self._is_base_setup = True
@@ -190,9 +183,6 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
190183
a numpy array of shape (m, n), where `m` is the number of
191184
such samplers, and `n` is the number of samples.
192185
"""
193-
if not self.supports_sampler_stats:
194-
raise ValueError("This backend does not support sampler stats")
195-
196186
if sampler_idx is not None:
197187
return self._get_sampler_stats(stat_name, sampler_idx, burn, thin)
198188

@@ -232,13 +222,11 @@ def point(self, idx):
232222

233223
@property
234224
def stat_names(self):
235-
if self.supports_sampler_stats:
236-
names = set()
237-
for vars in self.sampler_vars or []:
238-
names.update(vars.keys())
239-
return names
240-
else:
241-
return set()
225+
names = set()
226+
for vars in self.sampler_vars or []:
227+
names.update(vars.keys())
228+
229+
return names
242230

243231

244232
class MultiTrace:
@@ -356,7 +344,7 @@ def __getitem__(self, idx):
356344
return self.get_sampler_stats(var, burn=burn, thin=thin)
357345
raise KeyError("Unknown variable %s" % var)
358346

359-
_attrs = {"_straces", "varnames", "chains", "stat_names", "supports_sampler_stats", "_report"}
347+
_attrs = {"_straces", "varnames", "chains", "stat_names", "_report"}
360348

361349
def __getattr__(self, name):
362350
# Avoid infinite recursion when called before __init__

pymc/backends/ndarray.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ class NDArray(base.BaseTrace):
4141
`model.unobserved_RVs` is used.
4242
"""
4343

44-
supports_sampler_stats = True
45-
4644
def __init__(self, name=None, model=None, vars=None, test_point=None):
4745
super().__init__(name, model, vars, test_point)
4846
self.draw_idx = 0

pymc/sampling.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,11 +1048,8 @@ def _iter_sample(
10481048
step = stop_tuning(step)
10491049
if step.generates_stats:
10501050
point, stats = step.step(point)
1051-
if strace.supports_sampler_stats:
1052-
strace.record(point, stats)
1053-
diverging = i > tune and stats and stats[0].get("diverging")
1054-
else:
1055-
strace.record(point)
1051+
strace.record(point, stats)
1052+
diverging = i > tune and stats and stats[0].get("diverging")
10561053
else:
10571054
point = step.step(point)
10581055
strace.record(point)
@@ -1359,10 +1356,7 @@ def _iter_population(
13591356
for c, strace in enumerate(traces):
13601357
if steppers[c].generates_stats:
13611358
points[c], stats = updates[c]
1362-
if strace.supports_sampler_stats:
1363-
strace.record(points[c], stats)
1364-
else:
1365-
strace.record(points[c])
1359+
strace.record(points[c], stats)
13661360
else:
13671361
points[c] = updates[c]
13681362
strace.record(points[c])
@@ -1428,7 +1422,7 @@ def _init_trace(
14281422
else:
14291423
strace = _choose_backend(None, model=model)
14301424

1431-
if step.generates_stats and strace.supports_sampler_stats:
1425+
if step.generates_stats:
14321426
strace.setup(expected_length, chain_number, step.stats_dtypes)
14331427
else:
14341428
strace.setup(expected_length, chain_number)
@@ -1520,7 +1514,7 @@ def _mp_sample(
15201514
with sampler:
15211515
for draw in sampler:
15221516
strace = traces[draw.chain]
1523-
if strace.supports_sampler_stats and draw.stats is not None:
1517+
if draw.stats is not None:
15241518
strace.record(draw.point, draw.stats)
15251519
else:
15261520
strace.record(draw.point)

pymc/tests/backends/fixtures.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def setup_method(self):
5151
if not hasattr(self, "sampler_vars"):
5252
self.sampler_vars = None
5353
if self.sampler_vars is not None:
54-
assert self.strace.supports_sampler_stats
5554
self.strace.setup(self.draws, self.chain, self.sampler_vars)
5655
else:
5756
self.strace.setup(self.draws, self.chain)
@@ -110,11 +109,7 @@ def test_bad_dtype(self):
110109
with pytest.raises((ValueError, TypeError)):
111110
strace.setup(self.draws, self.chain, bad_vars)
112111
strace.setup(self.draws, self.chain, good_vars)
113-
if strace.supports_sampler_stats:
114-
assert strace.stat_names == {"a"}
115-
else:
116-
with pytest.raises((ValueError, TypeError)):
117-
strace.setup(self.draws, self.chain, good_vars)
112+
assert strace.stat_names == {"a"}
118113

119114
def teardown_method(self):
120115
if self.name is not None:

0 commit comments

Comments
 (0)