Skip to content

Commit 74776cb

Browse files
Add test for univariate and multivariate marginal mixture
Co-authored-by: Jesse Grabowski <[email protected]>
1 parent 8dff969 commit 74776cb

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

pymc_experimental/tests/model/test_marginal_model.py

+44
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,47 @@ def test_is_conditional_dependent_static_shape():
598598
x2 = pt.matrix("x2", shape=(9, 5))
599599
y2 = pt.random.normal(size=pt.shape(x2))
600600
assert not is_conditional_dependent(y2, x2, [x2, y2])
601+
602+
603+
@pytest.mark.parametrize("univariate", (True, False))
604+
def test_vector_univariate_mixture(univariate):
605+
606+
with MarginalModel() as m:
607+
idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
608+
609+
def dist(idx, size):
610+
return pm.math.switch(
611+
pm.math.eq(idx, 0),
612+
pm.Normal.dist([-10, -10], 1),
613+
pm.Normal.dist([10, 10], 1),
614+
)
615+
616+
pm.CustomDist("norm", idx, dist=dist)
617+
618+
m.marginalize(idx)
619+
logp_fn = m.compile_logp()
620+
621+
if univariate:
622+
with pm.Model() as ref_m:
623+
pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
624+
else:
625+
with pm.Model() as ref_m:
626+
pm.Mixture(
627+
"norm",
628+
w=[0.5, 0.5],
629+
comp_dists=[
630+
pm.MvNormal.dist([-10, -10], np.eye(2)),
631+
pm.MvNormal.dist([10, 10], np.eye(2)),
632+
],
633+
shape=(2,),
634+
)
635+
ref_logp_fn = ref_m.compile_logp()
636+
637+
for test_value in (
638+
[-10, -10],
639+
[10, 10],
640+
[-10, 10],
641+
[-10, 10],
642+
):
643+
pt = {"norm": test_value}
644+
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))

0 commit comments

Comments
 (0)