Skip to content

Commit 0d8ddba

Browse files
authored
Fix copying of shared variables in fgraph_from_model (#7153)
* Do not use deprecated ScalarSharedVariable * Recreate SharedVariables with exact type in fgraph_from_model
1 parent ff99e3b commit 0d8ddba

File tree

4 files changed

+55
-30
lines changed

4 files changed

+55
-30
lines changed

pymc/model/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from pytensor.tensor.elemwise import Elemwise
4545
from pytensor.tensor.random.op import RandomVariable
4646
from pytensor.tensor.random.type import RandomType
47-
from pytensor.tensor.sharedvar import ScalarSharedVariable
4847
from pytensor.tensor.variable import TensorConstant, TensorVariable
4948
from typing_extensions import Self
5049

@@ -999,6 +998,7 @@ def add_coord(
999998
length = pytensor.shared(length, name=name)
1000999
else:
10011000
length = pytensor.tensor.constant(length)
1001+
assert length.type.ndim == 0
10021002
self._dim_lengths[name] = length
10031003
self._coords[name] = values
10041004

@@ -1028,7 +1028,7 @@ def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] =
10281028
coord_values
10291029
Optional sequence of coordinate values.
10301030
"""
1031-
if not isinstance(self.dim_lengths[name], ScalarSharedVariable):
1031+
if not isinstance(self.dim_lengths[name], SharedVariable):
10321032
raise ValueError(f"The dimension '{name}' is immutable.")
10331033
if coord_values is None and self.coords.get(name, None) is not None:
10341034
raise ValueError(
@@ -1188,7 +1188,7 @@ def set_data(
11881188
actual=new_length,
11891189
expected=old_length,
11901190
)
1191-
if isinstance(length_tensor, ScalarSharedVariable):
1191+
if isinstance(length_tensor, SharedVariable):
11921192
# The dimension is mutable, but was defined without being linked
11931193
# to a shared variable. This is allowed, but a little less robust.
11941194
self.set_dim(dname, new_length, coord_values=new_coords)

pymc/model/fgraph.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from copy import copy
14+
from copy import copy, deepcopy
1515
from typing import Optional
1616

1717
import pytensor
1818

19-
from pytensor import Variable, shared
19+
from pytensor import Variable
2020
from pytensor.compile import SharedVariable
2121
from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter
2222
from pytensor.graph.rewriting.basic import out2in
2323
from pytensor.scalar import Identity
2424
from pytensor.tensor.elemwise import Elemwise
25-
from pytensor.tensor.sharedvar import ScalarSharedVariable
2625

2726
from pymc.logprob.transforms import Transform
2827
from pymc.model.core import Model
@@ -113,6 +112,21 @@ def local_remove_identity(fgraph, node):
113112
remove_identity_rewrite = out2in(local_remove_identity)
114113

115114

115+
def deepcopy_shared_variable(var: SharedVariable) -> SharedVariable:
116+
# Shared variables don't have a deepcopy method (SharedVariable.clone reuses the old container and contents).
117+
# We recreate Shared Variables manually after deepcopying their container.
118+
new_var = type(var)(
119+
type=var.type,
120+
value=None,
121+
strict=None,
122+
container=deepcopy(var.container),
123+
name=var.name,
124+
)
125+
assert new_var.type == var.type
126+
new_var.tag = copy(var.tag)
127+
return new_var
128+
129+
116130
def fgraph_from_model(
117131
model: Model, inlined_views=False
118132
) -> tuple[FunctionGraph, dict[Variable, Variable]]:
@@ -192,18 +206,7 @@ def fgraph_from_model(
192206
shared_vars_to_copy += [v for v in model.dim_lengths.values() if isinstance(v, SharedVariable)]
193207
shared_vars_to_copy += [v for v in model.named_vars.values() if isinstance(v, SharedVariable)]
194208
for var in shared_vars_to_copy:
195-
# FIXME: ScalarSharedVariables are converted to 0d numpy arrays internally,
196-
# so calling shared(shared(5).get_value()) returns a different type: TensorSharedVariables!
197-
# Furthermore, PyMC silently ignores mutable dim changes that are SharedTensorVariables...
198-
# https://github.com/pymc-devs/pytensor/issues/396
199-
if isinstance(var, ScalarSharedVariable):
200-
new_var = shared(var.get_value(borrow=False).item())
201-
else:
202-
new_var = shared(var.get_value(borrow=False))
203-
204-
assert new_var.type == var.type
205-
new_var.name = var.name
206-
new_var.tag = copy(var.tag)
209+
new_var = deepcopy_shared_variable(var)
207210
# We can replace input variables by placing them in the memo
208211
memo[var] = new_var
209212

tests/model/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pytensor.raise_op import Assert
3636
from pytensor.tensor import TensorVariable
3737
from pytensor.tensor.random.op import RandomVariable
38-
from pytensor.tensor.sharedvar import ScalarSharedVariable
38+
from pytensor.tensor.sharedvar import TensorSharedVariable
3939
from pytensor.tensor.variable import TensorConstant
4040

4141
import pymc as pm
@@ -823,7 +823,7 @@ def test_add_coord_mutable_kwarg():
823823
m.add_coord("fixed", values=[1], mutable=False)
824824
m.add_coord("mutable1", values=[1, 2], mutable=True)
825825
assert isinstance(m._dim_lengths["fixed"], TensorConstant)
826-
assert isinstance(m._dim_lengths["mutable1"], ScalarSharedVariable)
826+
assert isinstance(m._dim_lengths["mutable1"], TensorSharedVariable)
827827
pm.MutableData("mdata", np.ones((1, 2, 3)), dims=("fixed", "mutable1", "mutable2"))
828828
assert isinstance(m._dim_lengths["mutable2"], TensorVariable)
829829

tests/model/test_fgraph.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,16 @@ def test_data(inline_views):
107107
with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old:
108108
x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",))
109109
y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",))
110+
sigma = pm.MutableData("sigma", [1.0], shape=(1,))
110111
b0 = pm.ConstantData("b0", np.zeros((1,)))
111112
b1 = pm.DiracDelta("b1", 1.0)
112113
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
113-
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
114+
obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=y, dims=("test_dim",))
114115

115116
m_fgraph, memo = fgraph_from_model(m_old, inlined_views=inline_views)
116117
assert isinstance(memo[x].owner.op, ModelNamed)
117118
assert isinstance(memo[y].owner.op, ModelNamed)
119+
assert isinstance(memo[sigma].owner.op, ModelNamed)
118120
assert isinstance(memo[b0].owner.op, ModelNamed)
119121
mu_inp = memo[mu].owner.inputs[0]
120122
obs = memo[obs]
@@ -124,10 +126,13 @@ def test_data(inline_views):
124126
assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x].owner.inputs[0]
125127
# ObservedRV(obs, y, *dims) not ObservedRV(obs, Named(y), *dims)
126128
assert obs.owner.inputs[1] is memo[y].owner.inputs[0]
129+
# ObservedRV(Normal(..., sigma), ...) not ObservedRV(Normal(..., Named(sigma)), ...)
130+
assert obs.owner.inputs[0].owner.inputs[4] is memo[sigma].owner.inputs[0]
127131
else:
128132
assert mu_inp.owner.inputs[0] is memo[b0]
129133
assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x]
130134
assert obs.owner.inputs[1] is memo[y]
135+
assert obs.owner.inputs[0].owner.inputs[4] is memo[sigma]
131136

132137
m_new = model_from_fgraph(m_fgraph)
133138

@@ -140,9 +145,17 @@ def test_data(inline_views):
140145
# Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory
141146
assert not same_storage(m_new["x"], x)
142147
assert not same_storage(m_new["y"], y)
148+
assert not same_storage(m_new["sigma"], sigma)
143149
assert not same_storage(m_new["b1"].owner.inputs[0], b1.owner.inputs[0])
144150
assert not same_storage(m_new.dim_lengths["test_dim"], m_old.dim_lengths["test_dim"])
145151

152+
# Check they have the same type
153+
assert m_new["x"].type == x.type
154+
assert m_new["y"].type == y.type
155+
assert m_new["sigma"].type == sigma.type
156+
assert m_new["b1"].owner.inputs[0].type == b1.owner.inputs[0].type
157+
assert m_new.dim_lengths["test_dim"].type == m_old.dim_lengths["test_dim"].type
158+
146159
# Updating model shared variables in new model, doesn't affect old one
147160
with m_new:
148161
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
@@ -155,22 +168,31 @@ def test_data(inline_views):
155168
@config.change_flags(floatX="float64") # Avoid downcasting Ops in the graph
156169
def test_shared_variable():
157170
"""Test that user defined shared variables (other than RNGs) aren't copied."""
158-
x = shared(np.array([1, 2, 3.0]), name="x")
159-
y = shared(np.array([1, 2, 3.0]), name="y")
171+
mu = shared(np.array([1, 2, 3.0]), shape=(None,), name="mu")
172+
sigma = shared(np.array([1.0]), shape=(1,), name="sigma")
173+
obs = shared(np.array([1, 2, 3.0]), shape=(3,), name="obs")
160174

161175
with pm.Model() as m_old:
162-
test = pm.Normal("test", mu=x, observed=y)
176+
test = pm.Normal("test", mu=mu, sigma=sigma, observed=obs)
163177

164-
assert test.owner.inputs[3] is x
165-
assert m_old.rvs_to_values[test] is y
178+
assert test.owner.inputs[3] is mu
179+
assert test.owner.inputs[4] is sigma
180+
assert m_old.rvs_to_values[test] is obs
166181

167182
m_new = clone_model(m_old)
168183
test_new = m_new["test"]
169184
# Shared Variables are cloned but still point to the same memory
170-
assert test_new.owner.inputs[3] is not x
171-
assert m_new.rvs_to_values[test_new] is not y
172-
assert same_storage(test_new.owner.inputs[3], x)
173-
assert same_storage(m_new.rvs_to_values[test_new], y)
185+
mu_new, sigma_new = test_new.owner.inputs[3:5]
186+
obs_new = m_new.rvs_to_values[test_new]
187+
assert mu_new is not mu
188+
assert sigma_new is not sigma
189+
assert obs_new is not obs
190+
assert mu_new.type == mu.type
191+
assert sigma_new.type == sigma.type
192+
assert obs_new.type == obs.type
193+
assert same_storage(mu, mu_new)
194+
assert same_storage(sigma, sigma_new)
195+
assert same_storage(obs, obs_new)
174196

175197

176198
@pytest.mark.parametrize("inline_views", (False, True))

0 commit comments

Comments
 (0)