Skip to content

Commit db3172e

Browse files
committed
Fix bug in model_fgraph where RNGs weren't being copied as advertised
1 parent 430c3c8 commit db3172e

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

pymc_experimental/tests/utils/test_model_fgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_data(inline_views):
118118
assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"]
119119

120120
# Shared rng shared variables are not preserved
121-
m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container
121+
assert m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container
122122

123123
with m_old:
124124
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})

pymc_experimental/utils/model_fgraph.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pymc.logprob.transforms import RVTransform
55
from pymc.model import Model
66
from pymc.pytensorf import find_rng_nodes
7-
from pytensor import Variable
7+
from pytensor import Variable, shared
88
from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter
99
from pytensor.graph.rewriting.basic import out2in
1010
from pytensor.scalar import Identity
@@ -184,8 +184,7 @@ def fgraph_from_model(
184184

185185
# Replace RNG nodes so that seeding does not interfere with old model
186186
for rng in find_rng_nodes(model_vars):
187-
new_rng = rng.clone()
188-
new_rng.set_value(rng.get_value(borrow=False))
187+
new_rng = shared(rng.get_value(borrow=False))
189188
memo[rng] = new_rng
190189

191190
fgraph = FunctionGraph(

0 commit comments

Comments
 (0)