Skip to content

Commit 8a2d2bf

Browse files
committed
Test Mixture dtype
1 parent 9b6c019 commit 8a2d2bf

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

pymc/tests/test_mixture.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,25 @@ def test_iterable_single_component_warning(self):
658658
with pytest.warns(UserWarning, match="Single component will be treated as a mixture"):
659659
Mixture.dist(w=[0.5, 0.5], comp_dists=[Normal.dist(size=2)])
660660

661+
def test_mixture_dtype(self):
662+
mix_dtype = Mixture.dist(
663+
w=[0.5, 0.5],
664+
comp_dists=[
665+
Multinomial.dist(n=5, p=[0.5, 0.5]),
666+
Multinomial.dist(n=5, p=[0.5, 0.5]),
667+
],
668+
).dtype
669+
assert mix_dtype == "int64"
670+
671+
mix_dtype = Mixture.dist(
672+
w=[0.5, 0.5],
673+
comp_dists=[
674+
Dirichlet.dist(a=[0.5, 0.5]),
675+
Dirichlet.dist(a=[0.5, 0.5]),
676+
],
677+
).dtype
678+
assert mix_dtype == aesara.config.floatX
679+
661680

662681
class TestNormalMixture(SeededTest):
663682
def test_normal_mixture_sampling(self):

0 commit comments

Comments
 (0)