46
46
from aesara .tensor .random .opt import local_subtensor_rv_lift
47
47
from aesara .tensor .random .var import RandomStateSharedVariable
48
48
from aesara .tensor .sharedvar import ScalarSharedVariable
49
- from aesara .tensor .var import TensorVariable
49
+ from aesara .tensor .var import TensorConstant , TensorVariable
50
50
51
51
from pymc .aesaraf import (
52
52
compile_pymc ,
61
61
from pymc .distributions import joint_logpt
62
62
from pymc .distributions .logprob import _get_scaling
63
63
from pymc .distributions .transforms import _default_transform
64
- from pymc .exceptions import ImputationWarning , SamplingError , ShapeError
64
+ from pymc .exceptions import ImputationWarning , SamplingError , ShapeError , ShapeWarning
65
65
from pymc .initial_point import make_initial_point_fn
66
66
from pymc .math import flatten_list
67
67
from pymc .util import (
@@ -1179,23 +1179,48 @@ def set_data(
1179
1179
# Reject resizing if we already know that it would create shape problems.
1180
1180
# NOTE: If there are multiple pm.MutableData containers sharing this dim, but the user only
1181
1181
# changes the values for one of them, they will run into shape problems nonetheless.
1182
- length_belongs_to = length_tensor .owner .inputs [0 ].owner .inputs [0 ]
1183
- if not isinstance (length_belongs_to , SharedVariable ) and length_changed :
1184
- raise ShapeError (
1185
- f"Resizing dimension '{ dname } ' with values of length { new_length } would lead to incompatibilities, "
1186
- f"because the dimension was initialized from '{ length_belongs_to } ' which is not a shared variable. "
1187
- f"Check if the dimension was defined implicitly before the shared variable '{ name } ' was created, "
1188
- f"for example by a model variable." ,
1189
- actual = new_length ,
1190
- expected = old_length ,
1191
- )
1192
- if original_coords is not None and length_changed :
1193
- if length_changed and new_coords is None :
1194
- raise ValueError (
1195
- f"The '{ name } ' variable already had { len (original_coords )} coord values defined for"
1196
- f"its { dname } dimension. With the new values this dimension changes to length "
1197
- f"{ new_length } , so new coord values for the { dname } dimension are required."
1182
+ if length_changed :
1183
+ if isinstance (length_tensor , TensorConstant ):
1184
+ raise ShapeError (
1185
+ f"Resizing dimension '{ dname } ' is impossible, because "
1186
+ f"a 'TensorConstant' stores its length. To be able "
1187
+ f"to change the dimension length, 'fixed' in "
1188
+ f"'model.add_coord' must be set to `False`."
1198
1189
)
1190
+ if length_tensor .owner is None :
1191
+ # This is the case if the dimension was initialized
1192
+ # from custom coords, but dimension length was not
1193
+ # stored in TensorConstant e.g by 'fixed' set to False
1194
+
1195
+ warnings .warn (
1196
+ f"You're changing the shape of a variable "
1197
+ f"in the '{ dname } ' dimension which was initialized "
1198
+ f"from coords. Make sure to update the corresponding "
1199
+ f"coords, otherwise you'll get shape issues." ,
1200
+ ShapeWarning ,
1201
+ )
1202
+ else :
1203
+ length_belongs_to = length_tensor .owner .inputs [0 ].owner .inputs [0 ]
1204
+ if not isinstance (length_belongs_to , SharedVariable ):
1205
+ raise ShapeError (
1206
+ f"Resizing dimension '{ dname } ' with values of length { new_length } would lead to incompatibilities, "
1207
+ f"because the dimension was initialized from '{ length_belongs_to } ' which is not a shared variable. "
1208
+ f"Check if the dimension was defined implicitly before the shared variable '{ name } ' was created, "
1209
+ f"for example by another model variable." ,
1210
+ actual = new_length ,
1211
+ expected = old_length ,
1212
+ )
1213
+ if original_coords is not None :
1214
+ if new_coords is None :
1215
+ raise ValueError (
1216
+ f"The '{ name } ' variable already had { len (original_coords )} coord values defined for "
1217
+ f"its { dname } dimension. With the new values this dimension changes to length "
1218
+ f"{ new_length } , so new coord values for the { dname } dimension are required."
1219
+ )
1220
+ if isinstance (length_tensor , ScalarSharedVariable ):
1221
+ # Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1222
+ length_tensor .set_value (new_length )
1223
+
1199
1224
if new_coords is not None :
1200
1225
# Update the registered coord values (also if they were None)
1201
1226
if len (new_coords ) != new_length :
@@ -1204,10 +1229,8 @@ def set_data(
1204
1229
actual = len (new_coords ),
1205
1230
expected = new_length ,
1206
1231
)
1207
- self ._coords [dname ] = new_coords
1208
- if isinstance (length_tensor , ScalarSharedVariable ) and new_length != old_length :
1209
- # Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1210
- length_tensor .set_value (new_length )
1232
+ # store it as tuple for immutability as in add_coord
1233
+ self ._coords [dname ] = tuple (new_coords )
1211
1234
1212
1235
shared_object .set_value (values )
1213
1236
0 commit comments