|
1 |
| -from typing import Sequence |
| 1 | +from collections import deque |
| 2 | +from itertools import chain |
| 3 | +from typing import Iterable, Sequence, Set, Tuple |
2 | 4 |
|
3 | 5 | import pytensor
|
4 | 6 | from pymc import SymbolicRandomVariable
|
5 | 7 | from pytensor import Variable
|
6 |
| -from pytensor.graph import Constant, Type |
7 |
| -from pytensor.graph.basic import walk |
| 8 | +from pytensor.graph import Constant, FunctionGraph, Type |
| 9 | +from pytensor.graph.basic import Apply, walk |
8 | 10 | from pytensor.graph.op import HasInnerGraph
|
9 | 11 | from pytensor.tensor.random.op import RandomVariable
|
10 | 12 |
|
@@ -58,3 +60,183 @@ def expand(r):
|
58 | 60 | for node in walk(vars, expand, False)
|
59 | 61 | if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable))
|
60 | 62 | )
|
| 63 | + |
| 64 | + |
| 65 | +# def replace_rebuild_all( |
| 66 | +# fgraph: FunctionGraph, |
| 67 | +# replacements: Sequence[Tuple[Variable, Variable]], |
| 68 | +# reason: Optional[str] = None, |
| 69 | +# verbose: Optional[bool] = None, |
| 70 | +# import_missing: bool = False, |
| 71 | +# ) -> None: |
| 72 | +# """Replace a variable in the `FunctionGraph` and rebuild the graph. |
| 73 | +# |
| 74 | +# This is the main interface to manipulate the subgraph in `FunctionGraph`. |
| 75 | +# For every node that uses `var` as input, makes it use `new_var` instead. |
| 76 | +# |
| 77 | +# Parameters |
| 78 | +# ---------- |
| 79 | +# fgraph |
| 80 | +# The FunctionGraph where replacements are performed |
| 81 | +# var |
| 82 | +# The variable to be replaced. |
| 83 | +# new_var |
| 84 | +# The variable to replace `var`. |
| 85 | +# reason |
| 86 | +# The name of the optimization or operation in progress. |
| 87 | +# verbose |
| 88 | +# Print `reason`, `var`, and `new_var`. |
| 89 | +# import_missing |
| 90 | +# Import missing variables. |
| 91 | +# |
| 92 | +# """ |
| 93 | +# # if verbose is None: |
| 94 | +# # verbose = config.optimizer_verbose |
| 95 | +# # if verbose: |
| 96 | +# # print( |
| 97 | +# # f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}" |
| 98 | +# # ) |
| 99 | +# # |
| 100 | +# # if var not in fgraph.variables: |
| 101 | +# # return |
| 102 | +# |
| 103 | +# |
| 104 | +# |
| 105 | +# def get_client_nodes(vars): |
| 106 | +# nodes = set() |
| 107 | +# d = deque(chain.from_iterable(fgraph.clients[var] for var in vars)) |
| 108 | +# while d: |
| 109 | +# node, _ = d.pop() |
| 110 | +# if node == "output": |
| 111 | +# continue |
| 112 | +# if node in nodes: |
| 113 | +# continue |
| 114 | +# nodes.add(node) |
| 115 | +# d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs)) |
| 116 | +# return nodes |
| 117 | +# |
| 118 | +# topo_order = {node: order for order, node in enumerate(fgraph.toposort())} |
| 119 | +# old_vars = [old for old, _ in replacements] |
| 120 | +# d = deque(sorted(get_client_nodes(old_vars), key=lambda node: topo_order[node])) |
| 121 | +# |
| 122 | +# repl_dict = {old: new for old, new in replacements} |
| 123 | +# outputs = set(fgraph.outputs) |
| 124 | +# while d: |
| 125 | +# node: Apply = d.popleft() |
| 126 | +# |
| 127 | +# new_inputs = [repl_dict.get(i, i) for i in node.inputs] |
| 128 | +# if new_inputs == node.inputs: |
| 129 | +# continue |
| 130 | +# |
| 131 | +# new_node = node.clone_with_new_inputs(new_inputs, strict=False) |
| 132 | +# for out, new_out in zip(node.outputs, new_node.outputs): |
| 133 | +# repl_dict[out] = new_out |
| 134 | +# |
| 135 | +# return FunctionGraph(outputs=outputs, clone=False) |
| 136 | + |
| 137 | + |
| 138 | +def _replace_rebuild_all( |
| 139 | + fgraph: FunctionGraph, replacements: Iterable[Tuple[Variable, Variable]], **kwargs |
| 140 | +) -> FunctionGraph: |
| 141 | + """Replace variables and rebuild dependent graph if needed. |
| 142 | +
|
| 143 | + Rebuilding allows for replacements that change the semantics of the graph |
| 144 | + (different types), which may not be possible for all Ops. |
| 145 | + """ |
| 146 | + |
| 147 | + def get_client_nodes(vars) -> Set[Apply]: |
| 148 | + nodes = set() |
| 149 | + d = deque( |
| 150 | + chain.from_iterable(fgraph.clients[var] for var in vars if var in fgraph.variables) |
| 151 | + ) |
| 152 | + while d: |
| 153 | + node, _ = d.pop() |
| 154 | + if node in nodes or node == "output": |
| 155 | + continue |
| 156 | + nodes.add(node) |
| 157 | + d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs)) |
| 158 | + return nodes |
| 159 | + |
| 160 | + repl_dict = {old: new for old, new in replacements} |
| 161 | + root_nodes = {var.owner for var in repl_dict.keys()} |
| 162 | + |
| 163 | + # Build sorted queue with all nodes that depend on replaced variables |
| 164 | + topo_order = {node: order for order, node in enumerate(fgraph.toposort())} |
| 165 | + client_nodes = get_client_nodes(repl_dict.keys()) |
| 166 | + d = deque(sorted(client_nodes, key=lambda node: topo_order[node])) |
| 167 | + while d: |
| 168 | + node = d.popleft() |
| 169 | + if node in root_nodes: |
| 170 | + continue |
| 171 | + |
| 172 | + new_inputs = [repl_dict.get(i, i) for i in node.inputs] |
| 173 | + if new_inputs == node.inputs: |
| 174 | + continue |
| 175 | + |
| 176 | + # Either remake the node or do a simple inplace replacement |
| 177 | + # This property is not yet present in PyTensor |
| 178 | + if getattr(node.op, "_output_type_depends_on_input_value", False): |
| 179 | + remake_node = True |
| 180 | + else: |
| 181 | + remake_node = any( |
| 182 | + not inp.type == new_inp.type for inp, new_inp in zip(node.inputs, new_inputs) |
| 183 | + ) |
| 184 | + |
| 185 | + if remake_node: |
| 186 | + new_node = node.clone_with_new_inputs(new_inputs, strict=False) |
| 187 | + fgraph.import_node(new_node, import_missing=True) |
| 188 | + for out, new_out in zip(node.outputs, new_node.outputs): |
| 189 | + repl_dict[out] = new_out |
| 190 | + else: |
| 191 | + replace = list(zip(node.inputs, new_inputs)) |
| 192 | + fgraph.replace_all(replace, import_missing=True) |
| 193 | + |
| 194 | + # We need special logic for the cases where we had to rebuild the output nodes |
| 195 | + for i, (new_output, old_output) in enumerate( |
| 196 | + zip( |
| 197 | + (repl_dict.get(out, out) for out in fgraph.outputs), |
| 198 | + fgraph.outputs, |
| 199 | + ) |
| 200 | + ): |
| 201 | + if new_output is old_output: |
| 202 | + continue |
| 203 | + fgraph.outputs[i] = new_output |
| 204 | + fgraph.import_var(new_output, import_missing=True) |
| 205 | + client = ("output", i) |
| 206 | + fgraph.add_client(new_output, client) |
| 207 | + fgraph.remove_client(old_output, client) |
| 208 | + fgraph.execute_callbacks("on_change_input", "output", i, old_output, new_output) |
| 209 | + |
| 210 | + |
| 211 | +def toposort_replace( |
| 212 | + fgraph: FunctionGraph, |
| 213 | + replacements: Sequence[Tuple[Variable, Variable]], |
| 214 | + reverse: bool = False, |
| 215 | + rebuild: bool = False, |
| 216 | +) -> None: |
| 217 | + """Replace multiple variables in topological order.""" |
| 218 | + if rebuild and reverse: |
| 219 | + raise NotImplementedError("reverse rebuild not supported") |
| 220 | + |
| 221 | + toposort = fgraph.toposort() |
| 222 | + sorted_replacements = sorted( |
| 223 | + replacements, |
| 224 | + key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1, |
| 225 | + reverse=reverse, |
| 226 | + ) |
| 227 | + |
| 228 | + if rebuild: |
| 229 | + if len(replacements) > 1: |
| 230 | + # In this case we need to introduce the replacements inside each other |
| 231 | + # To avoid undoing previous changes |
| 232 | + sorted_replacements = [list(pairs) for pairs in sorted_replacements] |
| 233 | + for i in range(1, len(replacements)): |
| 234 | + # Replace-rebuild each successive replacement with the previous replacements (in topological order) |
| 235 | + temp_fgraph = FunctionGraph( |
| 236 | + outputs=[repl for _, repl in sorted_replacements[i:]], clone=False |
| 237 | + ) |
| 238 | + _replace_rebuild_all(temp_fgraph, replacements=sorted_replacements[:i]) |
| 239 | + sorted_replacements[i][1] = temp_fgraph.outputs[0] |
| 240 | + _replace_rebuild_all(fgraph, sorted_replacements, import_missing=True) |
| 241 | + else: |
| 242 | + fgraph.replace_all(sorted_replacements, import_missing=True) |
0 commit comments