Skip to content

Commit beccca4

Browse files
committed
Rework transform
1 parent 66f61d0 commit beccca4

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

pymc/model/transform/basic.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# limitations under the License.
1414
from collections.abc import Sequence
1515

16-
from pytensor import Variable
17-
from pytensor.graph import ancestors, node_rewriter
18-
from pytensor.graph.rewriting.basic import out2in
16+
from pytensor import Variable, clone_replace
17+
from pytensor.graph import ancestors
18+
from pytensor.graph.fg import FunctionGraph
1919

2020
from pymc.data import MinibatchOp
2121
from pymc.model.core import Model
@@ -62,14 +62,23 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
6262
return [model[var] if isinstance(var, str) else var for var in vars_seq]
6363

6464

65-
def remove_minibatched_nodes(model: Model):
65+
def remove_minibatched_nodes(model: Model) -> Model:
6666
"""Remove all uses of pm.Minibatch in the Model."""
67+
fgraph, _ = fgraph_from_model(model)
6768

68-
@node_rewriter([MinibatchOp])
69-
def local_remove_minibatch(fgraph, node):
70-
return node.inputs
69+
replacements = {}
70+
for var in fgraph.apply_nodes:
71+
if isinstance(var.op, MinibatchOp):
72+
for inp, out in zip(var.inputs, var.outputs):
73+
replacements[out] = inp
7174

72-
remove_minibatch = out2in(local_remove_minibatch)
73-
fgraph, _ = fgraph_from_model(model)
74-
remove_minibatch.apply(fgraph)
75+
old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined]
76+
# Using `rebuild_strict=False` means all coords, names, and dim information is lost
77+
# So we need to restore it from the old fgraph
78+
new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type]
79+
for old_out, new_out in zip(old_outs, new_outs):
80+
new_out.name = old_out.name
81+
fgraph = FunctionGraph(outputs=new_outs, clone=False)
82+
fgraph._coords = old_coords # type: ignore[attr-defined]
83+
fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined]
7584
return model_from_fgraph(fgraph, mutate_fgraph=True)

0 commit comments

Comments
 (0)