Skip to content

Commit 8a14e7e

Browse files
committed
Allow rebuilding a graph in toposort_replace
1 parent dd3c44d commit 8a14e7e

File tree

4 files changed

+259
-20
lines changed

4 files changed

+259
-20
lines changed

pymc_experimental/model_transform/conditioning.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
model_from_fgraph,
1515
model_named,
1616
model_observed_rv,
17-
toposort_replace,
1817
)
19-
from pymc_experimental.utils.pytensorf import rvs_in_graph
18+
from pymc_experimental.utils.pytensorf import rvs_in_graph, toposort_replace
2019

2120

2221
def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytensor.tensor as pt
2+
import pytest
3+
from pytensor.graph import FunctionGraph
4+
from pytensor.graph.basic import equal_computations
5+
6+
from pymc_experimental.utils.pytensorf import toposort_replace
7+
8+
9+
class TestToposortReplace:
10+
@pytest.mark.parametrize("compatible_type", (True, False))
11+
@pytest.mark.parametrize("num_replacements", (1, 2))
12+
@pytest.mark.parametrize("rebuild", (True, False))
13+
def test_horizontal_dependency(self, compatible_type, num_replacements, rebuild):
14+
x = pt.vector("x", shape=(5,))
15+
y = pt.vector("y", shape=(5,))
16+
17+
out1 = pt.exp(x + y) + pt.log(x + y)
18+
out2 = pt.cos(out1)
19+
20+
new_shape = (5,) if compatible_type else (10,)
21+
new_x = pt.vector("new_x", shape=new_shape)
22+
new_y = pt.vector("new_y", shape=new_shape)
23+
if num_replacements == 1:
24+
replacements = [(y, new_y)]
25+
else:
26+
replacements = [(x, new_x), (y, new_y)]
27+
28+
fg = FunctionGraph([x, y], [out1, out2], clone=False)
29+
30+
# If types are incompatible, and we don't rebuild or only replace one of the variables,
31+
# The function should fail
32+
if not compatible_type and (not rebuild or num_replacements == 1):
33+
with pytest.raises((TypeError, ValueError)):
34+
toposort_replace(fg, replacements, rebuild=rebuild)
35+
return
36+
toposort_replace(fg, replacements, rebuild=rebuild)
37+
38+
if num_replacements == 1:
39+
expected_out1 = pt.exp(x + new_y) + pt.log(x + new_y)
40+
else:
41+
expected_out1 = pt.exp(new_x + new_y) + pt.log(new_x + new_y)
42+
expected_out2 = pt.cos(expected_out1)
43+
assert equal_computations(fg.outputs, [expected_out1, expected_out2])
44+
45+
@pytest.mark.parametrize("compatible_type", (True, False))
46+
@pytest.mark.parametrize("num_replacements", (2, 3))
47+
@pytest.mark.parametrize("rebuild", (True, False))
48+
def test_vertical_dependency(self, compatible_type, num_replacements, rebuild):
49+
x = pt.vector("x", shape=(5,))
50+
a1 = pt.exp(x)
51+
a2 = pt.log(a1)
52+
out = a1 + a2
53+
54+
new_x = pt.vector("new_x", shape=(5 if compatible_type else 10,))
55+
if num_replacements == 2:
56+
replacements = [(x, new_x), (a1, pt.cos(a1)), (a2, pt.sin(a2 + 5))]
57+
else:
58+
replacements = [(a1, pt.cos(pt.exp(new_x))), (a2, pt.sin(a2 + 5))]
59+
60+
fg = FunctionGraph([x], [out], clone=False)
61+
62+
if not compatible_type and not rebuild:
63+
with pytest.raises(TypeError):
64+
toposort_replace(fg, replacements, rebuild=rebuild)
65+
return
66+
toposort_replace(fg, replacements, rebuild=rebuild)
67+
68+
expected_a1 = pt.cos(pt.exp(new_x))
69+
expected_a2 = pt.sin(pt.log(expected_a1) + 5)
70+
expected_out = expected_a1 + expected_a2
71+
assert equal_computations(fg.outputs, [expected_out])

pymc_experimental/utils/model_fgraph.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Sequence, Tuple
1+
from typing import Dict, Optional, Tuple
22

33
import pytensor
44
from pymc.logprob.transforms import RVTransform
@@ -10,7 +10,7 @@
1010
from pytensor.scalar import Identity
1111
from pytensor.tensor.elemwise import Elemwise
1212

