Skip to content

Commit 85468b0

Browse files
Improve errors and warnings related to dimension resizing (#5761)
* Support TensorConstant entries in `Model.dim_lengths`. * Raise errors on attempts to resize dims with `TensorConstant` length. * Raise errors on attempts to resize dims that are linked to non-shared variables. * Warn about resizing dims that weren't initialized from a shared variable (supposedly via `add_coord(..., fixed=False)`, see #5763). * Raise errors when attempting to resize a dim that had coord values without providing new coord values. Closes #5760 by anticipating that not all symbolic dim lengths originate from RVs. Co-authored-by: Michael Osthege <[email protected]>
1 parent 3b4da05 commit 85468b0

File tree

2 files changed

+49
-23
lines changed

2 files changed

+49
-23
lines changed

pymc/model.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from aesara.tensor.random.opt import local_subtensor_rv_lift
4747
from aesara.tensor.random.var import RandomStateSharedVariable
4848
from aesara.tensor.sharedvar import ScalarSharedVariable
49-
from aesara.tensor.var import TensorVariable
49+
from aesara.tensor.var import TensorConstant, TensorVariable
5050

5151
from pymc.aesaraf import (
5252
compile_pymc,
@@ -61,7 +61,7 @@
6161
from pymc.distributions import joint_logpt
6262
from pymc.distributions.logprob import _get_scaling
6363
from pymc.distributions.transforms import _default_transform
64-
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
64+
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning
6565
from pymc.initial_point import make_initial_point_fn
6666
from pymc.math import flatten_list
6767
from pymc.util import (
@@ -1179,23 +1179,48 @@ def set_data(
11791179
# Reject resizing if we already know that it would create shape problems.
11801180
# NOTE: If there are multiple pm.MutableData containers sharing this dim, but the user only
11811181
# changes the values for one of them, they will run into shape problems nonetheless.
1182-
length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0]
1183-
if not isinstance(length_belongs_to, SharedVariable) and length_changed:
1184-
raise ShapeError(
1185-
f"Resizing dimension '{dname}' with values of length {new_length} would lead to incompatibilities, "
1186-
f"because the dimension was initialized from '{length_belongs_to}' which is not a shared variable. "
1187-
f"Check if the dimension was defined implicitly before the shared variable '{name}' was created, "
1188-
f"for example by a model variable.",
1189-
actual=new_length,
1190-
expected=old_length,
1191-
)
1192-
if original_coords is not None and length_changed:
1193-
if length_changed and new_coords is None:
1194-
raise ValueError(
1195-
f"The '{name}' variable already had {len(original_coords)} coord values defined for"
1196-
f"its {dname} dimension. With the new values this dimension changes to length "
1197-
f"{new_length}, so new coord values for the {dname} dimension are required."
1182+
if length_changed:
1183+
if isinstance(length_tensor, TensorConstant):
1184+
raise ShapeError(
1185+
f"Resizing dimension '{dname}' is impossible, because "
1186+
f"a 'TensorConstant' stores its length. To be able "
1187+
f"to change the dimension length, 'fixed' in "
1188+
f"'model.add_coord' must be set to `False`."
11981189
)
1190+
if length_tensor.owner is None:
1191+
# This is the case if the dimension was initialized
1192+
# from custom coords, but dimension length was not
1193+
# stored in TensorConstant e.g by 'fixed' set to False
1194+
1195+
warnings.warn(
1196+
f"You're changing the shape of a variable "
1197+
f"in the '{dname}' dimension which was initialized "
1198+
f"from coords. Make sure to update the corresponding "
1199+
f"coords, otherwise you'll get shape issues.",
1200+
ShapeWarning,
1201+
)
1202+
else:
1203+
length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0]
1204+
if not isinstance(length_belongs_to, SharedVariable):
1205+
raise ShapeError(
1206+
f"Resizing dimension '{dname}' with values of length {new_length} would lead to incompatibilities, "
1207+
f"because the dimension was initialized from '{length_belongs_to}' which is not a shared variable. "
1208+
f"Check if the dimension was defined implicitly before the shared variable '{name}' was created, "
1209+
f"for example by another model variable.",
1210+
actual=new_length,
1211+
expected=old_length,
1212+
)
1213+
if original_coords is not None:
1214+
if new_coords is None:
1215+
raise ValueError(
1216+
f"The '{name}' variable already had {len(original_coords)} coord values defined for "
1217+
f"its {dname} dimension. With the new values this dimension changes to length "
1218+
f"{new_length}, so new coord values for the {dname} dimension are required."
1219+
)
1220+
if isinstance(length_tensor, ScalarSharedVariable):
1221+
# Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1222+
length_tensor.set_value(new_length)
1223+
11991224
if new_coords is not None:
12001225
# Update the registered coord values (also if they were None)
12011226
if len(new_coords) != new_length:
@@ -1204,10 +1229,8 @@ def set_data(
12041229
actual=len(new_coords),
12051230
expected=new_length,
12061231
)
1207-
self._coords[dname] = new_coords
1208-
if isinstance(length_tensor, ScalarSharedVariable) and new_length != old_length:
1209-
# Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1210-
length_tensor.set_value(new_length)
1232+
# store it as tuple for immutability as in add_coord
1233+
self._coords[dname] = tuple(new_coords)
12111234

12121235
shared_object.set_value(values)
12131236

pymc/tests/test_data_container.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,10 @@ def test_explicit_coords(self):
316316
# pass coordinates explicitly, use numpy array in Data container
317317
with pm.Model(coords=coords) as pmodel:
318318
pm.MutableData("observations", data, dims=("rows", "columns"))
319-
319+
# new data with same shape
320+
pm.set_data({"observations": data + 1})
321+
# new data with same shape and coords
322+
pm.set_data({"observations": data}, coords=coords)
320323
assert "rows" in pmodel.coords
321324
assert pmodel.coords["rows"] == ("R1", "R2", "R3", "R4", "R5")
322325
assert "rows" in pmodel.dim_lengths

0 commit comments

Comments
 (0)