Skip to content

Commit 10b6fed

Browse files
Add test for univariate and multivariate marginal mixture
Co-authored-by: Jesse Grabowski <[email protected]>
1 parent 32bc557 commit 10b6fed

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

pymc_experimental/tests/model/test_marginal_model.py

+44
Original file line numberDiff line numberDiff line change
@@ -629,3 +629,47 @@ def test_data_container():
629629

630630
ip = marginal_m.initial_point()
631631
np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip))
632+
633+
634+
@pytest.mark.parametrize("univariate", (True, False))
635+
def test_vector_univariate_mixture(univariate):
636+
637+
with MarginalModel() as m:
638+
idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
639+
640+
def dist(idx, size):
641+
return pm.math.switch(
642+
pm.math.eq(idx, 0),
643+
pm.Normal.dist([-10, -10], 1),
644+
pm.Normal.dist([10, 10], 1),
645+
)
646+
647+
pm.CustomDist("norm", idx, dist=dist)
648+
649+
m.marginalize(idx)
650+
logp_fn = m.compile_logp()
651+
652+
if univariate:
653+
with pm.Model() as ref_m:
654+
pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
655+
else:
656+
with pm.Model() as ref_m:
657+
pm.Mixture(
658+
"norm",
659+
w=[0.5, 0.5],
660+
comp_dists=[
661+
pm.MvNormal.dist([-10, -10], np.eye(2)),
662+
pm.MvNormal.dist([10, 10], np.eye(2)),
663+
],
664+
shape=(2,),
665+
)
666+
ref_logp_fn = ref_m.compile_logp()
667+
668+
for test_value in (
669+
[-10, -10],
670+
[10, 10],
671+
[-10, 10],
672+
[-10, 10],
673+
):
674+
pt = {"norm": test_value}
675+
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))

0 commit comments

Comments
 (0)