Skip to content

Commit a2ced28

Browse files
committed
Make fgraph Deterministic conversion logic more robust
1 parent 0ef82df commit a2ced28

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

pymc_experimental/utils/model_fgraph.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,30 +237,36 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model:
237237
238238
See: fgraph_from_model
239239
"""
240+
241+
def first_non_model_var(var):
242+
if var.owner and isinstance(var.owner.op, ModelVar):
243+
new_var = var.owner.inputs[0]
244+
return first_non_model_var(new_var)
245+
else:
246+
return var
247+
240248
model = Model()
241249
if model.parent is not None:
242250
raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context")
243251
model._coords = getattr(fgraph, "_coords", {})
244252
model._dim_lengths = getattr(fgraph, "_dim_lengths", {})
245253

246254
# Replace dummy `ModelVar` Ops by the underlying variables,
247-
# Except for Deterministics which could reintroduce the old graphs
248255
fgraph = fgraph.clone()
249256
model_dummy_vars = [
250257
model_node.outputs[0]
251258
for model_node in fgraph.toposort()
252259
if isinstance(model_node.op, ModelVar)
253260
]
254261
model_dummy_vars_to_vars = {
255-
dummy_var: dummy_var.owner.inputs[0]
262+
# Deterministics could refer to other model variables directly,
263+
# We make sure to replace them by the first non-model variable
264+
dummy_var: first_non_model_var(dummy_var.owner.inputs[0])
256265
for dummy_var in model_dummy_vars
257-
# Don't include Deterministics!
258-
if not isinstance(dummy_var.owner.op, ModelDeterministic)
259266
}
260267
toposort_replace(fgraph, tuple(model_dummy_vars_to_vars.items()))
261268

262269
# Populate new PyMC model mappings
263-
non_det_model_vars = set(model_dummy_vars_to_vars.values())
264270
for model_var in model_dummy_vars:
265271
if isinstance(model_var.owner.op, ModelFreeRV):
266272
var, value, *dims = model_var.owner.inputs
@@ -279,10 +285,8 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model:
279285
model.potentials.append(var)
280286
elif isinstance(model_var.owner.op, ModelDeterministic):
281287
var, *dims = model_var.owner.inputs
282-
# Register the original var (not the copy) as the Deterministic
283-
# So it shows in the expected place in graphviz.
284-
# unless it's another model var, in which case we need a copy!
285-
if var in non_det_model_vars:
288+
# If a Deterministic is a direct view on an RV, copy it
289+
if var in model.basic_RVs:
286290
var = var.copy()
287291
model.deterministics.append(var)
288292
elif isinstance(model_var.owner.op, ModelNamed):

0 commit comments

Comments
 (0)