Skip to content

Commit 72acac4

Browse files
committed
Allow rebuilding a graph in toposort_replace
1 parent dd3c44d commit 72acac4

File tree

4 files changed

+186
-20
lines changed

4 files changed

+186
-20
lines changed

pymc_experimental/model_transform/conditioning.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
model_from_fgraph,
1515
model_named,
1616
model_observed_rv,
17-
toposort_replace,
1817
)
19-
from pymc_experimental.utils.pytensorf import rvs_in_graph
18+
from pymc_experimental.utils.pytensorf import rvs_in_graph, toposort_replace
2019

2120

2221
def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytensor.tensor as pt
2+
import pytest
3+
from pytensor.graph import FunctionGraph
4+
from pytensor.graph.basic import equal_computations
5+
6+
from pymc_experimental.utils.pytensorf import toposort_replace
7+
8+
9+
class TestToposortReplace:
10+
@pytest.mark.parametrize("compatible_type", (True, False))
11+
@pytest.mark.parametrize("num_replacements", (1, 2))
12+
@pytest.mark.parametrize("rebuild", (True, False))
13+
def test_horizontal_dependency(self, compatible_type, num_replacements, rebuild):
14+
x = pt.vector("x", shape=(5,))
15+
y = pt.vector("y", shape=(5,))
16+
17+
out1 = pt.exp(x + y) + pt.log(x + y)
18+
out2 = pt.cos(out1)
19+
20+
new_shape = (5,) if compatible_type else (10,)
21+
new_x = pt.vector("new_x", shape=new_shape)
22+
new_y = pt.vector("new_y", shape=new_shape)
23+
if num_replacements == 1:
24+
replacements = [(y, new_y)]
25+
else:
26+
replacements = [(x, new_x), (y, new_y)]
27+
28+
fg = FunctionGraph([x, y], [out1, out2], clone=False)
29+
30+
# If types are incompatible, and we don't rebuild or only replace one of the variables,
31+
# The function should fail
32+
if not compatible_type and (not rebuild or num_replacements == 1):
33+
with pytest.raises((TypeError, ValueError)):
34+
toposort_replace(fg, replacements, rebuild=rebuild)
35+
return
36+
toposort_replace(fg, replacements, rebuild=rebuild)
37+
38+
if num_replacements == 1:
39+
expected_out1 = pt.exp(x + new_y) + pt.log(x + new_y)
40+
else:
41+
expected_out1 = pt.exp(new_x + new_y) + pt.log(new_x + new_y)
42+
expected_out2 = pt.cos(expected_out1)
43+
assert equal_computations(fg.outputs, [expected_out1, expected_out2])
44+
45+
@pytest.mark.parametrize("compatible_type", (True, False))
46+
@pytest.mark.parametrize("num_replacements", (2, 3))
47+
@pytest.mark.parametrize("rebuild", (True, False))
48+
def test_vertical_dependency(self, compatible_type, num_replacements, rebuild):
49+
x = pt.vector("x", shape=(5,))
50+
a1 = pt.exp(x)
51+
a2 = pt.log(a1)
52+
out = a1 + a2
53+
54+
new_x = pt.vector("new_x", shape=(5 if compatible_type else 10,))
55+
if num_replacements == 2:
56+
replacements = [(x, new_x), (a1, pt.cos(a1)), (a2, pt.sin(a2 + 5))]
57+
else:
58+
replacements = [(a1, pt.cos(pt.exp(new_x))), (a2, pt.sin(a2 + 5))]
59+
60+
fg = FunctionGraph([x], [out], clone=False)
61+
62+
if not compatible_type and not rebuild:
63+
with pytest.raises(TypeError):
64+
toposort_replace(fg, replacements, rebuild=rebuild)
65+
return
66+
toposort_replace(fg, replacements, rebuild=rebuild)
67+
68+
expected_a1 = pt.cos(pt.exp(new_x))
69+
expected_a2 = pt.sin(pt.log(expected_a1) + 5)
70+
expected_out = expected_a1 + expected_a2
71+
assert equal_computations(fg.outputs, [expected_out])

pymc_experimental/utils/model_fgraph.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Sequence, Tuple
1+
from typing import Dict, Optional, Tuple
22

33
import pytensor
44
from pymc.logprob.transforms import RVTransform
@@ -10,7 +10,7 @@
1010
from pytensor.scalar import Identity
1111
from pytensor.tensor.elemwise import Elemwise
1212

13-
from pymc_experimental.utils.pytensorf import StringType
13+
from pymc_experimental.utils.pytensorf import StringType, toposort_replace
1414

1515

1616
class ModelVar(Op):
@@ -89,19 +89,6 @@ def model_free_rv(rv, value, transform, *dims):
8989
model_named = ModelNamed()
9090

9191

92-
def toposort_replace(
93-
fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False
94-
) -> None:
95-
"""Replace multiple variables in topological order."""
96-
toposort = fgraph.toposort()
97-
sorted_replacements = sorted(
98-
replacements,
99-
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
100-
reverse=reverse,
101-
)
102-
fgraph.replace_all(sorted_replacements, import_missing=True)
103-
104-
10592
@node_rewriter([Elemwise])
10693
def local_remove_identity(fgraph, node):
10794
if isinstance(node.op.scalar_op, Identity):

