Skip to content

Commit 338ec15

Browse files
explicit error about empirical approx (#5874)
* explicit error about empirical approx * fix an error message * Write more detailed error message Co-authored-by: Michael Osthege <[email protected]>
1 parent 78b1f17 commit 338ec15

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## PyMC 4.0.1 (vNext)
44
+ Fixed an incorrect entry in `pm.Metropolis.stats_dtypes` (see #5582).
5+
+ Added a check in `Empirical` approximation which does not yet support `InferenceData` inputs (see #5874, #5884).
56
+ ...
67

78
## PyMC 4.0.0 (2022-06-03)

pymc/tests/test_variational_inference.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,14 @@ def test_empirical_from_trace(another_simple_model):
791791
assert emp.histogram.shape[0].eval() == 400
792792

793793

794+
def test_empirical_does_not_support_inference_data(another_simple_model):
795+
with another_simple_model:
796+
step = pm.Metropolis()
797+
trace = pm.sample(100, step=step, chains=1, tune=0, return_inferencedata=True)
798+
with pytest.raises(NotImplementedError, match="return_inferencedata=False"):
799+
Empirical(trace)
800+
801+
794802
@pytest.mark.parametrize("score", [True, False])
795803
def test_fit_with_nans(score):
796804
X_mean = pm.floatX(np.linspace(0, 10, 10))

pymc/variational/approximations.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from aesara import tensor as at
1919
from aesara.graph.basic import Variable
2020
from aesara.tensor.var import TensorVariable
21+
from arviz import InferenceData
2122

2223
import pymc as pm
2324

@@ -198,7 +199,6 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None):
198199
# Initialize particles
199200
histogram = np.tile(start, (size, 1))
200201
histogram += pm.floatX(np.random.normal(0, jitter, histogram.shape))
201-
202202
else:
203203
histogram = np.empty((len(trace) * len(trace.chains), self.ddim))
204204
i = 0
@@ -210,7 +210,13 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None):
210210

211211
def _check_trace(self):
212212
trace = self._kwargs.get("trace", None)
213-
if trace is not None and not all(
213+
if isinstance(trace, InferenceData):
214+
raise NotImplementedError(
215+
"The `Empirical` approximation does not yet support `InferenceData` inputs."
216+
" Pass `pm.sample(return_inferencedata=False)` to get a `MultiTrace` to use with `Empirical`."
217+
" Please help us to refactor: https://github.com/pymc-devs/pymc/issues/5884"
218+
)
219+
elif trace is not None and not all(
214220
[self.model.rvs_to_values[var].name in trace.varnames for var in self.group]
215221
):
216222
raise ValueError("trace has not all free RVs in the group")

0 commit comments

Comments
 (0)