Skip to content

Commit 13f8644

Browse files
committed
Copy model-related shared variables in model_fgraph
1 parent a26d5c3 commit 13f8644

File tree

2 files changed

+72
-20
lines changed

2 files changed

+72
-20
lines changed

pymc_experimental/tests/utils/test_model_fgraph.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pymc as pm
33
import pytensor.tensor as pt
44
import pytest
5+
from pytensor import config, shared
56
from pytensor.graph import Constant, FunctionGraph, node_rewriter
67
from pytensor.graph.rewriting.basic import in2out
78
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -13,6 +14,7 @@
1314
ModelObservedRV,
1415
ModelPotential,
1516
ModelVar,
17+
clone_model,
1618
fgraph_from_model,
1719
model_deterministic,
1820
model_free_rv,
@@ -76,17 +78,22 @@ def test_basic():
7678
)
7779

7880

81+
def same_storage(shared_1, shared_2) -> bool:
82+
"""Check if two shared variables have the same storage containers (i.e., they point to the same memory)."""
83+
return shared_1.container.storage is shared_2.container.storage
84+
85+
7986
@pytest.mark.parametrize("inline_views", (False, True))
8087
def test_data(inline_views):
81-
"""Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly.
88+
"""Test shared RNGs, MutableData, ConstantData and dim lengths are handled correctly.
8289
83-
Everything should be preserved across new and old models, except for shared RNGs
90+
All model-related shared variables should be copied to become independent across models.
8491
"""
8592
with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old:
8693
x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",))
8794
y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",))
88-
b0 = pm.ConstantData("b0", np.zeros(3))
89-
b1 = pm.Normal("b1")
95+
b0 = pm.ConstantData("b0", np.zeros((1,)))
96+
b1 = pm.DiracDelta("b1", 1.0)
9097
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
9198
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
9299

@@ -109,22 +116,46 @@ def test_data(inline_views):
109116

110117
m_new = model_from_fgraph(m_fgraph)
111118

112-
# ConstantData is preserved
113-
assert np.all(m_new["b0"].data == m_old["b0"].data)
114-
115-
# Shared non-rng shared variables are preserved
116-
assert m_new["x"].container is x.container
117-
assert m_new["y"].container is y.container
119+
# The rv-data mapping is preserved
118120
assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"]
119121

120-
# Shared rng shared variables are not preserved
121-
assert m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container
122+
# ConstantData is still accessible as a model variable
123+
np.testing.assert_array_equal(m_new["b0"], m_old["b0"])
122124

123-
with m_old:
124-
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
125+
# Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory
126+
assert not same_storage(m_new["x"], x)
127+
assert not same_storage(m_new["y"], y)
128+
assert not same_storage(m_new["b1"].owner.inputs[0], b1.owner.inputs[0])
129+
assert not same_storage(m_new.dim_lengths["test_dim"], m_old.dim_lengths["test_dim"])
125130

131+
# Updating model shared variables in new model, doesn't affect old one
132+
with m_new:
133+
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
126134
assert m_new.dim_lengths["test_dim"].eval() == 2
127-
np.testing.assert_array_almost_equal(pm.draw(m_new["x"], random_seed=63), [100.0, 200.0])
135+
assert m_old.dim_lengths["test_dim"].eval() == 3
136+
np.testing.assert_allclose(pm.draw(m_new["mu"]), [100.0, 200.0])
137+
np.testing.assert_allclose(pm.draw(m_old["mu"]), [0.0, 1.0, 2.0], atol=1e-6)
138+
139+
140+
@config.change_flags(floatX="float64") # Avoid downcasting Ops in the graph
141+
def test_shared_variable():
142+
"""Test that user defined shared variables (other than RNGs) aren't copied."""
143+
x = shared(np.array([1, 2, 3.0]), name="x")
144+
y = shared(np.array([1, 2, 3.0]), name="y")
145+
146+
with pm.Model() as m_old:
147+
test = pm.Normal("test", mu=x, observed=y)
148+
149+
assert test.owner.inputs[3] is x
150+
assert m_old.rvs_to_values[test] is y
151+
152+
m_new = clone_model(m_old)
153+
test_new = m_new["test"]
154+
# Shared Variables are cloned but still point to the same memory
155+
assert test_new.owner.inputs[3] is not x
156+
assert m_new.rvs_to_values[test_new] is not y
157+
assert same_storage(test_new.owner.inputs[3], x)
158+
assert same_storage(m_new.rvs_to_values[test_new], y)
128159

129160

130161
@pytest.mark.parametrize("inline_views", (False, True))

pymc_experimental/utils/model_fgraph.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
from copy import copy
12
from typing import Dict, Optional, Sequence, Tuple
23

34
import pytensor
45
from pymc.logprob.transforms import RVTransform
56
from pymc.model import Model
67
from pymc.pytensorf import find_rng_nodes
78
from pytensor import Variable, shared
9+
from pytensor.compile import SharedVariable
810
from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter
911
from pytensor.graph.rewriting.basic import out2in
1012
from pytensor.scalar import Identity
1113
from pytensor.tensor.elemwise import Elemwise
14+
from pytensor.tensor.sharedvar import ScalarSharedVariable
1215

1316
from pymc_experimental.utils.pytensorf import StringType
1417

@@ -182,10 +185,28 @@ def fgraph_from_model(
182185

183186
memo = {}
184187

185-
# Replace RNG nodes so that seeding does not interfere with old model
186-
for rng in find_rng_nodes(model_vars):
187-
new_rng = shared(rng.get_value(borrow=False))
188-
memo[rng] = new_rng
188+
# Replace the following shared variables in the model:
189+
# 1. RNGs
190+
# 2. MutableData (could increase memory usage significantly)
191+
# 3. Mutable coords dim lengths
192+
shared_vars_to_copy = find_rng_nodes(model_vars)
193+
shared_vars_to_copy += [v for v in model.dim_lengths.values() if isinstance(v, SharedVariable)]
194+
shared_vars_to_copy += [v for v in model.named_vars.values() if isinstance(v, SharedVariable)]
195+
for var in shared_vars_to_copy:
196+
# FIXME: ScalarSharedVariables are converted to 0d numpy arrays internally,
197+
# so calling shared(shared(5).get_value()) returns a different type: TensorSharedVariables!
198+
# Furthermore, PyMC silently ignores mutable dim changes that are SharedTensorVariables...
199+
# https://github.com/pymc-devs/pytensor/issues/396
200+
if isinstance(var, ScalarSharedVariable):
201+
new_var = shared(var.get_value(borrow=False).item())
202+
else:
203+
new_var = shared(var.get_value(borrow=False))
204+
205+
assert new_var.type == var.type
206+
new_var.name = var.name
207+
new_var.tag = copy(var.tag)
208+
# We can replace input variables by placing them in the memo
209+
memo[var] = new_var
189210

190211
fgraph = FunctionGraph(
191212
outputs=model_vars,
@@ -196,7 +217,7 @@ def fgraph_from_model(
196217
)
197218
# Copy model meta-info to fgraph
198219
fgraph._coords = model._coords.copy()
199-
fgraph._dim_lengths = model._dim_lengths.copy()
220+
fgraph._dim_lengths = {k: memo.get(v, v) for k, v in model._dim_lengths.items()}
200221

201222
rvs_to_transforms = model.rvs_to_transforms
202223
named_vars_to_dims = model.named_vars_to_dims

0 commit comments

Comments
 (0)