13-
from pymc_experimental.utils.pytensorf import StringType
13+
from pymc_experimental.utils.pytensorf import StringType, toposort_replace
1414

1515

1616
class ModelVar(Op):
@@ -89,19 +89,6 @@ def model_free_rv(rv, value, transform, *dims):
8989
model_named = ModelNamed()
9090

9191

92-
def toposort_replace(
93-
fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False
94-
) -> None:
95-
"""Replace multiple variables in topological order."""
96-
toposort = fgraph.toposort()
97-
sorted_replacements = sorted(
98-
replacements,
99-
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
100-
reverse=reverse,
101-
)
102-
fgraph.replace_all(sorted_replacements, import_missing=True)
103-
104-
10592
@node_rewriter([Elemwise])
10693
def local_remove_identity(fgraph, node):
10794
if isinstance(node.op.scalar_op, Identity):

pymc_experimental/utils/pytensorf.py

+185-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Sequence
1+
from collections import deque
2+
from itertools import chain
3+
from typing import Iterable, Sequence, Set, Tuple
24

35
import pytensor
46
from pymc import SymbolicRandomVariable
57
from pytensor import Variable
6-
from pytensor.graph import Constant, Type
7-
from pytensor.graph.basic import walk
8+
from pytensor.graph import Constant, FunctionGraph, Type
9+
from pytensor.graph.basic import Apply, walk
810
from pytensor.graph.op import HasInnerGraph
911
from pytensor.tensor.random.op import RandomVariable
1012

