Skip to content

Copy model-related shared variables in model_fgraph #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc_experimental/tests/utils/test_model_fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_data(inline_views):
assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"]

# Shared rng shared variables are not preserved
m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container
assert m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container

with m_old:
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
Expand Down
5 changes: 2 additions & 3 deletions pymc_experimental/utils/model_fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pymc.logprob.transforms import RVTransform
from pymc.model import Model
from pymc.pytensorf import find_rng_nodes
from pytensor import Variable
from pytensor import Variable, shared
from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.scalar import Identity
Expand Down Expand Up @@ -184,8 +184,7 @@ def fgraph_from_model(

# Replace RNG nodes so that seeding does not interfere with old model
for rng in find_rng_nodes(model_vars):
new_rng = rng.clone()
new_rng.set_value(rng.get_value(borrow=False))
new_rng = shared(rng.get_value(borrow=False))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the old approach didn't work. You were creating a clone, then setting its value to a new generator. Why would wrapping a new shared around a generator work and the old approach not.

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SharedVariable.clone reuses the same container. The only thing that's "cloned" is the Variable object: https://github.com/pymc-devs/pytensor/blob/673c1accf98659ba0457759431cb36eef4659f63/pytensor/compile/sharedvalue.py#L139-L149

memo[rng] = new_rng

fgraph = FunctionGraph(
Expand Down