18
18
import numpy as np
19
19
20
20
from aeppl .abstract import MeasurableVariable , _get_measurable_outputs
21
- from aeppl .logprob import _logprob
21
+ from aeppl .logprob import _logcdf , _logprob
22
22
from aesara .compile .builders import OpFromGraph
23
23
from aesara .tensor import TensorVariable
24
24
from aesara .tensor .random .op import RandomVariable
33
33
_get_moment ,
34
34
get_moment ,
35
35
)
36
- from pymc .distributions .logprob import logp
36
+ from pymc .distributions .logprob import logcdf , logp
37
37
from pymc .distributions .shape_utils import to_tuple
38
38
from pymc .util import check_dist_not_registered
39
39
@@ -275,7 +275,8 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
275
275
276
276
# Index components and squeeze mixture dimension
277
277
mix_out_ = at .take_along_axis (stacked_components_ , mix_indexes_padded_ , axis = mix_axis )
278
- # There is a Aeasara bug in squeeze with negative axis
278
+ # There is a Aesara bug in squeeze with negative axis
279
+ # https://github.com/aesara-devs/aesara/issues/830
279
280
# this is equivalent to np.squeeze(mix_out_, axis=mix_axis)
280
281
mix_out_ = at .squeeze (mix_out_ , axis = mix_out_ .ndim + mix_axis )
281
282
@@ -389,7 +390,8 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
389
390
mix_logp = at .logsumexp (at .log (weights ) + components_logp , axis = - 1 )
390
391
391
392
# Squeeze stack dimension
392
- # There is a Aeasara bug in squeeze with negative axis
393
+ # There is a Aesara bug in squeeze with negative axis
394
+ # https://github.com/aesara-devs/aesara/issues/830
393
395
# mix_logp = at.squeeze(mix_logp, axis=-1)
394
396
mix_logp = at .squeeze (mix_logp , axis = mix_logp .ndim - 1 )
395
397
@@ -404,6 +406,39 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
404
406
return mix_logp
405
407
406
408
409
+ @_logcdf .register (MarginalMixtureRV )
410
+ def marginal_mixture_logcdf (op , value , rng , weights , * components , ** kwargs ):
411
+
412
+ # single component
413
+ if len (components ) == 1 :
414
+ # Need to broadcast value across mixture axis
415
+ mix_axis = - components [0 ].owner .op .ndim_supp - 1
416
+ components_logcdf = logcdf (components [0 ], at .expand_dims (value , mix_axis ))
417
+ else :
418
+ components_logcdf = at .stack (
419
+ [logcdf (component , value ) for component in components ],
420
+ axis = - 1 ,
421
+ )
422
+
423
+ mix_logcdf = at .logsumexp (at .log (weights ) + components_logcdf , axis = - 1 )
424
+
425
+ # Squeeze stack dimension
426
+ # There is a Aesara bug in squeeze with negative axis
427
+ # https://github.com/aesara-devs/aesara/issues/830
428
+ # mix_logp = at.squeeze(mix_logp, axis=-1)
429
+ mix_logcdf = at .squeeze (mix_logcdf , axis = mix_logcdf .ndim - 1 )
430
+
431
+ mix_logcdf = check_parameters (
432
+ mix_logcdf ,
433
+ 0 <= weights ,
434
+ weights <= 1 ,
435
+ at .isclose (at .sum (weights , axis = - 1 ), 1 ),
436
+ msg = "0 <= weights <= 1, sum(weights) == 1" ,
437
+ )
438
+
439
+ return mix_logcdf
440
+
441
+
407
442
@_get_moment .register (MarginalMixtureRV )
408
443
def get_moment_marginal_mixture (op , rv , rng , weights , * components ):
409
444
ndim_supp = components [0 ].owner .op .ndim_supp
0 commit comments