@@ -149,7 +149,7 @@ def fn(*args, **kwargs):
149
149
return fn
150
150
151
151
152
- def _make_rv_and_resize_shape (
152
+ def _make_rv_and_resize_shape_from_dims (
153
153
* ,
154
154
cls ,
155
155
dims : Optional [StrongDims ],
@@ -159,21 +159,23 @@ def _make_rv_and_resize_shape(
159
159
** kwargs ,
160
160
) -> Tuple [Variable , StrongShape ]:
161
161
"""Creates the RV, possibly using dims or observed to determine a resize shape (if needed)."""
162
- resize_shape = None
162
+ resize_shape_from_dims = None
163
163
size_or_shape = kwargs .get ("size" ) or kwargs .get ("shape" )
164
164
165
- # Create the RV without dims or observed information
165
+ # Preference is given to size or shape. If not specified, we rely on dims and
166
+ # finally, observed, to determine the shape of the variable. Because dims can be
167
+ # specified on the fly, we need a two-step process where we first create the RV
168
+ # without dims information and then resize it.
169
+ if not size_or_shape and observed is not None :
170
+ kwargs ["shape" ] = tuple (observed .shape )
171
+
172
+ # Create the RV without dims information
166
173
rv_out = cls .dist (* args , ** kwargs )
167
174
168
- # Preference is given to size or shape, if not provided we use dims and observed
169
- # to resize the variable
170
- if not size_or_shape :
171
- if dims is not None :
172
- resize_shape = shape_from_dims (dims , tuple (rv_out .shape ), model )
173
- elif observed is not None :
174
- resize_shape = tuple (observed .shape )
175
+ if not size_or_shape and dims is not None :
176
+ resize_shape_from_dims = shape_from_dims (dims , tuple (rv_out .shape ), model )
175
177
176
- return rv_out , resize_shape
178
+ return rv_out , resize_shape_from_dims
177
179
178
180
179
181
class Distribution (metaclass = DistributionMeta ):
@@ -257,16 +259,17 @@ def __new__(
257
259
if observed is not None :
258
260
observed = convert_observed_data (observed )
259
261
260
- # Create the RV, possibly taking into consideration dims and observed to
261
- # determine its shape
262
- rv_out , resize_shape = _make_rv_and_resize_shape (
262
+ # Create the RV, without taking `dims` into consideration
263
+ rv_out , resize_shape_from_dims = _make_rv_and_resize_shape_from_dims (
263
264
cls = cls , dims = dims , model = model , observed = observed , args = args , ** kwargs
264
265
)
265
266
266
- # A shape was specified only through `dims`, or implied by `observed`.
267
- if resize_shape :
268
- resize_size = find_size (shape = resize_shape , size = None , ndim_supp = cls .rv_op .ndim_supp )
269
- rv_out = change_rv_size (rv = rv_out , new_size = resize_size , expand = False )
267
+ # Resize variable based on `dims` information
268
+ if resize_shape_from_dims :
269
+ resize_size_from_dims = find_size (
270
+ shape = resize_shape_from_dims , size = None , ndim_supp = cls .rv_op .ndim_supp
271
+ )
272
+ rv_out = change_rv_size (rv = rv_out , new_size = resize_size_from_dims , expand = False )
270
273
271
274
rv_out = model .register_rv (
272
275
rv_out ,
@@ -452,16 +455,17 @@ def __new__(
452
455
if observed is not None :
453
456
observed = convert_observed_data (observed )
454
457
455
- # Create the RV, possibly taking into consideration dims and observed to
456
- # determine its shape
457
- rv_out , resize_shape = _make_rv_and_resize_shape (
458
+ # Create the RV, without taking `dims` into consideration
459
+ rv_out , resize_shape_from_dims = _make_rv_and_resize_shape_from_dims (
458
460
cls = cls , dims = dims , model = model , observed = observed , args = args , ** kwargs
459
461
)
460
462
461
- # A shape was specified only through `dims`, or implied by `observed`.
462
- if resize_shape :
463
- resize_size = find_size (shape = resize_shape , size = None , ndim_supp = rv_out .tag .ndim_supp )
464
- rv_out = cls .change_size (rv = rv_out , new_size = resize_size , expand = False )
463
+ # Resize variable based on `dims` information
464
+ if resize_shape_from_dims :
465
+ resize_size_from_dims = find_size (
466
+ shape = resize_shape_from_dims , size = None , ndim_supp = rv_out .tag .ndim_supp
467
+ )
468
+ rv_out = cls .change_size (rv = rv_out , new_size = resize_size_from_dims , expand = False )
465
469
466
470
rv_out = model .register_rv (
467
471
rv_out ,
0 commit comments