Skip to content

Rename _replace_rvs_in_graphs and fix bug when replacing input #6720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def ignore_logprob_multiple_vars(
making each "unmeasurable", whereas a sequential call to `ignore_logprob`
would not do this correctly.
"""
from pymc.pytensorf import _replace_rvs_in_graphs
from pymc.pytensorf import _replace_vars_in_graphs

measurable_vars_to_unmeasurable_vars = {
measurable_var: ignore_logprob(measurable_var) for measurable_var in vars
Expand All @@ -353,5 +353,5 @@ def replacement_fn(var, replacements):

return []

unmeasurable_vars, _ = _replace_rvs_in_graphs(graphs=vars, replacement_fn=replacement_fn)
unmeasurable_vars, _ = _replace_vars_in_graphs(graphs=vars, replacement_fn=replacement_fn)
return unmeasurable_vars
14 changes: 9 additions & 5 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,22 @@ def expand(var):
yield from walk(graphs, expand, bfs=False)


def _replace_rvs_in_graphs(
def _replace_vars_in_graphs(
graphs: Iterable[TensorVariable],
replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]],
**kwargs,
) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]:
"""Replace random variables in graphs
"""Replace variables in graphs.

This will *not* recompute test values.

Parameters
----------
graphs
The graphs in which random variables are to be replaced.
replacement_fn
A callable called on each graph output that populates a replacement dictionary and returns
nodes that should be investigated further.

Returns
-------
Expand Down Expand Up @@ -256,7 +259,8 @@ def expand_replace(var):
toposort = fg.toposort()
sorted_replacements = sorted(
tuple(replacements.items()),
key=lambda pair: toposort.index(pair[0].owner),
# Root inputs don't have owner, we give them negative priority -1
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner is not None else -1,
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)
Expand Down Expand Up @@ -317,7 +321,7 @@ def populate_replacements(
equiv = clone_get_equiv(inputs, graphs, False, False, {})
graphs = [equiv[n] for n in graphs]

graphs, _ = _replace_rvs_in_graphs(
graphs, _ = _replace_vars_in_graphs(
graphs,
replacement_fn=populate_replacements,
**kwargs,
Expand Down Expand Up @@ -385,7 +389,7 @@ def poulate_replacements(rv, replacements):
# replacements if that is not a simple input variable
return [value]

graphs, _ = _replace_rvs_in_graphs(
graphs, _ = _replace_vars_in_graphs(
graphs,
replacement_fn=poulate_replacements,
**kwargs,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest
import scipy.sparse as sps

from pytensor import shared
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Variable, equal_computations
from pytensor.tensor.random.basic import normal, uniform
Expand All @@ -40,6 +41,7 @@
from pymc.exceptions import NotConstantValueError
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import (
_replace_vars_in_graphs,
collect_default_updates,
compile_pymc,
constant_fold,
Expand Down Expand Up @@ -821,3 +823,21 @@ def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn):
),
[expected_x, expected_y, expected_z, expected_w],
)

def test_replace_input(self):
inp = shared(0.0, name="inp")
x = pm.Normal.dist(inp)

assert x.eval() < 50

new_inp = inp + 100

def replacement_fn(var, replacements):
if var is x:
replacements[x.owner.inputs[3]] = new_inp

return []

[new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn)

assert new_x.eval() > 50