Skip to content

Commit 43b40de

Browse files
committed
Allow mutable shape in PartialObservedRVs
1 parent 3729614 commit 43b40de

File tree

3 files changed

+61
-10
lines changed

3 files changed

+61
-10
lines changed

pymc/distributions/distribution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1630,7 +1630,9 @@ def partial_observed_rv_logprob(op, values, dist, mask, **kwargs):
16301630
# For the logp, simply join the values
16311631
[obs_value, unobs_value] = values
16321632
antimask = ~mask
1633-
joined_value = pt.empty(constant_fold([dist.shape])[0])
1633+
# We don't need it to be completely folded, just to avoid any RVs in the graph of the shape
1634+
[folded_shape] = constant_fold([dist.shape], raise_not_constant=False)
1635+
joined_value = pt.empty(folded_shape)
16341636
joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
16351637
joined_value = pt.set_subtensor(joined_value[antimask], obs_value)
16361638
joined_logp = logp(dist, joined_value)

tests/distributions/test_distribution.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -979,16 +979,21 @@ def test_univariate(self, symbolic_rv):
979979
np.testing.assert_allclose(obs_logp, st.norm([1, 2]).logpdf([0.25, 0.5]))
980980
np.testing.assert_allclose(unobs_logp, st.norm([3]).logpdf([0.25]))
981981

982+
@pytest.mark.parametrize("mutable_shape", (False, True))
982983
@pytest.mark.parametrize("obs_component_selected", (True, False))
983-
def test_multivariate_constant_mask_separable(self, obs_component_selected):
984+
def test_multivariate_constant_mask_separable(self, obs_component_selected, mutable_shape):
984985
if obs_component_selected:
985986
mask = np.zeros((1, 4), dtype=bool)
986987
else:
987988
mask = np.ones((1, 4), dtype=bool)
988989
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
989990
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])
990991

991-
rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
992+
if mutable_shape:
993+
shape = (1, pytensor.shared(np.array(4, dtype=int)))
994+
else:
995+
shape = (1, 4)
996+
rv = pm.Dirichlet.dist(pt.arange(shape[-1]) + 1, shape=shape)
992997
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)
993998

994999
# Test types
@@ -1023,6 +1028,10 @@ def test_multivariate_constant_mask_separable(self, obs_component_selected):
10231028
np.testing.assert_allclose(obs_logp, expected_obs_logp)
10241029
np.testing.assert_allclose(unobs_logp, expected_unobs_logp)
10251030

1031+
if mutable_shape:
1032+
shape[-1].set_value(7)
1033+
assert tuple(joined_rv.shape.eval()) == (1, 7)
1034+
10261035
def test_multivariate_constant_mask_unseparable(self):
10271036
mask = pt.constant(np.array([[True, True, False, False]]))
10281037
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
@@ -1097,14 +1106,19 @@ def test_multivariate_shared_mask_separable(self):
10971106
np.testing.assert_almost_equal(obs_logp, new_expected_logp)
10981107
np.testing.assert_array_equal(unobs_logp, [])
10991108

1100-
def test_multivariate_shared_mask_unseparable(self):
1109+
@pytest.mark.parametrize("mutable_shape", (False, True))
1110+
def test_multivariate_shared_mask_unseparable(self, mutable_shape):
11011111
# Even if the mask is initially not mixing support dims,
11021112
# it could later be changed in a way that does!
11031113
mask = shared(np.array([[True, True, True, True]]))
11041114
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
11051115
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])
11061116

1107-
rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
1117+
if mutable_shape:
1118+
shape = mask.shape
1119+
else:
1120+
shape = (1, 4)
1121+
rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=shape)
11081122
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)
11091123

11101124
# Test types
@@ -1134,16 +1148,22 @@ def test_multivariate_shared_mask_unseparable(self):
11341148

11351149
# Test that we can update a shared mask
11361150
mask.set_value(np.array([[False, False, True, True]]))
1151+
equivalent_value = np.array([0.1, 0.4, 0.4, 0.1])
11371152

11381153
assert tuple(obs_rv.shape.eval()) == (2,)
11391154
assert tuple(unobs_rv.shape.eval()) == (2,)
11401155

1141-
new_expected_logp = pm.logp(rv, [0.1, 0.4, 0.4, 0.1]).eval()
1156+
new_expected_logp = pm.logp(rv, equivalent_value).eval()
11421157
assert not np.isclose(expected_logp, new_expected_logp) # Otherwise test is weak
11431158
obs_logp, unobs_logp = logp_fn()
11441159
np.testing.assert_almost_equal(obs_logp, new_expected_logp)
11451160
np.testing.assert_array_equal(unobs_logp, [])
11461161

1162+
if mutable_shape:
1163+
mask.set_value(np.array([[False, False, True, False], [False, False, False, True]]))
1164+
assert tuple(obs_rv.shape.eval()) == (6,)
1165+
assert tuple(unobs_rv.shape.eval()) == (2,)
1166+
11471167
def test_support_point(self):
11481168
x = pm.GaussianRandomWalk.dist(init_dist=pm.Normal.dist(-5), mu=1, steps=9)
11491169
ref_support_point = support_point(x).eval()

tests/model/test_core.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import pytensor.sparse as sparse
2929
import pytensor.tensor as pt
3030
import pytest
31+
import scipy
3132
import scipy.sparse as sps
3233
import scipy.stats as st
3334

@@ -38,7 +39,7 @@
3839

3940
import pymc as pm
4041

41-
from pymc import Deterministic, Model, Potential
42+
from pymc import Deterministic, Model, MvNormal, Potential
4243
from pymc.blocking import DictToArrayBijection, RaveledVars
4344
from pymc.distributions import Normal, transforms
4445
from pymc.distributions.distribution import PartialObservedRV
@@ -1504,11 +1505,39 @@ def test_truncated_normal(self):
15041505
"""
15051506
with Model() as m:
15061507
mu = pm.TruncatedNormal("mu", mu=1, sigma=2, lower=0)
1507-
x = pm.TruncatedNormal(
1508-
"x", mu=mu, sigma=0.5, lower=0, observed=np.array([0.1, 0.2, 0.5, np.nan, np.nan])
1509-
)
1508+
with pytest.warns(ImputationWarning):
1509+
x = pm.TruncatedNormal(
1510+
"x",
1511+
mu=mu,
1512+
sigma=0.5,
1513+
lower=0,
1514+
observed=np.array([0.1, 0.2, 0.5, np.nan, np.nan]),
1515+
)
15101516
m.check_start_vals(m.initial_point())
15111517

1518+
def test_coordinates(self):
1519+
# Regression test for https://github.com/pymc-devs/pymc/issues/7304
1520+
1521+
coords = {"trial": range(30), "feature": range(2)}
1522+
observed = np.zeros((30, 2))
1523+
observed[0, 0] = np.nan
1524+
1525+
with Model(coords=coords) as model:
1526+
with pytest.warns(ImputationWarning):
1527+
MvNormal(
1528+
"y",
1529+
mu=np.zeros(2),
1530+
cov=np.eye(2),
1531+
observed=observed,
1532+
dims=("trial", "feature"),
1533+
)
1534+
1535+
logp_fn = model.compile_logp()
1536+
np.testing.assert_allclose(
1537+
logp_fn({"y_unobserved": [0]}),
1538+
scipy.stats.multivariate_normal.logpdf([0, 0], cov=np.eye(2)) * 30,
1539+
)
1540+
15121541

15131542
class TestShared:
15141543
def test_deterministic(self):

0 commit comments

Comments
 (0)