File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -658,6 +658,25 @@ def test_iterable_single_component_warning(self):
658
658
with pytest .warns (UserWarning , match = "Single component will be treated as a mixture" ):
659
659
Mixture .dist (w = [0.5 , 0.5 ], comp_dists = [Normal .dist (size = 2 )])
660
660
661
+ def test_mixture_dtype (self ):
662
+ mix_dtype = Mixture .dist (
663
+ w = [0.5 , 0.5 ],
664
+ comp_dists = [
665
+ Multinomial .dist (n = 5 , p = [0.5 , 0.5 ]),
666
+ Multinomial .dist (n = 5 , p = [0.5 , 0.5 ]),
667
+ ],
668
+ ).dtype
669
+ assert mix_dtype == "int64"
670
+
671
+ mix_dtype = Mixture .dist (
672
+ w = [0.5 , 0.5 ],
673
+ comp_dists = [
674
+ Dirichlet .dist (a = [0.5 , 0.5 ]),
675
+ Dirichlet .dist (a = [0.5 , 0.5 ]),
676
+ ],
677
+ ).dtype
678
+ assert mix_dtype == aesara .config .floatX
679
+
661
680
662
681
class TestNormalMixture (SeededTest ):
663
682
def test_normal_mixture_sampling (self ):
You can’t perform that action at this time.
0 commit comments