Skip to content

Commit d606b9c

Browse files
committed
Add test to address issue pymc-devs#286
1 parent 2df8082 commit d606b9c

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

pymc_experimental/tests/model/test_marginal_model.py

+15
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,21 @@ def test_multiple_dependent_marginalized_rvs():
168168
np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y)
169169

170170

171+
def test_rv_dependent_multiple_marginalized_rvs():
172+
"""Test when random variables depend on multiple marginalized variables"""
173+
with MarginalModel() as m:
174+
x = pm.Bernoulli("x", 0.1)
175+
y = pm.Bernoulli("y", 0.3)
176+
z = pm.DiracDelta("z", c=x + y)
177+
178+
m.marginalize([x, y])
179+
logp = m.compile_logp()
180+
181+
np.testing.assert_allclose(np.exp(logp({"z": 0})), 0.9 * 0.7)
182+
np.testing.assert_allclose(np.exp(logp({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7)
183+
np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3)
184+
185+
171186
@pytest.mark.filterwarnings("error")
172187
def test_nested_marginalized_rvs():
173188
"""Test that marginalization works when there are nested marginalized RVs"""

0 commit comments

Comments
 (0)