@@ -598,3 +598,47 @@ def test_is_conditional_dependent_static_shape():
598
598
x2 = pt .matrix ("x2" , shape = (9 , 5 ))
599
599
y2 = pt .random .normal (size = pt .shape (x2 ))
600
600
assert not is_conditional_dependent (y2 , x2 , [x2 , y2 ])
601
+
602
+
603
+ @pytest .mark .parametrize ("univariate" , (True , False ))
604
+ def test_vector_univariate_mixture (univariate ):
605
+
606
+ with MarginalModel () as m :
607
+ idx = pm .Bernoulli ("idx" , p = 0.5 , shape = (2 ,) if univariate else ())
608
+
609
+ def dist (idx , size ):
610
+ return pm .math .switch (
611
+ pm .math .eq (idx , 0 ),
612
+ pm .Normal .dist ([- 10 , - 10 ], 1 ),
613
+ pm .Normal .dist ([10 , 10 ], 1 ),
614
+ )
615
+
616
+ pm .CustomDist ("norm" , idx , dist = dist )
617
+
618
+ m .marginalize (idx )
619
+ logp_fn = m .compile_logp ()
620
+
621
+ if univariate :
622
+ with pm .Model () as ref_m :
623
+ pm .NormalMixture ("norm" , w = [0.5 , 0.5 ], mu = [[- 10 , 10 ], [- 10 , 10 ]], shape = (2 ,))
624
+ else :
625
+ with pm .Model () as ref_m :
626
+ pm .Mixture (
627
+ "norm" ,
628
+ w = [0.5 , 0.5 ],
629
+ comp_dists = [
630
+ pm .MvNormal .dist ([- 10 , - 10 ], np .eye (2 )),
631
+ pm .MvNormal .dist ([10 , 10 ], np .eye (2 )),
632
+ ],
633
+ shape = (2 ,),
634
+ )
635
+ ref_logp_fn = ref_m .compile_logp ()
636
+
637
+ for test_value in (
638
+ [- 10 , - 10 ],
639
+ [10 , 10 ],
640
+ [- 10 , 10 ],
641
+ [- 10 , 10 ],
642
+ ):
643
+ pt = {"norm" : test_value }
644
+ np .testing .assert_allclose (logp_fn (pt ), ref_logp_fn (pt ))
0 commit comments