|
19 | 19 |
|
20 | 20 | from numpy import array, ma
|
21 | 21 |
|
22 |
| -from pymc.distributions import Gamma, Normal, Uniform |
| 22 | +from pymc.distributions import Dirichlet, Gamma, Normal, Uniform |
23 | 23 | from pymc.exceptions import ImputationWarning
|
24 | 24 | from pymc.model import Model
|
25 | 25 | from pymc.sampling import sample, sample_posterior_predictive, sample_prior_predictive
|
@@ -163,3 +163,28 @@ def test_missing_logp():
|
163 | 163 | m_missing_logp = m_missing.logp({"theta1_missing": [2, 4], "theta2_missing": [0, 1, 3]})
|
164 | 164 |
|
165 | 165 | assert m_logp == m_missing_logp
|
| 166 | + |
| 167 | + |
| 168 | +def test_missing_multivariate(): |
| 169 | + """Test model with missing variables whose transform changes base shape still works""" |
| 170 | + |
| 171 | + with Model() as m_miss: |
| 172 | + with pytest.raises( |
| 173 | + NotImplementedError, |
| 174 | + match="Automatic inputation is only supported for univariate RandomVariables", |
| 175 | + ): |
| 176 | + x = Dirichlet( |
| 177 | + "x", a=[1, 2, 3], observed=np.array([[0.3, 0.3, 0.4], [np.nan, np.nan, np.nan]]) |
| 178 | + ) |
| 179 | + |
| 180 | + # TODO: Test can be used when local_subtensor_rv_lift supports multivariate distributions |
| 181 | + # from pymc.distributions.transforms import simplex |
| 182 | + # |
| 183 | + # with Model() as m_unobs: |
| 184 | + # x = Dirichlet("x", a=[1, 2, 3]) |
| 185 | + # |
| 186 | + # inp_vals = simplex.forward(np.array([0.3, 0.3, 0.4])).eval() |
| 187 | + # assert np.isclose( |
| 188 | + # m_miss.logp({"x_missing_simplex__": inp_vals}), |
| 189 | + # m_unobs.logp_nojac({"x_simplex__": inp_vals}) * 2, |
| 190 | + # ) |
0 commit comments