Skip to content

Commit 1f20998

Browse files
committed
Remove remaining uses of default_updates in codebase
1 parent 11f08db commit 1f20998

File tree

4 files changed

+46
-19
lines changed

4 files changed

+46
-19
lines changed

pymc/aesaraf.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import numpy as np
3131
import scipy.sparse as sps
3232

33+
from aeppl.abstract import MeasurableVariable
3334
from aeppl.logprob import CheckParameterValue
3435
from aesara import config, scalar
3536
from aesara.compile.mode import Mode, get_mode
@@ -978,14 +979,21 @@ def compile_pymc(
978979
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
979980
rng_updates = {}
980981
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
981-
for rv in (
982-
node
983-
for node in vars_between(inputs, output_to_list)
984-
if node.owner and isinstance(node.owner.op, RandomVariable) and node not in inputs
982+
for random_var in (
983+
var
984+
for var in vars_between(inputs, output_to_list)
985+
if var.owner
986+
and isinstance(var.owner.op, (RandomVariable, MeasurableVariable))
987+
and var not in inputs
985988
):
986-
rng = rv.owner.inputs[0]
987-
if not hasattr(rng, "default_update"):
988-
rng_updates[rng] = rv.owner.outputs[0]
989+
if isinstance(random_var.owner.op, RandomVariable):
990+
rng = random_var.owner.inputs[0]
991+
if not hasattr(rng, "default_update"):
992+
rng_updates[rng] = random_var.owner.outputs[0]
993+
else:
994+
update_fn = getattr(random_var.owner.op, "update", None)
995+
if update_fn is not None:
996+
rng_updates.update(update_fn(random_var.owner))
989997

990998
# If called inside a model context, see if check_bounds flag is set to False
991999
try:

pymc/distributions/mixture.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from aeppl.logprob import _logcdf, _logprob
2222
from aeppl.transforms import IntervalTransform
2323
from aesara.compile.builders import OpFromGraph
24-
from aesara.graph.basic import equal_computations
24+
from aesara.graph.basic import Node, equal_computations
2525
from aesara.tensor import TensorVariable
2626
from aesara.tensor.random.op import RandomVariable
2727

@@ -44,6 +44,10 @@ class MarginalMixtureRV(OpFromGraph):
4444

4545
default_output = 1
4646

47+
def update(self, node: Node):
48+
# Update for the internal mix_indexes RV
49+
return {node.inputs[0]: node.outputs[0]}
50+
4751

4852
MeasurableVariable.register(MarginalMixtureRV)
4953

@@ -294,10 +298,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
294298
# Create the actual MarginalMixture variable
295299
mix_out = mix_op(mix_indexes_rng, weights, *components)
296300

297-
# We need to set_default_updates ourselves, because the choices RV is hidden
298-
# inside OpFromGraph and PyMC will never find it otherwise
299-
mix_indexes_rng.default_update = mix_out.owner.outputs[0]
300-
301301
# Reference nodes to facilitate identification in other classmethods
302302
mix_out.tag.weights = weights
303303
mix_out.tag.components = components

pymc/model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,13 +1363,9 @@ def make_obs_var(
13631363
# size of the masked and unmasked array happened to coincide
13641364
_, size, _, *inps = observed_rv_var.owner.inputs
13651365
rng = self.model.next_rng()
1366-
observed_rv_var = observed_rv_var.owner.op(*inps, size=size, rng=rng)
1367-
# Add default_update to new rng
1368-
new_rng = observed_rv_var.owner.outputs[0]
1369-
observed_rv_var.update = (rng, new_rng)
1370-
rng.default_update = new_rng
1371-
observed_rv_var.name = f"{name}_observed"
1372-
1366+
observed_rv_var = observed_rv_var.owner.op(
1367+
*inps, size=size, rng=rng, name=f"{name}_observed"
1368+
)
13731369
observed_rv_var.tag.observations = nonmissing_data
13741370

13751371
self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)

pymc/tests/test_aesaraf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import pytest
2323
import scipy.sparse as sps
2424

25+
from aeppl.abstract import MeasurableVariable
2526
from aeppl.logprob import ParameterValueError
27+
from aesara.compile.builders import OpFromGraph
2628
from aesara.graph.basic import Constant, Variable, ancestors, equal_computations
2729
from aesara.tensor.random.basic import normal, uniform
2830
from aesara.tensor.random.op import RandomVariable
@@ -681,3 +683,24 @@ def test_compile_pymc_updates_inputs(self):
681683
assert len(fn_fgraph.apply_nodes) == max(rvs_in_graph, 1)
682684
# Each RV adds a shared output for its rng
683685
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph
686+
687+
def test_compile_pymc_custom_update_op(self):
688+
"""Test that custom MeasurableVariable Op updates are used by compile_pymc"""
689+
690+
class UnmeasurableOp(OpFromGraph):
691+
def update(self, node):
692+
return {node.inputs[0]: node.inputs[0] + 1}
693+
694+
dummy_inputs = [at.scalar(), at.scalar()]
695+
dummy_outputs = [at.add(*dummy_inputs)]
696+
dummy_x = UnmeasurableOp(dummy_inputs, dummy_outputs)(aesara.shared(1.0), 1.0)
697+
698+
# Check that there are no updates at first
699+
fn = compile_pymc(inputs=[], outputs=dummy_x)
700+
assert fn() == fn() == 2.0
701+
702+
# And they are enabled once the Op is registered as Measurable
703+
MeasurableVariable.register(UnmeasurableOp)
704+
fn = compile_pymc(inputs=[], outputs=dummy_x)
705+
assert fn() == 2.0
706+
assert fn() == 3.0

0 commit comments

Comments
 (0)