Skip to content

Commit fbd41bb

Browse files
Only pack variables for which prior samples are available
Closes #5337
1 parent 95bd5e5 commit fbd41bb

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

pymc/backends/arviz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def priors_to_xarray(self):
427427
if self.prior is None:
428428
return {"prior": None, "prior_predictive": None}
429429
if self.observations is not None:
430-
prior_predictive_vars = list(self.observations.keys())
430+
prior_predictive_vars = list(set(self.observations).intersection(self.prior))
431431
prior_vars = [key for key in self.prior.keys() if key not in prior_predictive_vars]
432432
else:
433433
prior_vars = list(self.prior.keys())

pymc/tests/test_idata_conversion.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,17 @@ def test_priors_separation(self, use_context):
564564
fails = check_multiple_attrs(test_dict, inference_data)
565565
assert not fails
566566

567+
def test_conversion_from_variables_subset(self):
568+
"""This is a regression test for issue #5337."""
569+
with pm.Model() as model:
570+
x = pm.Normal("x")
571+
pm.Normal("y", x, observed=5)
572+
idata = pm.sample(
573+
tune=10, draws=20, chains=1, step=pm.Metropolis(), compute_convergence_checks=False
574+
)
575+
pm.sample_posterior_predictive(idata, var_names=["x"])
576+
pm.sample_prior_predictive(var_names=["x"])
577+
567578
def test_multivariate_observations(self):
568579
coords = {"direction": ["x", "y", "z"], "experiment": np.arange(20)}
569580
data = np.random.multinomial(20, [0.2, 0.3, 0.5], size=20)

0 commit comments

Comments
 (0)