Skip to content

Commit ba60b79

Browse files
committed
Add option to avoid cloning fgraph in model_from_fgraph
In most function transforms the caller is both creating the fgraph representation and discarding it, so it's safe to mutate the fgraph in place.
1 parent 01b0219 commit ba60b79

File tree

4 files changed

+21
-14
lines changed

4 files changed

+21
-14
lines changed

pymc/model/fgraph.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,18 @@ def fgraph_from_model(
278278
return fgraph, memo
279279

280280

281-
def model_from_fgraph(fgraph: FunctionGraph) -> Model:
281+
def model_from_fgraph(fgraph: FunctionGraph, mutate_fgraph: bool = False) -> Model:
282282
"""Convert FunctionGraph to PyMC model.
283283
284-
This requires nodes to be properly tagged with `ModelVar` dummy Ops.
284+
Parameters
285+
----------
286+
fgraph: FunctionGraph
287+
fgraph representation of a PyMC model, with dummy `ModelVar` Ops.
288+
See `fgraph_from_model` for more details.
285289
286-
See: fgraph_from_model
290+
mutate_fgraph: bool, default False
291+
Whether the function is allowed to modify the fgraph (and it's variables) in place.
292+
This is useful if these are not needed anymore after the model is created.
287293
"""
288294

289295
def first_non_model_var(var):
@@ -300,11 +306,12 @@ def first_non_model_var(var):
300306
_coords = getattr(fgraph, "_coords", {})
301307
_dim_lengths = getattr(fgraph, "_dim_lengths", {})
302308

303-
fgraph, memo = fgraph.clone_get_equiv(check_integrity=False, attach_feature=False)
304-
# Shared dim lengths are not extracted from the fgraph representation,
305-
# so we need to update after we clone the fgraph
306-
# TODO: Consider representing/extracting them from the fgraph!
307-
_dim_lengths = {k: memo.get(v, v) for k, v in _dim_lengths.items()}
309+
if not mutate_fgraph:
310+
fgraph, memo = fgraph.clone_get_equiv(check_integrity=False, attach_feature=False)
311+
# Shared dim lengths are not extracted from the fgraph representation,
312+
# so we need to update after we clone the fgraph
313+
# TODO: Consider representing/extracting them from the fgraph!
314+
_dim_lengths = {k: memo.get(v, v) for k, v in _dim_lengths.items()}
308315

309316
model._coords = _coords
310317
model._dim_lengths = _dim_lengths
@@ -385,7 +392,7 @@ def clone_model(model: Model) -> Model:
385392
z = pm.Deterministic("z", clone_x + 1)
386393
387394
"""
388-
return model_from_fgraph(fgraph_from_model(model)[0])
395+
return model_from_fgraph(fgraph_from_model(model)[0], mutate_fgraph=True)
389396

390397

391398
def extract_dims(var) -> tuple:

pymc/model/transform/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def prune_vars_detached_from_observed(model: Model) -> Model:
5050
}
5151
for node_to_remove in nodes_to_remove:
5252
fgraph.remove_node(node_to_remove)
53-
return model_from_fgraph(fgraph)
53+
return model_from_fgraph(fgraph, mutate_fgraph=True)
5454

5555

5656
def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> list[Variable]:

pymc/model/transform/conditioning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def observe(
117117

118118
toposort_replace(fgraph, tuple(replacements.items()))
119119

120-
return model_from_fgraph(fgraph)
120+
return model_from_fgraph(fgraph, mutate_fgraph=True)
121121

122122

123123
def do(
@@ -215,7 +215,7 @@ def do(
215215
# Replace variables by interventions
216216
toposort_replace(fgraph, tuple(replacements.items()))
217217

218-
model = model_from_fgraph(fgraph)
218+
model = model_from_fgraph(fgraph, mutate_fgraph=True)
219219
if prune_vars:
220220
return prune_vars_detached_from_observed(model)
221221
return model
@@ -302,7 +302,7 @@ def change_value_transforms(
302302
replacements[dummy_rv] = new_dummy_rv
303303

304304
toposort_replace(fgraph, tuple(replacements.items()))
305-
return model_from_fgraph(fgraph)
305+
return model_from_fgraph(fgraph, mutate_fgraph=True)
306306

307307

308308
def remove_value_transforms(

pymc/model/transform/optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def freeze_dims_and_data(
104104
replacements[old_value] = new_value
105105
fg.replace_all(tuple(replacements.items()), import_missing=True)
106106

107-
return model_from_fgraph(fg)
107+
return model_from_fgraph(fg, mutate_fgraph=True)
108108

109109

110110
__all__ = ("freeze_dims_and_data",)

0 commit comments

Comments
 (0)