|
4 | 4 | import numpy as np
|
5 | 5 | import pandas as pd
|
6 | 6 | import pymc as pm
|
7 |
| -import pytensor |
8 | 7 | import pytensor.tensor as pt
|
9 | 8 | import pytest
|
10 | 9 | from arviz import InferenceData, dict_to_dataset
|
11 | 10 | from pymc import ImputationWarning, inputvars
|
12 | 11 | from pymc.distributions import transforms
|
13 | 12 | from pymc.logprob.abstract import _logprob
|
14 | 13 | from pymc.util import UNSET
|
15 |
| -from pytensor.graph import vectorize_graph |
16 | 14 | from scipy.special import logsumexp
|
17 | 15 | from scipy.stats import norm
|
18 | 16 |
|
@@ -291,6 +289,58 @@ def true_logp(y):
|
291 | 289 | )
|
292 | 290 |
|
293 | 291 |
|
| 292 | +@pytest.mark.filterwarnings("error") |
| 293 | +def test_nested_recover_marginals(): |
| 294 | + """Test that marginalization works when there are nested marginalized RVs""" |
| 295 | + |
| 296 | + with MarginalModel() as m: |
| 297 | + idx = pm.Bernoulli("idx", p=0.75) |
| 298 | + sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95)) |
| 299 | + sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0) |
| 300 | + |
| 301 | + m.marginalize([idx, sub_idx]) |
| 302 | + |
| 303 | + rng = np.random.default_rng(211) |
| 304 | + |
| 305 | + with m: |
| 306 | + prior = pm.sample_prior_predictive( |
| 307 | + samples=20, |
| 308 | + random_seed=rng, |
| 309 | + return_inferencedata=False, |
| 310 | + ) |
| 311 | + idata = InferenceData(posterior=dict_to_dataset(prior)) |
| 312 | + |
| 313 | + idata = m.recover_marginals(idata, include_samples=True) |
| 314 | + assert "idx" in idata |
| 315 | + assert "lp_idx" in idata |
| 316 | + assert idata.idx.shape == idata.y.shape |
| 317 | + assert idata.lp_idx.shape == idata.idx.shape + (2,) |
| 318 | + assert "sub_idx" in idata |
| 319 | + assert "lp_sub_idx" in idata |
| 320 | + assert idata.sub_idx.shape == idata.y.shape |
| 321 | + assert idata.lp_sub_idx.shape == idata.sub_idx.shape + (2,) |
| 322 | + |
| 323 | + def true_idx_logp(y): |
| 324 | + idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1)) |
| 325 | + idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) |
| 326 | + return np.stack([idx_0, idx_1]).T |
| 327 | + |
| 328 | + np.testing.assert_almost_equal( |
| 329 | + true_idx_logp(idata.y.values.flatten()), |
| 330 | + idata.lp_idx[0].values, |
| 331 | + ) |
| 332 | + |
| 333 | + def true_sub_idx_logp(y): |
| 334 | + sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1)) |
| 335 | + sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) |
| 336 | + return np.stack([sub_idx_0, sub_idx_1]).T |
| 337 | + |
| 338 | + np.testing.assert_almost_equal( |
| 339 | + true_sub_idx_logp(idata.y.values.flatten()), |
| 340 | + idata.lp_sub_idx[0].values, |
| 341 | + ) |
| 342 | + |
| 343 | + |
294 | 344 | @pytest.mark.filterwarnings("error")
|
295 | 345 | def test_not_supported_marginalized():
|
296 | 346 | """Marginalized graphs with non-Elemwise Operations are not supported as they
|
|
0 commit comments