Skip to content

Commit e4c9db7

Browse files
committed
Add nested model test
1 parent 94fc1b1 commit e4c9db7

File tree

1 file changed

+52
-2
lines changed

1 file changed

+52
-2
lines changed

pymc_experimental/tests/model/test_marginal_model.py

+52-2
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44
import numpy as np
55
import pandas as pd
66
import pymc as pm
7-
import pytensor
87
import pytensor.tensor as pt
98
import pytest
109
from arviz import InferenceData, dict_to_dataset
1110
from pymc import ImputationWarning, inputvars
1211
from pymc.distributions import transforms
1312
from pymc.logprob.abstract import _logprob
1413
from pymc.util import UNSET
15-
from pytensor.graph import vectorize_graph
1614
from scipy.special import logsumexp
1715
from scipy.stats import norm
1816

@@ -291,6 +289,58 @@ def true_logp(y):
291289
)
292290

293291

292+
@pytest.mark.filterwarnings("error")
293+
def test_nested_recover_marginals():
294+
"""Test that marginalization works when there are nested marginalized RVs"""
295+
296+
with MarginalModel() as m:
297+
idx = pm.Bernoulli("idx", p=0.75)
298+
sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95))
299+
sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0)
300+
301+
m.marginalize([idx, sub_idx])
302+
303+
rng = np.random.default_rng(211)
304+
305+
with m:
306+
prior = pm.sample_prior_predictive(
307+
samples=20,
308+
random_seed=rng,
309+
return_inferencedata=False,
310+
)
311+
idata = InferenceData(posterior=dict_to_dataset(prior))
312+
313+
idata = m.recover_marginals(idata, include_samples=True)
314+
assert "idx" in idata
315+
assert "lp_idx" in idata
316+
assert idata.idx.shape == idata.y.shape
317+
assert idata.lp_idx.shape == idata.idx.shape + (2,)
318+
assert "sub_idx" in idata
319+
assert "lp_sub_idx" in idata
320+
assert idata.sub_idx.shape == idata.y.shape
321+
assert idata.lp_sub_idx.shape == idata.sub_idx.shape + (2,)
322+
323+
def true_idx_logp(y):
324+
idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1))
325+
idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
326+
return np.stack([idx_0, idx_1]).T
327+
328+
np.testing.assert_almost_equal(
329+
true_idx_logp(idata.y.values.flatten()),
330+
idata.lp_idx[0].values,
331+
)
332+
333+
def true_sub_idx_logp(y):
334+
sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1))
335+
sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
336+
return np.stack([sub_idx_0, sub_idx_1]).T
337+
338+
np.testing.assert_almost_equal(
339+
true_sub_idx_logp(idata.y.values.flatten()),
340+
idata.lp_sub_idx[0].values,
341+
)
342+
343+
294344
@pytest.mark.filterwarnings("error")
295345
def test_not_supported_marginalized():
296346
"""Marginalized graphs with non-Elemwise Operations are not supported as they

0 commit comments

Comments
 (0)