Skip to content

Commit 01b0219

Browse files
committed
Align model dim_lengths with cloned variables in model_from_fgraph
1 parent 881030a commit 01b0219

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

pymc/model/fgraph.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,20 @@ def first_non_model_var(var):
296296
model = Model()
297297
if model.parent is not None:
298298
raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context")
299-
model._coords = getattr(fgraph, "_coords", {})
300-
model._dim_lengths = getattr(fgraph, "_dim_lengths", {})
299+
300+
_coords = getattr(fgraph, "_coords", {})
301+
_dim_lengths = getattr(fgraph, "_dim_lengths", {})
302+
303+
fgraph, memo = fgraph.clone_get_equiv(check_integrity=False, attach_feature=False)
304+
# Shared dim lengths are not extracted from the fgraph representation,
305+
# so we need to update after we clone the fgraph
306+
# TODO: Consider representing/extracting them from the fgraph!
307+
_dim_lengths = {k: memo.get(v, v) for k, v in _dim_lengths.items()}
308+
309+
model._coords = _coords
310+
model._dim_lengths = _dim_lengths
301311

302312
# Replace dummy `ModelVar` Ops by the underlying variables,
303-
fgraph = fgraph.clone()
304313
model_dummy_vars = [
305314
model_node.outputs[0]
306315
for model_node in fgraph.toposort()

tests/model/test_fgraph.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717

1818
from pytensor import config, shared
19-
from pytensor.graph import Constant, FunctionGraph, node_rewriter
19+
from pytensor.graph import Constant, FunctionGraph, graph_inputs, node_rewriter
2020
from pytensor.graph.rewriting.basic import in2out
2121
from pytensor.tensor.exceptions import NotScalarConstantError
2222

@@ -164,6 +164,11 @@ def test_data(inline_views):
164164
np.testing.assert_allclose(pm.draw(m_new["mu"]), [100.0, 200.0])
165165
np.testing.assert_allclose(pm.draw(m_old["mu"]), [0.0, 1.0, 2.0], atol=1e-6)
166166

167+
# Check model dim_lengths contains the exact variables used in the graph of RVs
168+
m_new_size_param = m_new["obs"].owner.inputs[1]
169+
[m_new_dim_len] = graph_inputs([m_new_size_param])
170+
assert m_new.dim_lengths["test_dim"] is m_new_dim_len
171+
167172

168173
@config.change_flags(floatX="float64") # Avoid downcasting Ops in the graph
169174
def test_shared_variable():

tests/model/transform/test_optimization.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pytensor.compile import SharedVariable
1818
from pytensor.graph import Constant
1919

20-
from pymc import Deterministic
20+
from pymc import Deterministic, do
2121
from pymc.data import Data
2222
from pymc.distributions import HalfNormal, Normal
2323
from pymc.model import Model
@@ -132,3 +132,15 @@ def test_freeze_dims_and_data_subset():
132132
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
133133
assert isinstance(new_m["data1"], SharedVariable)
134134
assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5])
135+
136+
137+
def test_freeze_dim_after_do_intervention():
138+
with Model(coords={"test_dim": range(5)}) as m:
139+
mu = Data("mu", [0, 1, 2, 3, 4], dims="test_dim")
140+
x = Normal("x", mu=mu, dims="test_dim")
141+
142+
do_m = do(m, {mu: mu * 100})
143+
assert do_m["x"].type.shape == (None,)
144+
145+
frozen_do_m = freeze_dims_and_data(do_m)
146+
assert frozen_do_m["x"].type.shape == (5,)

0 commit comments

Comments
 (0)