|
1 |
| -from typing import Sequence |
| 1 | +from collections import deque |
| 2 | +from itertools import chain |
| 3 | +from typing import Iterable, Sequence, Set, Tuple |
2 | 4 |
|
3 | 5 | import pytensor
|
4 | 6 | from pymc import SymbolicRandomVariable
|
5 | 7 | from pytensor import Variable
|
6 |
| -from pytensor.graph import Constant, Type |
7 |
| -from pytensor.graph.basic import walk |
| 8 | +from pytensor.graph import Constant, FunctionGraph, Type |
| 9 | +from pytensor.graph.basic import Apply, walk |
8 | 10 | from pytensor.graph.op import HasInnerGraph
|
9 | 11 | from pytensor.tensor.random.op import RandomVariable
|
10 | 12 |
|
@@ -58,3 +60,110 @@ def expand(r):
|
58 | 60 | for node in walk(vars, expand, False)
|
59 | 61 | if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable))
|
60 | 62 | )
|
| 63 | + |
| 64 | + |
| 65 | +def _replace_rebuild_all( |
| 66 | + fgraph: FunctionGraph, replacements: Iterable[Tuple[Variable, Variable]], **kwargs |
| 67 | +) -> FunctionGraph: |
| 68 | + """Replace variables and rebuild dependent graph if needed. |
| 69 | +
|
| 70 | + Rebuilding allows for replacements that change the semantics of the graph |
| 71 | + (different types), which may not be possible for all Ops. |
| 72 | + """ |
| 73 | + |
| 74 | + def get_client_nodes(vars) -> Set[Apply]: |
| 75 | + nodes = set() |
| 76 | + d = deque( |
| 77 | + chain.from_iterable(fgraph.clients[var] for var in vars if var in fgraph.variables) |
| 78 | + ) |
| 79 | + while d: |
| 80 | + node, _ = d.pop() |
| 81 | + if node in nodes or node == "output": |
| 82 | + continue |
| 83 | + nodes.add(node) |
| 84 | + d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs)) |
| 85 | + return nodes |
| 86 | + |
| 87 | + repl_dict = {old: new for old, new in replacements} |
| 88 | + root_nodes = {var.owner for var in repl_dict.keys()} |
| 89 | + |
| 90 | + # Build sorted queue with all nodes that depend on replaced variables |
| 91 | + topo_order = {node: order for order, node in enumerate(fgraph.toposort())} |
| 92 | + client_nodes = get_client_nodes(repl_dict.keys()) |
| 93 | + d = deque(sorted(client_nodes, key=lambda node: topo_order[node])) |
| 94 | + while d: |
| 95 | + node = d.popleft() |
| 96 | + if node in root_nodes: |
| 97 | + continue |
| 98 | + |
| 99 | + new_inputs = [repl_dict.get(i, i) for i in node.inputs] |
| 100 | + if new_inputs == node.inputs: |
| 101 | + continue |
| 102 | + |
| 103 | + # Either remake the node or do a simple inplace replacement |
| 104 | + # This property is not yet present in PyTensor |
| 105 | + if getattr(node.op, "_output_type_depends_on_input_value", False): |
| 106 | + remake_node = True |
| 107 | + else: |
| 108 | + remake_node = any( |
| 109 | + not inp.type == new_inp.type for inp, new_inp in zip(node.inputs, new_inputs) |
| 110 | + ) |
| 111 | + |
| 112 | + if remake_node: |
| 113 | + new_node = node.clone_with_new_inputs(new_inputs, strict=False) |
| 114 | + fgraph.import_node(new_node, import_missing=True) |
| 115 | + for out, new_out in zip(node.outputs, new_node.outputs): |
| 116 | + repl_dict[out] = new_out |
| 117 | + else: |
| 118 | + replace = list(zip(node.inputs, new_inputs)) |
| 119 | + fgraph.replace_all(replace, import_missing=True) |
| 120 | + |
| 121 | + # We need special logic for the cases where we had to rebuild the output nodes |
| 122 | + for i, (new_output, old_output) in enumerate( |
| 123 | + zip( |
| 124 | + (repl_dict.get(out, out) for out in fgraph.outputs), |
| 125 | + fgraph.outputs, |
| 126 | + ) |
| 127 | + ): |
| 128 | + if new_output is old_output: |
| 129 | + continue |
| 130 | + fgraph.outputs[i] = new_output |
| 131 | + fgraph.import_var(new_output, import_missing=True) |
| 132 | + client = ("output", i) |
| 133 | + fgraph.add_client(new_output, client) |
| 134 | + fgraph.remove_client(old_output, client) |
| 135 | + fgraph.execute_callbacks("on_change_input", "output", i, old_output, new_output) |
| 136 | + |
| 137 | + |
| 138 | +def toposort_replace( |
| 139 | + fgraph: FunctionGraph, |
| 140 | + replacements: Sequence[Tuple[Variable, Variable]], |
| 141 | + reverse: bool = False, |
| 142 | + rebuild: bool = False, |
| 143 | +) -> None: |
| 144 | + """Replace multiple variables in topological order.""" |
| 145 | + if rebuild and reverse: |
| 146 | + raise NotImplementedError("reverse rebuild not supported") |
| 147 | + |
| 148 | + toposort = fgraph.toposort() |
| 149 | + sorted_replacements = sorted( |
| 150 | + replacements, |
| 151 | + key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1, |
| 152 | + reverse=reverse, |
| 153 | + ) |
| 154 | + |
| 155 | + if rebuild: |
| 156 | + if len(replacements) > 1: |
| 157 | + # In this case we need to introduce the replacements inside each other |
| 158 | + # To avoid undoing previous changes |
| 159 | + sorted_replacements = [list(pairs) for pairs in sorted_replacements] |
| 160 | + for i in range(1, len(replacements)): |
| 161 | + # Replace-rebuild each successive replacement with the previous replacements (in topological order) |
| 162 | + temp_fgraph = FunctionGraph( |
| 163 | + outputs=[repl for _, repl in sorted_replacements[i:]], clone=False |
| 164 | + ) |
| 165 | + _replace_rebuild_all(temp_fgraph, replacements=sorted_replacements[:i]) |
| 166 | + sorted_replacements[i][1] = temp_fgraph.outputs[0] |
| 167 | + _replace_rebuild_all(fgraph, sorted_replacements, import_missing=True) |
| 168 | + else: |
| 169 | + fgraph.replace_all(sorted_replacements, import_missing=True) |
0 commit comments