Skip to content

Commit 2f8f110

Browse files
ricardoV94twiecki
authored andcommitted
Replace RVs by respective value variables in the graph of untransformed variables
This fixes incorrect behavior, where deterministic projection from transformed (sampling) to untransformed space would be nonsensical when transforms depend on other graph variables
1 parent af6b3c6 commit 2f8f110

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

pymc/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -900,10 +900,11 @@ def value_vars(self):
900900
@property
901901
def unobserved_value_vars(self):
902902
"""List of all random variables (including untransformed projections),
903-
as well as deterministics used as inputs and outputs of the the model's
903+
as well as deterministics used as inputs and outputs of the model's
904904
log-likelihood graph
905905
"""
906906
vars = []
907+
untransformed_vars = []
907908
for rv in self.free_RVs:
908909
value_var = self.rvs_to_values[rv]
909910
transform = getattr(value_var.tag, "transform", None)
@@ -912,13 +913,16 @@ def unobserved_value_vars(self):
912913
# each transformed variable
913914
untrans_value_var = transform.backward(value_var, *rv.owner.inputs)
914915
untrans_value_var.name = rv.name
915-
vars.append(untrans_value_var)
916+
untransformed_vars.append(untrans_value_var)
916917
vars.append(value_var)
917918

919+
# Remove rvs from untransformed values graph
920+
untransformed_vars, _ = rvs_to_value_vars(untransformed_vars, apply_transforms=True)
921+
918922
# Remove rvs from deterministics graph
919923
deterministics, _ = rvs_to_value_vars(self.deterministics, apply_transforms=True)
920924

921-
return vars + deterministics
925+
return vars + untransformed_vars + deterministics
922926

923927
@property
924928
def basic_RVs(self):

pymc/tests/test_sampling.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
import numpy.testing as npt
2525
import pytest
26+
import scipy.special
2627

2728
from aesara import shared
2829
from arviz import InferenceData
@@ -326,6 +327,16 @@ def test_deterministic_of_unobserved(self):
326327

327328
np.testing.assert_allclose(idata.posterior["y"], idata.posterior["x"] + 100)
328329

330+
def test_transform_with_rv_depenency(self):
331+
# Test that untransformed variables that depend on upstream variables are properly handled
332+
with pm.Model() as m:
333+
x = pm.HalfNormal("x", observed=1)
334+
transform = pm.transforms.IntervalTransform(lambda *inputs: (inputs[-2], inputs[-1]))
335+
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
336+
trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336)
337+
338+
assert np.allclose(scipy.special.expit(trace["y_interval__"]), trace["y"])
339+
329340

330341
def test_sample_find_MAP_does_not_modify_start():
331342
# see https://github.com/pymc-devs/pymc/pull/4458

0 commit comments

Comments
 (0)