Skip to content

Commit 022cefa

Browse files
committed
Add NotImplementedError for partial observed multivariate variables
1 parent fd12b83 commit 022cefa

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

pymc/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,11 @@ def make_obs_var(
12391239
)
12401240
warnings.warn(impute_message, ImputationWarning)
12411241

1242+
if rv_var.owner.op.ndim_supp > 0:
1243+
raise NotImplementedError(
1244+
f"Automatic inputation is only supported for univariate RandomVariables, but {rv_var} is multivariate"
1245+
)
1246+
12421247
# We can get a random variable comprised of only the unobserved
12431248
# entries by lifting the indices through the `RandomVariable` `Op`.
12441249

pymc/tests/test_missing.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from numpy import array, ma
2121

22-
from pymc.distributions import Gamma, Normal, Uniform
22+
from pymc.distributions import Dirichlet, Gamma, Normal, Uniform
2323
from pymc.exceptions import ImputationWarning
2424
from pymc.model import Model
2525
from pymc.sampling import sample, sample_posterior_predictive, sample_prior_predictive
@@ -163,3 +163,28 @@ def test_missing_logp():
163163
m_missing_logp = m_missing.logp({"theta1_missing": [2, 4], "theta2_missing": [0, 1, 3]})
164164

165165
assert m_logp == m_missing_logp
166+
167+
168+
def test_missing_multivariate():
169+
"""Test model with missing variables whose transform changes base shape still works"""
170+
171+
with Model() as m_miss:
172+
with pytest.raises(
173+
NotImplementedError,
174+
match="Automatic inputation is only supported for univariate RandomVariables",
175+
):
176+
x = Dirichlet(
177+
"x", a=[1, 2, 3], observed=np.array([[0.3, 0.3, 0.4], [np.nan, np.nan, np.nan]])
178+
)
179+
180+
# TODO: Test can be used when local_subtensor_rv_lift supports multivariate distributions
181+
# from pymc.distributions.transforms import simplex
182+
#
183+
# with Model() as m_unobs:
184+
# x = Dirichlet("x", a=[1, 2, 3])
185+
#
186+
# inp_vals = simplex.forward(np.array([0.3, 0.3, 0.4])).eval()
187+
# assert np.isclose(
188+
# m_miss.logp({"x_missing_simplex__": inp_vals}),
189+
# m_unobs.logp_nojac({"x_simplex__": inp_vals}) * 2,
190+
# )

0 commit comments

Comments
 (0)