@@ -629,3 +629,47 @@ def test_data_container():
629
629
630
630
ip = marginal_m .initial_point ()
631
631
np .testing .assert_allclose (logp_fn (ip ), ref_logp_fn (ip ))
632
+
633
+
634
+ @pytest .mark .parametrize ("univariate" , (True , False ))
635
+ def test_vector_univariate_mixture (univariate ):
636
+
637
+ with MarginalModel () as m :
638
+ idx = pm .Bernoulli ("idx" , p = 0.5 , shape = (2 ,) if univariate else ())
639
+
640
+ def dist (idx , size ):
641
+ return pm .math .switch (
642
+ pm .math .eq (idx , 0 ),
643
+ pm .Normal .dist ([- 10 , - 10 ], 1 ),
644
+ pm .Normal .dist ([10 , 10 ], 1 ),
645
+ )
646
+
647
+ pm .CustomDist ("norm" , idx , dist = dist )
648
+
649
+ m .marginalize (idx )
650
+ logp_fn = m .compile_logp ()
651
+
652
+ if univariate :
653
+ with pm .Model () as ref_m :
654
+ pm .NormalMixture ("norm" , w = [0.5 , 0.5 ], mu = [[- 10 , 10 ], [- 10 , 10 ]], shape = (2 ,))
655
+ else :
656
+ with pm .Model () as ref_m :
657
+ pm .Mixture (
658
+ "norm" ,
659
+ w = [0.5 , 0.5 ],
660
+ comp_dists = [
661
+ pm .MvNormal .dist ([- 10 , - 10 ], np .eye (2 )),
662
+ pm .MvNormal .dist ([10 , 10 ], np .eye (2 )),
663
+ ],
664
+ shape = (2 ,),
665
+ )
666
+ ref_logp_fn = ref_m .compile_logp ()
667
+
668
+ for test_value in (
669
+ [- 10 , - 10 ],
670
+ [10 , 10 ],
671
+ [- 10 , 10 ],
672
+ [- 10 , 10 ],
673
+ ):
674
+ pt = {"norm" : test_value }
675
+ np .testing .assert_allclose (logp_fn (pt ), ref_logp_fn (pt ))
0 commit comments