Skip to content

Commit b36ccee

Browse files
committed
Implement Mixture logcdf
1 parent 8a2d2bf commit b36ccee

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

pymc/distributions/mixture.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919

2020
from aeppl.abstract import MeasurableVariable, _get_measurable_outputs
21-
from aeppl.logprob import _logprob
21+
from aeppl.logprob import _logcdf, _logprob
2222
from aesara.compile.builders import OpFromGraph
2323
from aesara.tensor import TensorVariable
2424
from aesara.tensor.random.op import RandomVariable
@@ -33,7 +33,7 @@
3333
_get_moment,
3434
get_moment,
3535
)
36-
from pymc.distributions.logprob import logp
36+
from pymc.distributions.logprob import logcdf, logp
3737
from pymc.distributions.shape_utils import to_tuple
3838
from pymc.util import check_dist_not_registered
3939

@@ -275,7 +275,8 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
275275

276276
# Index components and squeeze mixture dimension
277277
mix_out_ = at.take_along_axis(stacked_components_, mix_indexes_padded_, axis=mix_axis)
278-
# There is a Aeasara bug in squeeze with negative axis
278+
# There is a Aesara bug in squeeze with negative axis
279+
# https://github.com/aesara-devs/aesara/issues/830
279280
# this is equivalent to np.squeeze(mix_out_, axis=mix_axis)
280281
mix_out_ = at.squeeze(mix_out_, axis=mix_out_.ndim + mix_axis)
281282

@@ -389,7 +390,8 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
389390
mix_logp = at.logsumexp(at.log(weights) + components_logp, axis=-1)
390391

391392
# Squeeze stack dimension
392-
# There is a Aeasara bug in squeeze with negative axis
393+
# There is a Aesara bug in squeeze with negative axis
394+
# https://github.com/aesara-devs/aesara/issues/830
393395
# mix_logp = at.squeeze(mix_logp, axis=-1)
394396
mix_logp = at.squeeze(mix_logp, axis=mix_logp.ndim - 1)
395397

@@ -404,6 +406,39 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
404406
return mix_logp
405407

406408

409+
@_logcdf.register(MarginalMixtureRV)
410+
def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs):
411+
412+
# single component
413+
if len(components) == 1:
414+
# Need to broadcast value across mixture axis
415+
mix_axis = -components[0].owner.op.ndim_supp - 1
416+
components_logcdf = logcdf(components[0], at.expand_dims(value, mix_axis))
417+
else:
418+
components_logcdf = at.stack(
419+
[logcdf(component, value) for component in components],
420+
axis=-1,
421+
)
422+
423+
mix_logcdf = at.logsumexp(at.log(weights) + components_logcdf, axis=-1)
424+
425+
# Squeeze stack dimension
426+
# There is a Aesara bug in squeeze with negative axis
427+
# https://github.com/aesara-devs/aesara/issues/830
428+
# mix_logp = at.squeeze(mix_logp, axis=-1)
429+
mix_logcdf = at.squeeze(mix_logcdf, axis=mix_logcdf.ndim - 1)
430+
431+
mix_logcdf = check_parameters(
432+
mix_logcdf,
433+
0 <= weights,
434+
weights <= 1,
435+
at.isclose(at.sum(weights, axis=-1), 1),
436+
msg="0 <= weights <= 1, sum(weights) == 1",
437+
)
438+
439+
return mix_logcdf
440+
441+
407442
@_get_moment.register(MarginalMixtureRV)
408443
def get_moment_marginal_mixture(op, rv, rng, weights, *components):
409444
ndim_supp = components[0].owner.op.ndim_supp

0 commit comments

Comments
 (0)