From 7dc7a82b0619effb7d41aec20ed3919118a3e567 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 17 May 2023 15:56:15 +0200 Subject: [PATCH] Rename _replace_rvs_in_graphs and fix bug when replacing input --- pymc/logprob/utils.py | 4 ++-- pymc/pytensorf.py | 14 +++++++++----- tests/test_pytensorf.py | 20 ++++++++++++++++++++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index c44e88a500..d095dbee59 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -337,7 +337,7 @@ def ignore_logprob_multiple_vars( making each "unmeasurable", whereas a sequential call to `ignore_logprob` would not do this correctly. """ - from pymc.pytensorf import _replace_rvs_in_graphs + from pymc.pytensorf import _replace_vars_in_graphs measurable_vars_to_unmeasurable_vars = { measurable_var: ignore_logprob(measurable_var) for measurable_var in vars @@ -353,5 +353,5 @@ def replacement_fn(var, replacements): return [] - unmeasurable_vars, _ = _replace_rvs_in_graphs(graphs=vars, replacement_fn=replacement_fn) + unmeasurable_vars, _ = _replace_vars_in_graphs(graphs=vars, replacement_fn=replacement_fn) return unmeasurable_vars diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 3c9cb945be..d675b6c040 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -205,12 +205,12 @@ def expand(var): yield from walk(graphs, expand, bfs=False) -def _replace_rvs_in_graphs( +def _replace_vars_in_graphs( graphs: Iterable[TensorVariable], replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]], **kwargs, ) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]: - """Replace random variables in graphs + """Replace variables in graphs. This will *not* recompute test values. @@ -218,6 +218,9 @@ def _replace_rvs_in_graphs( ---------- graphs The graphs in which random variables are to be replaced. + replacement_fn + A callable called on each graph output that populates a replacement dictionary and returns + nodes that should be investigated further. Returns ------- @@ -256,7 +259,8 @@ def expand_replace(var): toposort = fg.toposort() sorted_replacements = sorted( tuple(replacements.items()), - key=lambda pair: toposort.index(pair[0].owner), + # Root inputs don't have owner, we give them negative priority -1 + key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner is not None else -1, reverse=True, ) fg.replace_all(sorted_replacements, import_missing=True) @@ -317,7 +321,7 @@ def populate_replacements( equiv = clone_get_equiv(inputs, graphs, False, False, {}) graphs = [equiv[n] for n in graphs] - graphs, _ = _replace_rvs_in_graphs( + graphs, _ = _replace_vars_in_graphs( graphs, replacement_fn=populate_replacements, **kwargs, @@ -385,7 +389,7 @@ def poulate_replacements(rv, replacements): # replacements if that is not a simple input variable return [value] - graphs, _ = _replace_rvs_in_graphs( + graphs, _ = _replace_vars_in_graphs( graphs, replacement_fn=poulate_replacements, **kwargs, diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 009eb88619..73912668a1 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -24,6 +24,7 @@ import pytest import scipy.sparse as sps +from pytensor import shared from pytensor.compile.builders import OpFromGraph from pytensor.graph.basic import Variable, equal_computations from pytensor.tensor.random.basic import normal, uniform @@ -40,6 +41,7 @@ from pymc.exceptions import NotConstantValueError from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( + _replace_vars_in_graphs, collect_default_updates, compile_pymc, constant_fold, @@ -821,3 +823,21 @@ def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn): ), [expected_x, expected_y, expected_z, expected_w], ) + + def test_replace_input(self): + inp = shared(0.0, name="inp") + x = pm.Normal.dist(inp) + + assert x.eval() < 50 + + new_inp = inp + 100 + + def replacement_fn(var, replacements): + if var is x: + replacements[x.owner.inputs[3]] = new_inp + + return [] + + [new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn) + + assert new_x.eval() > 50