Skip to content

Commit 0b9f9cb

Browse files
committed
Specify default_output of MarginalMixtureRV
This is necessary for transforms to work with Aeppl
1 parent 21b289a commit 0b9f9cb

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

pymc/distributions/mixture.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def all_discrete(comp_dists):
5454
class MarginalMixtureRV(OpFromGraph):
5555
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""
5656

57+
default_output = 1
58+
5759

5860
MeasurableVariable.register(MarginalMixtureRV)
5961

@@ -288,11 +290,11 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
288290
)
289291

290292
# Create the actual MarginalMixture variable
291-
mix_indexes_rng_next, mix_out = mix_op(mix_indexes_rng, weights, *components)
293+
mix_out = mix_op(mix_indexes_rng, weights, *components)
292294

293295
# We need to set_default_updates ourselves, because the choices RV is hidden
294296
# inside OpFromGraph and PyMC will never find it otherwise
295-
mix_indexes_rng.default_update = mix_indexes_rng_next
297+
mix_indexes_rng.default_update = mix_out.owner.outputs[0]
296298

297299
# Reference nodes to facilitate identification in other classmethods
298300
mix_out.tag.weights = weights

0 commit comments

Comments
 (0)