|
13 | 13 | # limitations under the License.
|
14 | 14 | from collections.abc import Sequence
|
15 | 15 |
|
16 |
| -from pytensor import Variable |
17 |
| -from pytensor.graph import ancestors, node_rewriter |
18 |
| -from pytensor.graph.rewriting.basic import out2in |
| 16 | +from pytensor import Variable, clone_replace |
| 17 | +from pytensor.graph import ancestors |
| 18 | +from pytensor.graph.fg import FunctionGraph |
19 | 19 |
|
20 | 20 | from pymc.data import MinibatchOp
|
21 | 21 | from pymc.model.core import Model
|
@@ -62,14 +62,23 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
|
62 | 62 | return [model[var] if isinstance(var, str) else var for var in vars_seq]
|
63 | 63 |
|
64 | 64 |
|
65 |
| -def remove_minibatched_nodes(model: Model): |
| 65 | +def remove_minibatched_nodes(model: Model) -> Model: |
66 | 66 | """Remove all uses of pm.Minibatch in the Model."""
|
| 67 | + fgraph, _ = fgraph_from_model(model) |
67 | 68 |
|
68 |
| - @node_rewriter([MinibatchOp]) |
69 |
| - def local_remove_minibatch(fgraph, node): |
70 |
| - return node.inputs |
| 69 | + replacements = {} |
| 70 | + for var in fgraph.apply_nodes: |
| 71 | + if isinstance(var.op, MinibatchOp): |
| 72 | + for inp, out in zip(var.inputs, var.outputs): |
| 73 | + replacements[out] = inp |
71 | 74 |
|
72 |
| - remove_minibatch = out2in(local_remove_minibatch) |
73 |
| - fgraph, _ = fgraph_from_model(model) |
74 |
| - remove_minibatch.apply(fgraph) |
| 75 | + old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined] |
| 76 | + # Using `rebuild_strict=False` means all coords, names, and dim information is lost |
| 77 | + # So we need to restore it from the old fgraph |
| 78 | + new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] |
| 79 | + for old_out, new_out in zip(old_outs, new_outs): |
| 80 | + new_out.name = old_out.name |
| 81 | + fgraph = FunctionGraph(outputs=new_outs, clone=False) |
| 82 | + fgraph._coords = old_coords # type: ignore[attr-defined] |
| 83 | + fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] |
75 | 84 | return model_from_fgraph(fgraph, mutate_fgraph=True)
|
0 commit comments