@@ -58,3 +60,183 @@ def expand(r):
5860
for node in walk(vars, expand, False)
5961
if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable))
6062
)
63+
64+
65+
# def replace_rebuild_all(
66+
# fgraph: FunctionGraph,
67+
# replacements: Sequence[Tuple[Variable, Variable]],
68+
# reason: Optional[str] = None,
69+
# verbose: Optional[bool] = None,
70+
# import_missing: bool = False,
71+
# ) -> None:
72+
# """Replace a variable in the `FunctionGraph` and rebuild the graph.
73+
#
74+
# This is the main interface to manipulate the subgraph in `FunctionGraph`.
75+
# For every node that uses `var` as input, makes it use `new_var` instead.
76+
#
77+
# Parameters
78+
# ----------
79+
# fgraph
80+
# The FunctionGraph where replacements are performed
81+
# var
82+
# The variable to be replaced.
83+
# new_var
84+
# The variable to replace `var`.
85+
# reason
86+
# The name of the optimization or operation in progress.
87+
# verbose
88+
# Print `reason`, `var`, and `new_var`.
89+
# import_missing
90+
# Import missing variables.
91+
#
92+
# """
93+
# # if verbose is None:
94+
# # verbose = config.optimizer_verbose
95+
# # if verbose:
96+
# # print(
97+
# # f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}"
98+
# # )
99+
# #
100+
# # if var not in fgraph.variables:
101+
# # return
102+
#
103+
#
104+
#
105+
# def get_client_nodes(vars):
106+
# nodes = set()
107+
# d = deque(chain.from_iterable(fgraph.clients[var] for var in vars))
108+
# while d:
109+
# node, _ = d.pop()
110+
# if node == "output":
111+
# continue
112+
# if node in nodes:
113+
# continue
114+
# nodes.add(node)
115+
# d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs))
116+
# return nodes
117+
#
118+
# topo_order = {node: order for order, node in enumerate(fgraph.toposort())}
119+
# old_vars = [old for old, _ in replacements]
120+
# d = deque(sorted(get_client_nodes(old_vars), key=lambda node: topo_order[node]))
121+
#
122+
# repl_dict = {old: new for old, new in replacements}
123+
# outputs = set(fgraph.outputs)
124+
# while d:
125+
# node: Apply = d.popleft()
126+
#
127+
# new_inputs = [repl_dict.get(i, i) for i in node.inputs]
128+
# if new_inputs == node.inputs:
129+
# continue
130+
#
131+
# new_node = node.clone_with_new_inputs(new_inputs, strict=False)
132+
# for out, new_out in zip(node.outputs, new_node.outputs):
133+
# repl_dict[out] = new_out
134+
#
135+
# return FunctionGraph(outputs=outputs, clone=False)
136+
137+
138+
def _replace_rebuild_all(
139+
fgraph: FunctionGraph, replacements: Iterable[Tuple[Variable, Variable]], **kwargs
140+
) -> FunctionGraph:
141+
"""Replace variables and rebuild dependent graph if needed.
142+
143+
Rebuilding allows for replacements that change the semantics of the graph
144+
(different types), which may not be possible for all Ops.
145+
"""
146+
147+
def get_client_nodes(vars) -> Set[Apply]:
148+
nodes = set()
149+
d = deque(
150+
chain.from_iterable(fgraph.clients[var] for var in vars if var in fgraph.variables)
151+
)
152+
while d:
153+
node, _ = d.pop()
154+
if node in nodes or node == "output":
155+
continue
156+
nodes.add(node)
157+
d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs))
158+
return nodes
159+
160+
repl_dict = {old: new for old, new in replacements}
161+
root_nodes = {var.owner for var in repl_dict.keys()}
162+
163+
# Build sorted queue with all nodes that depend on replaced variables
164+
topo_order = {node: order for order, node in enumerate(fgraph.toposort())}
165+
client_nodes = get_client_nodes(repl_dict.keys())
166+
d = deque(sorted(client_nodes, key=lambda node: topo_order[node]))
167+
while d:
168+
node = d.popleft()
169+
if node in root_nodes:
170+
continue
171+
172+
new_inputs = [repl_dict.get(i, i) for i in node.inputs]
173+
if new_inputs == node.inputs:
174+
continue
175+
176+
# Either remake the node or do a simple inplace replacement
177+
# This property is not yet present in PyTensor
178+
if getattr(node.op, "_output_type_depends_on_input_value", False):
179+
remake_node = True
180+
else:
181+
remake_node = any(
182+
not inp.type == new_inp.type for inp, new_inp in zip(node.inputs, new_inputs)
183+
)
184+
185+
if remake_node:
186+
new_node = node.clone_with_new_inputs(new_inputs, strict=False)
187+
fgraph.import_node(new_node, import_missing=True)
188+
for out, new_out in zip(node.outputs, new_node.outputs):
189+
repl_dict[out] = new_out
190+
else:
191+
replace = list(zip(node.inputs, new_inputs))
192+
fgraph.replace_all(replace, import_missing=True)
193+
194+
# We need special logic for the cases where we had to rebuild the output nodes
195+
for i, (new_output, old_output) in enumerate(
196+
zip(
197+
(repl_dict.get(out, out) for out in fgraph.outputs),
198+
fgraph.outputs,
199+
)
200+
):
201+
if new_output is old_output:
202+
continue
203+
fgraph.outputs[i] = new_output
204+
fgraph.import_var(new_output, import_missing=True)
205+
client = ("output", i)
206+
fgraph.add_client(new_output, client)
207+
fgraph.remove_client(old_output, client)
208+
fgraph.execute_callbacks("on_change_input", "output", i, old_output, new_output)
209+
210+
211+
def toposort_replace(
212+
fgraph: FunctionGraph,
213+
replacements: Sequence[Tuple[Variable, Variable]],
214+
reverse: bool = False,
215+
rebuild: bool = False,
216+
) -> None:
217+
"""Replace multiple variables in topological order."""
218+
if rebuild and reverse:
219+
raise NotImplementedError("reverse rebuild not supported")
220+
221+
toposort = fgraph.toposort()
222+
sorted_replacements = sorted(
223+
replacements,
224+
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
225+
reverse=reverse,
226+
)
227+
228+
if rebuild:
229+
if len(replacements) > 1:
230+
# In this case we need to introduce the replacements inside each other
231+
# To avoid undoing previous changes
232+
sorted_replacements = [list(pairs) for pairs in sorted_replacements]
233+
for i in range(1, len(replacements)):
234+
# Replace-rebuild each successive replacement with the previous replacements (in topological order)
235+
temp_fgraph = FunctionGraph(
236+
outputs=[repl for _, repl in sorted_replacements[i:]], clone=False
237+
)
238+
_replace_rebuild_all(temp_fgraph, replacements=sorted_replacements[:i])
239+
sorted_replacements[i][1] = temp_fgraph.outputs[0]
240+
_replace_rebuild_all(fgraph, sorted_replacements, import_missing=True)
241+
else:
242+
fgraph.replace_all(sorted_replacements, import_missing=True)

0 commit comments

Comments
 (0)