@@ -237,30 +237,36 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model:
237
237
238
238
See: fgraph_from_model
239
239
"""
240
+
241
+ def first_non_model_var (var ):
242
+ if var .owner and isinstance (var .owner .op , ModelVar ):
243
+ new_var = var .owner .inputs [0 ]
244
+ return first_non_model_var (new_var )
245
+ else :
246
+ return var
247
+
240
248
model = Model ()
241
249
if model .parent is not None :
242
250
raise RuntimeError ("model_to_fgraph cannot be called inside a PyMC model context" )
243
251
model ._coords = getattr (fgraph , "_coords" , {})
244
252
model ._dim_lengths = getattr (fgraph , "_dim_lengths" , {})
245
253
246
254
# Replace dummy `ModelVar` Ops by the underlying variables,
247
- # Except for Deterministics which could reintroduce the old graphs
248
255
fgraph = fgraph .clone ()
249
256
model_dummy_vars = [
250
257
model_node .outputs [0 ]
251
258
for model_node in fgraph .toposort ()
252
259
if isinstance (model_node .op , ModelVar )
253
260
]
254
261
model_dummy_vars_to_vars = {
255
- dummy_var : dummy_var .owner .inputs [0 ]
262
+ # Deterministics could refer to other model variables directly,
263
+ # We make sure to replace them by the first non-model variable
264
+ dummy_var : first_non_model_var (dummy_var .owner .inputs [0 ])
256
265
for dummy_var in model_dummy_vars
257
- # Don't include Deterministics!
258
- if not isinstance (dummy_var .owner .op , ModelDeterministic )
259
266
}
260
267
toposort_replace (fgraph , tuple (model_dummy_vars_to_vars .items ()))
261
268
262
269
# Populate new PyMC model mappings
263
- non_det_model_vars = set (model_dummy_vars_to_vars .values ())
264
270
for model_var in model_dummy_vars :
265
271
if isinstance (model_var .owner .op , ModelFreeRV ):
266
272
var , value , * dims = model_var .owner .inputs
@@ -279,10 +285,8 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model:
279
285
model .potentials .append (var )
280
286
elif isinstance (model_var .owner .op , ModelDeterministic ):
281
287
var , * dims = model_var .owner .inputs
282
- # Register the original var (not the copy) as the Deterministic
283
- # So it shows in the expected place in graphviz.
284
- # unless it's another model var, in which case we need a copy!
285
- if var in non_det_model_vars :
288
+ # If a Deterministic is a direct view on an RV, copy it
289
+ if var in model .basic_RVs :
286
290
var = var .copy ()
287
291
model .deterministics .append (var )
288
292
elif isinstance (model_var .owner .op , ModelNamed ):
0 commit comments