diff --git a/pymc/model.py b/pymc/model.py index aa6a8bfbad..7ffa5f6d74 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -900,10 +900,11 @@ def value_vars(self): @property def unobserved_value_vars(self): """List of all random variables (including untransformed projections), - as well as deterministics used as inputs and outputs of the the model's + as well as deterministics used as inputs and outputs of the model's log-likelihood graph """ vars = [] + untransformed_vars = [] for rv in self.free_RVs: value_var = self.rvs_to_values[rv] transform = getattr(value_var.tag, "transform", None) @@ -912,13 +913,16 @@ def unobserved_value_vars(self): # each transformed variable untrans_value_var = transform.backward(value_var, *rv.owner.inputs) untrans_value_var.name = rv.name - vars.append(untrans_value_var) + untransformed_vars.append(untrans_value_var) vars.append(value_var) + # Remove rvs from untransformed values graph + untransformed_vars, _ = rvs_to_value_vars(untransformed_vars, apply_transforms=True) + # Remove rvs from deterministics graph deterministics, _ = rvs_to_value_vars(self.deterministics, apply_transforms=True) - return vars + deterministics + return vars + untransformed_vars + deterministics @property def basic_RVs(self): diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 7c8be684fa..d145631eb8 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -23,6 +23,7 @@ import numpy as np import numpy.testing as npt import pytest +import scipy.special from aesara import shared from arviz import InferenceData @@ -313,6 +314,29 @@ def test_exceptions(self): with pytest.raises(NotImplementedError): xvars = [t["mu"] for t in trace] + def test_deterministic_of_unobserved(self): + with pm.Model() as model: + x = pm.HalfNormal("x", 1) + y = pm.Deterministic("y", x + 100) + idata = pm.sample( + chains=1, + tune=10, + draws=50, + compute_convergence_checks=False, + ) + + np.testing.assert_allclose(idata.posterior["y"], idata.posterior["x"] + 100) + + def test_transform_with_rv_depenency(self): + # Test that untransformed variables that depend on upstream variables are properly handled + with pm.Model() as m: + x = pm.HalfNormal("x", observed=1) + transform = pm.transforms.IntervalTransform(lambda *inputs: (inputs[-2], inputs[-1])) + y = pm.Uniform("y", lower=0, upper=x, transform=transform) + trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336) + + assert np.allclose(scipy.special.expit(trace["y_interval__"]), trace["y"]) + def test_sample_find_MAP_does_not_modify_start(): # see https://github.com/pymc-devs/pymc/pull/4458 @@ -1220,15 +1244,6 @@ def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture): pp = pm.sample_posterior_predictive(idat.posterior, var_names=["d"]) -def test_sample_deterministic(): - with pm.Model() as model: - x = pm.HalfNormal("x", 1) - y = pm.Deterministic("y", x + 100) - idata = pm.sample(chains=1, draws=50, compute_convergence_checks=False) - - np.testing.assert_allclose(idata.posterior["y"], idata.posterior["x"] + 100) - - class TestDraw(SeededTest): def test_univariate(self): with pm.Model():