pymc_experimental/utils/pytensorf.py

+112-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Sequence
1+
from collections import deque
2+
from itertools import chain
3+
from typing import Iterable, Sequence, Set, Tuple
24

35
import pytensor
46
from pymc import SymbolicRandomVariable
57
from pytensor import Variable
6-
from pytensor.graph import Constant, Type
7-
from pytensor.graph.basic import walk
8+
from pytensor.graph import Constant, FunctionGraph, Type
9+
from pytensor.graph.basic import Apply, walk
810
from pytensor.graph.op import HasInnerGraph
911
from pytensor.tensor.random.op import RandomVariable
1012

@@ -58,3 +60,110 @@ def expand(r):
5860
for node in walk(vars, expand, False)
5961
if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable))
6062
)
63+
64+
65+
def _replace_rebuild_all(
66+
fgraph: FunctionGraph, replacements: Iterable[Tuple[Variable, Variable]], **kwargs
67+
) -> FunctionGraph:
68+
"""Replace variables and rebuild dependent graph if needed.
69+
70+
Rebuilding allows for replacements that change the semantics of the graph
71+
(different types), which may not be possible for all Ops.
72+
"""
73+
74+
def get_client_nodes(vars) -> Set[Apply]:
75+
nodes = set()
76+
d = deque(
77+
chain.from_iterable(fgraph.clients[var] for var in vars if var in fgraph.variables)
78+
)
79+
while d:
80+
node, _ = d.pop()
81+
if node in nodes or node == "output":
82+
continue
83+
nodes.add(node)
84+
d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs))
85+
return nodes
86+
87+
repl_dict = {old: new for old, new in replacements}
88+
root_nodes = {var.owner for var in repl_dict.keys()}
89+
90+
# Build sorted queue with all nodes that depend on replaced variables
91+
topo_order = {node: order for order, node in enumerate(fgraph.toposort())}
92+
client_nodes = get_client_nodes(repl_dict.keys())
93+
d = deque(sorted(client_nodes, key=lambda node: topo_order[node]))
94+
while d:
95+
node = d.popleft()
96+
if node in root_nodes:
97+
continue
98+
99+
new_inputs = [repl_dict.get(i, i) for i in node.inputs]
100+
if new_inputs == node.inputs:
101+
continue
102+
103+
# Either remake the node or do a simple inplace replacement
104+
# This property is not yet present in PyTensor
105+
if getattr(node.op, "_output_type_depends_on_input_value", False):
106+
remake_node = True
107+
else:
108+
remake_node = any(
109+
not inp.type == new_inp.type for inp, new_inp in zip(node.inputs, new_inputs)
110+
)
111+
112+
if remake_node:
113+
new_node = node.clone_with_new_inputs(new_inputs, strict=False)
114+
fgraph.import_node(new_node, import_missing=True)
115+
for out, new_out in zip(node.outputs, new_node.outputs):
116+
repl_dict[out] = new_out
117+
else:
118+
replace = list(zip(node.inputs, new_inputs))
119+
fgraph.replace_all(replace, import_missing=True)
120+
121+
# We need special logic for the cases where we had to rebuild the output nodes
122+
for i, (new_output, old_output) in enumerate(
123+
zip(
124+
(repl_dict.get(out, out) for out in fgraph.outputs),
125+
fgraph.outputs,
126+
)
127+
):
128+
if new_output is old_output:
129+
continue
130+
fgraph.outputs[i] = new_output
131+
fgraph.import_var(new_output, import_missing=True)
132+
client = ("output", i)
133+
fgraph.add_client(new_output, client)
134+
fgraph.remove_client(old_output, client)
135+
fgraph.execute_callbacks("on_change_input", "output", i, old_output, new_output)
136+
137+
138+
def toposort_replace(
139+
fgraph: FunctionGraph,
140+
replacements: Sequence[Tuple[Variable, Variable]],
141+
reverse: bool = False,
142+
rebuild: bool = False,
143+
) -> None:
144+
"""Replace multiple variables in topological order."""
145+
if rebuild and reverse:
146+
raise NotImplementedError("reverse rebuild not supported")
147+
148+
toposort = fgraph.toposort()
149+
sorted_replacements = sorted(
150+
replacements,
151+
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
152+
reverse=reverse,
153+
)
154+
155+
if rebuild:
156+
if len(replacements) > 1:
157+
# In this case we need to introduce the replacements inside each other
158+
# To avoid undoing previous changes
159+
sorted_replacements = [list(pairs) for pairs in sorted_replacements]
160+
for i in range(1, len(replacements)):
161+
# Replace-rebuild each successive replacement with the previous replacements (in topological order)
162+
temp_fgraph = FunctionGraph(
163+
outputs=[repl for _, repl in sorted_replacements[i:]], clone=False
164+
)
165+
_replace_rebuild_all(temp_fgraph, replacements=sorted_replacements[:i])
166+
sorted_replacements[i][1] = temp_fgraph.outputs[0]
167+
_replace_rebuild_all(fgraph, sorted_replacements, import_missing=True)
168+
else:
169+
fgraph.replace_all(sorted_replacements, import_missing=True)

0 commit comments

Comments
 (0)