Skip to content

Commit 0ad689c

Browse files
committed
Allow freezing only subset of data and dims
*Also fix dim_lengths not being returned
1 parent 04b6881 commit 0ad689c

File tree

3 files changed

+146
-26
lines changed

3 files changed

+146
-26
lines changed

pymc/model/core.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,14 @@ def set_dim(self, name: str, new_length: int, coord_values: Sequence | None = No
10501050
expected=new_length,
10511051
)
10521052
self._coords[name] = tuple(coord_values)
1053-
self.dim_lengths[name].set_value(new_length)
1053+
dim_length = self.dim_lengths[name]
1054+
if not isinstance(dim_length, SharedVariable):
1055+
raise TypeError(
1056+
f"The dim_length of `{name}` must be a `SharedVariable` "
1057+
"(created through `coords` to allow updating). "
1058+
f"The current type is: {type(dim_length)}"
1059+
)
1060+
dim_length.set_value(new_length)
10541061
return
10551062

10561063
def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.ndarray]:
@@ -1102,8 +1109,8 @@ def set_data(
11021109
shared_object = self[name]
11031110
if not isinstance(shared_object, SharedVariable):
11041111
raise TypeError(
1105-
f"The variable `{name}` must be a `SharedVariable`"
1106-
" (created through `pm.Data()` or `pm.Data(mutable=True)`) to allow updating. "
1112+
f"The variable `{name}` must be a `SharedVariable` "
1113+
"(created through `pm.Data()` to allow updating.) "
11071114
f"The current type is: {type(shared_object)}"
11081115
)
11091116

pymc/model/transform/optimization.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,27 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from collections.abc import Sequence
15+
1416
from pytensor import clone_replace
1517
from pytensor.compile import SharedVariable
1618
from pytensor.graph import FunctionGraph
1719
from pytensor.tensor import constant
20+
from pytensor.tensor.sharedvar import TensorSharedVariable
21+
from pytensor.tensor.variable import TensorConstant
1822

1923
from pymc import Model
2024
from pymc.model.fgraph import ModelFreeRV, fgraph_from_model, model_from_fgraph
2125

2226

23-
def freeze_dims_and_data(model: Model) -> Model:
27+
def _constant_from_shared(shared: SharedVariable) -> TensorConstant:
28+
assert isinstance(shared, TensorSharedVariable)
29+
return constant(shared.get_value(), name=shared.name, dtype=shared.type.dtype)
30+
31+
32+
def freeze_dims_and_data(
33+
model: Model, dims: Sequence[str] | None = None, data: Sequence[str] | None = None
34+
) -> Model:
2435
"""Recreate a Model with fixed RV dimensions and Data values.
2536
2637
The dimensions of the pre-existing RVs will no longer follow changes to the coordinates.
@@ -30,41 +41,60 @@ def freeze_dims_and_data(model: Model) -> Model:
3041
3142
This transformation may allow more performant sampling, or compiling model functions to backends that
3243
are more restrictive about dynamic shapes such as JAX.
44+
45+
Parameters
46+
----------
47+
model : Model
48+
The model where to freeze dims and data.
49+
dims : Sequence of str, optional
50+
The dimensions to freeze.
51+
If None, all dimensions are frozen. Pass an empty list to avoid freezing any dimension.
52+
data : Sequence of str, optional
53+
The data to freeze.
54+
If None, all data are frozen. Pass an empty list to avoid freezing any data.
55+
56+
Returns
57+
-------
58+
Model
59+
A new model with the specified dimensions and data frozen.
3360
"""
3461
fg, memo = fgraph_from_model(model)
3562

63+
if dims is None:
64+
dims = tuple(model.dim_lengths.keys())
65+
if data is None:
66+
data = tuple(model.named_vars.keys())
67+
3668
# Replace mutable dim lengths and data by constants
37-
frozen_vars = {
38-
memo[dim_length]: constant(
39-
dim_length.get_value(), name=dim_length.name, dtype=dim_length.type.dtype
40-
)
41-
for dim_length in model.dim_lengths.values()
69+
frozen_replacements = {
70+
memo[dim_length]: _constant_from_shared(dim_length)
71+
for dim_length in (model.dim_lengths[dim_name] for dim_name in dims)
4272
if isinstance(dim_length, SharedVariable)
4373
}
44-
frozen_vars |= {
45-
memo[data_var].owner.inputs[0]: constant(
46-
data_var.get_value(), name=data_var.name, dtype=data_var.type.dtype
47-
)
48-
for data_var in model.named_vars.values()
49-
if isinstance(data_var, SharedVariable)
74+
frozen_replacements |= {
75+
memo[datum].owner.inputs[0]: _constant_from_shared(datum)
76+
for datum in (model.named_vars[datum_name] for datum_name in data)
77+
if isinstance(datum, SharedVariable)
5078
}
5179

52-
old_outs, coords = fg.outputs, fg._coords # type: ignore
80+
old_outs, old_coords, old_dim_lenghts = fg.outputs, fg._coords, fg._dim_lengths # type: ignore
5381
# Rebuild strict will force the recreation of RV nodes with updated static types
54-
new_outs = clone_replace(old_outs, replace=frozen_vars, rebuild_strict=False) # type: ignore
82+
new_outs = clone_replace(old_outs, replace=frozen_replacements, rebuild_strict=False) # type: ignore
5583
for old_out, new_out in zip(old_outs, new_outs):
5684
new_out.name = old_out.name
5785
fg = FunctionGraph(outputs=new_outs, clone=False)
58-
fg._coords = coords # type: ignore
86+
fg._coords = old_coords # type: ignore
87+
fg._dim_lengths = { # type: ignore
88+
dim: frozen_replacements.get(dim_length, dim_length)
89+
for dim, dim_length in old_dim_lenghts.items()
90+
}
5991

6092
# Recreate value variables from new RVs to propagate static types to logp graphs
6193
replacements = {}
6294
for node in fg.apply_nodes:
6395
if not isinstance(node.op, ModelFreeRV):
6496
continue
65-
rv, old_value, *dims = node.inputs
66-
if dims is None:
67-
continue
97+
rv, old_value, *_ = node.inputs
6898
transform = node.op.transform
6999
if transform is None:
70100
new_value = rv.type()

tests/model/transform/test_optimization.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,22 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
15+
import pytest
16+
17+
from pytensor.compile import SharedVariable
1418
from pytensor.graph import Constant
1519

20+
from pymc import Deterministic
1621
from pymc.data import Data
1722
from pymc.distributions import HalfNormal, Normal
1823
from pymc.model import Model
1924
from pymc.model.transform.optimization import freeze_dims_and_data
2025

2126

22-
def test_freeze_existing_rv_dims_and_data():
27+
def test_freeze_dims_and_data():
2328
with Model(coords={"test_dim": range(5)}) as m:
24-
std = Data("std", [1])
29+
std = Data("test_data", [1])
2530
x = HalfNormal("x", std, dims=("test_dim",))
2631
y = Normal("y", shape=x.shape[0] + 1)
2732

@@ -34,18 +39,96 @@ def test_freeze_existing_rv_dims_and_data():
3439
assert y_logp.type.shape == (None,)
3540

3641
frozen_m = freeze_dims_and_data(m)
37-
std, x, y = frozen_m["std"], frozen_m["x"], frozen_m["y"]
42+
data, x, y = frozen_m["test_data"], frozen_m["x"], frozen_m["y"]
3843
x_logp, y_logp = frozen_m.logp(sum=False)
39-
assert isinstance(std, Constant)
44+
assert isinstance(data, Constant)
4045
assert x.type.shape == (5,)
4146
assert y.type.shape == (6,)
4247
assert x_logp.type.shape == (5,)
4348
assert y_logp.type.shape == (6,)
4449

50+
# Test trying to update a frozen data or dim raises an informative error
51+
with frozen_m:
52+
with pytest.raises(TypeError, match="The variable `test_data` must be a `SharedVariable`"):
53+
frozen_m.set_data("test_data", values=[2])
54+
with pytest.raises(
55+
TypeError, match="The dim_length of `test_dim` must be a `SharedVariable`"
56+
):
57+
frozen_m.set_dim("test_dim", new_length=6, coord_values=range(6))
58+
59+
# Test we can still update original model
60+
with m:
61+
m.set_data("test_data", values=[2])
62+
m.set_dim("test_dim", new_length=6, coord_values=range(6))
63+
assert m["test_data"].get_value() == [2]
64+
assert m.dim_lengths["test_dim"].get_value() == 6
4565

46-
def test_freeze_rv_dims_nothing_to_change():
66+
67+
def test_freeze_dims_nothing_to_change():
4768
with Model(coords={"test_dim": range(5)}) as m:
4869
x = HalfNormal("x", shape=(5,))
4970
y = Normal("y", shape=x.shape[0] + 1)
5071

5172
assert m.point_logps() == freeze_dims_and_data(m).point_logps()
73+
74+
75+
def test_freeze_dims_and_data_subset():
76+
with Model(coords={"dim1": range(3), "dim2": range(5)}) as m:
77+
data1 = Data("data1", [1, 2, 3], dims="dim1")
78+
data2 = Data("data2", [1, 2, 3, 4, 5], dims="dim2")
79+
var1 = Normal("var1", dims="dim1")
80+
var2 = Normal("var2", dims="dim2")
81+
x = data1 * var1
82+
y = data2 * var2
83+
det = Deterministic("det", x[:, None] + y[None, :])
84+
85+
assert det.type.shape == (None, None)
86+
87+
new_m = freeze_dims_and_data(m, dims=["dim1"], data=[])
88+
assert new_m["det"].type.shape == (3, None)
89+
assert isinstance(new_m.dim_lengths["dim1"], Constant) and new_m.dim_lengths["dim1"].data == 3
90+
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
91+
assert isinstance(new_m["data1"], SharedVariable)
92+
assert isinstance(new_m["data2"], SharedVariable)
93+
94+
new_m = freeze_dims_and_data(m, dims=["dim2"], data=[])
95+
assert new_m["det"].type.shape == (None, 5)
96+
assert isinstance(new_m.dim_lengths["dim1"], SharedVariable)
97+
assert isinstance(new_m.dim_lengths["dim2"], Constant) and new_m.dim_lengths["dim2"].data == 5
98+
assert isinstance(new_m["data1"], SharedVariable)
99+
assert isinstance(new_m["data2"], SharedVariable)
100+
101+
new_m = freeze_dims_and_data(m, dims=["dim1", "dim2"], data=[])
102+
assert new_m["det"].type.shape == (3, 5)
103+
assert isinstance(new_m.dim_lengths["dim1"], Constant) and new_m.dim_lengths["dim1"].data == 3
104+
assert isinstance(new_m.dim_lengths["dim2"], Constant) and new_m.dim_lengths["dim2"].data == 5
105+
assert isinstance(new_m["data1"], SharedVariable)
106+
assert isinstance(new_m["data2"], SharedVariable)
107+
108+
new_m = freeze_dims_and_data(m, dims=[], data=["data1"])
109+
assert new_m["det"].type.shape == (3, None)
110+
assert isinstance(new_m.dim_lengths["dim1"], SharedVariable)
111+
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
112+
assert isinstance(new_m["data1"], Constant) and np.all(new_m["data1"].data == [1, 2, 3])
113+
assert isinstance(new_m["data2"], SharedVariable)
114+
115+
new_m = freeze_dims_and_data(m, dims=[], data=["data2"])
116+
assert new_m["det"].type.shape == (None, 5)
117+
assert isinstance(new_m.dim_lengths["dim1"], SharedVariable)
118+
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
119+
assert isinstance(new_m["data1"], SharedVariable)
120+
assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5])
121+
122+
new_m = freeze_dims_and_data(m, dims=[], data=["data1", "data2"])
123+
assert new_m["det"].type.shape == (3, 5)
124+
assert isinstance(new_m.dim_lengths["dim1"], SharedVariable)
125+
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
126+
assert isinstance(new_m["data1"], Constant) and np.all(new_m["data1"].data == [1, 2, 3])
127+
assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5])
128+
129+
new_m = freeze_dims_and_data(m, dims=["dim1"], data=["data2"])
130+
assert new_m["det"].type.shape == (3, 5)
131+
assert isinstance(new_m.dim_lengths["dim1"], Constant) and new_m.dim_lengths["dim1"].data == 3
132+
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
133+
assert isinstance(new_m["data1"], SharedVariable)
134+
assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5])

0 commit comments

Comments
 (0)