Skip to content

Commit 838c0d7

Browse files
authored
Assume default_output is the only measurable output in SymbolicRandomVariables (#6161)
1 parent 310a4d9 commit 838c0d7

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

pymc/distributions/distribution.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,14 @@ def dist(
381381
@_get_measurable_outputs.register(SymbolicRandomVariable)
382382
def _get_measurable_outputs_symbolic_random_variable(op, node):
383383
# This tells Aeppl that any non RandomType outputs are measurable
384+
385+
# Assume that if there is one default_output, that's the only one that is measurable
386+
# In the rare case this is not what one wants, a specialized _get_measuarable_outputs
387+
# can dispatch for a subclassed Op
388+
if op.default_output is not None:
389+
return [node.default_output()]
390+
391+
# Otherwise assume that any outputs that are not of RandomType are measurable
384392
return [out for out in node.outputs if not isinstance(out.type, RandomType)]
385393

386394

pymc/distributions/timeseries.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import aesara.tensor as at
2020
import numpy as np
2121

22-
from aeppl.abstract import _get_measurable_outputs
2322
from aeppl.logprob import _logprob
2423
from aesara.graph.basic import Node, clone_replace
2524
from aesara.raise_op import Assert
@@ -203,12 +202,6 @@ def rv_op(cls, init_dist, innovation_dist, steps, size=None):
203202
)(init_dist, innovation_dist, steps)
204203

205204

206-
@_get_measurable_outputs.register(RandomWalkRV)
207-
def _get_measurable_outputs_random_walk(op, node):
208-
# Ignore steps output
209-
return [node.default_output()]
210-
211-
212205
@_change_dist_size.register(RandomWalkRV)
213206
def change_random_walk_size(op, dist, new_size, expand):
214207
init_dist, innovation_dist, steps = dist.owner.inputs

pymc/tests/distributions/test_distribution.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,25 @@ class TestInlinedSymbolicRV(SymbolicRandomVariable):
339339
x_inline = TestInlinedSymbolicRV([], [Flat.dist()], ndim_supp=0)()
340340
assert np.isclose(logp(x_inline, 0).eval(), 0)
341341

342-
def test_measurable_outputs(self):
342+
def test_measurable_outputs_rng_ignored(self):
343+
"""Test that any RandomType outputs are ignored as a measurable_outputs"""
344+
343345
class TestSymbolicRV(SymbolicRandomVariable):
344346
pass
345347

346348
next_rng_, dirac_delta_ = DiracDelta.dist(5).owner.outputs
347349
next_rng, dirac_delta = TestSymbolicRV([], [next_rng_, dirac_delta_], ndim_supp=0)()
348350
node = dirac_delta.owner
349351
assert get_measurable_outputs(node.op, node) == [dirac_delta]
352+
353+
@pytest.mark.parametrize("default_output_idx", (0, 1))
354+
def test_measurable_outputs_default_output(self, default_output_idx):
355+
"""Test that if provided, a default output is considered the only measurable_output"""
356+
357+
class TestSymbolicRV(SymbolicRandomVariable):
358+
default_output = default_output_idx
359+
360+
dirac_delta_1_ = DiracDelta.dist(5)
361+
dirac_delta_2_ = DiracDelta.dist(10)
362+
node = TestSymbolicRV([], [dirac_delta_1_, dirac_delta_2_], ndim_supp=0)().owner
363+
assert get_measurable_outputs(node.op, node) == [node.outputs[default_output_idx]]

0 commit comments

Comments
 (0)