Skip to content

Commit 7d4006f

Browse files
committed
Round moment of discrete mixtures
1 parent 7ec9a24 commit 7d4006f

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

pymc/distributions/mixture.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pymc.distributions.logprob import logcdf, logp
3737
from pymc.distributions.shape_utils import to_tuple
3838
from pymc.util import check_dist_not_registered
39+
from pymc.vartypes import discrete_types
3940

4041
__all__ = ["Mixture", "NormalMixture"]
4142

@@ -452,7 +453,10 @@ def get_moment_marginal_mixture(op, rv, rng, weights, *components):
452453
axis=mix_axis,
453454
)
454455

455-
return at.sum(weights * moment_components, axis=mix_axis)
456+
moment = at.sum(weights * moment_components, axis=mix_axis)
457+
if components[0].dtype in discrete_types:
458+
moment = at.round(moment)
459+
return moment
456460

457461

458462
class NormalMixture:

0 commit comments

Comments
 (0)