@@ -211,9 +211,7 @@ class TestShapeDimsSize:
211
211
[
212
212
"implicit" ,
213
213
"shape" ,
214
- "shape..." ,
215
214
"dims" ,
216
- "dims..." ,
217
215
"size" ,
218
216
],
219
217
)
@@ -249,65 +247,36 @@ def test_param_and_batch_shape_combos(
249
247
if parametrization == "shape" :
250
248
rv = pm .Normal ("rv" , mu = mu , shape = batch_shape + param_shape )
251
249
assert rv .eval ().shape == expected_shape
252
- elif parametrization == "shape..." :
253
- rv = pm .Normal ("rv" , mu = mu , shape = (* batch_shape , ...))
254
- assert rv .eval ().shape == batch_shape + param_shape
255
250
elif parametrization == "dims" :
256
251
rv = pm .Normal ("rv" , mu = mu , dims = batch_dims + param_dims )
257
252
assert rv .eval ().shape == expected_shape
258
- elif parametrization == "dims..." :
259
- rv = pm .Normal ("rv" , mu = mu , dims = (* batch_dims , ...))
260
- n_size = len (batch_shape )
261
- n_implied = len (param_shape )
262
- ndim = n_size + n_implied
263
- assert len (pmodel .RV_dims ["rv" ]) == ndim , pmodel .RV_dims
264
- assert len (pmodel .RV_dims ["rv" ][:n_size ]) == len (batch_dims )
265
- assert len (pmodel .RV_dims ["rv" ][n_size :]) == len (param_dims )
266
- if n_implied > 0 :
267
- assert pmodel .RV_dims ["rv" ][- 1 ] is None
268
253
elif parametrization == "size" :
269
254
rv = pm .Normal ("rv" , mu = mu , size = batch_shape + param_shape )
270
255
assert rv .eval ().shape == expected_shape
271
256
else :
272
257
raise NotImplementedError ("Invalid test case parametrization." )
273
258
274
- @pytest .mark .parametrize ("ellipsis_in" , ["none" , "shape" , "dims" , "both" ])
275
- def test_simultaneous_shape_and_dims (self , ellipsis_in ):
259
+ def test_simultaneous_shape_and_dims (self ):
276
260
with pm .Model () as pmodel :
277
261
x = pm .ConstantData ("x" , [1 , 2 , 3 ], dims = "ddata" )
278
262
279
- if ellipsis_in == "none" :
280
- # The shape and dims tuples correspond to each other.
281
- # Note: No checks are performed that implied shape (x), shape and dims actually match.
282
- y = pm .Normal ("y" , mu = x , shape = (2 , 3 ), dims = ("dshape" , "ddata" ))
283
- assert pmodel .RV_dims ["y" ] == ("dshape" , "ddata" )
284
- elif ellipsis_in == "shape" :
285
- y = pm .Normal ("y" , mu = x , shape = (2 , ...), dims = ("dshape" , "ddata" ))
286
- assert pmodel .RV_dims ["y" ] == ("dshape" , "ddata" )
287
- elif ellipsis_in == "dims" :
288
- y = pm .Normal ("y" , mu = x , shape = (2 , 3 ), dims = ("dshape" , ...))
289
- assert pmodel .RV_dims ["y" ] == ("dshape" , None )
290
- elif ellipsis_in == "both" :
291
- y = pm .Normal ("y" , mu = x , shape = (2 , ...), dims = ("dshape" , ...))
292
- assert pmodel .RV_dims ["y" ] == ("dshape" , None )
263
+ # The shape and dims tuples correspond to each other.
264
+ # Note: No checks are performed that implied shape (x), shape and dims actually match.
265
+ y = pm .Normal ("y" , mu = x , shape = (2 , 3 ), dims = ("dshape" , "ddata" ))
266
+ assert pmodel .RV_dims ["y" ] == ("dshape" , "ddata" )
293
267
294
268
assert "dshape" in pmodel .dim_lengths
295
269
assert y .eval ().shape == (2 , 3 )
296
270
297
- @pytest .mark .parametrize ("with_dims_ellipsis" , [False , True ])
298
- def test_simultaneous_size_and_dims (self , with_dims_ellipsis ):
271
+ def test_simultaneous_size_and_dims (self ):
299
272
with pm .Model () as pmodel :
300
273
x = pm .ConstantData ("x" , [1 , 2 , 3 ], dims = "ddata" )
301
274
assert "ddata" in pmodel .dim_lengths
302
275
303
276
# Size does not include support dims, so this test must use a dist with support dims.
304
277
kwargs = dict (name = "y" , size = (2 , 3 ), mu = at .ones ((3 , 4 )), cov = at .eye (4 ))
305
- if with_dims_ellipsis :
306
- y = pm .MvNormal (** kwargs , dims = ("dsize" , ...))
307
- assert pmodel .RV_dims ["y" ] == ("dsize" , None , None )
308
- else :
309
- y = pm .MvNormal (** kwargs , dims = ("dsize" , "ddata" , "dsupport" ))
310
- assert pmodel .RV_dims ["y" ] == ("dsize" , "ddata" , "dsupport" )
278
+ y = pm .MvNormal (** kwargs , dims = ("dsize" , "ddata" , "dsupport" ))
279
+ assert pmodel .RV_dims ["y" ] == ("dsize" , "ddata" , "dsupport" )
311
280
312
281
assert "dsize" in pmodel .dim_lengths
313
282
assert y .eval ().shape == (2 , 3 , 4 )
@@ -382,7 +351,6 @@ def test_dist_api_works(self):
382
351
pm .Normal .dist (mu = mu , dims = ("town" ,))
383
352
assert pm .Normal .dist (mu = mu , shape = (3 ,)).eval ().shape == (3 ,)
384
353
assert pm .Normal .dist (mu = mu , shape = (5 , 3 )).eval ().shape == (5 , 3 )
385
- assert pm .Normal .dist (mu = mu , shape = (7 , ...)).eval ().shape == (7 , 3 )
386
354
assert pm .Normal .dist (mu = mu , size = (3 ,)).eval ().shape == (3 ,)
387
355
assert pm .Normal .dist (mu = mu , size = (4 , 3 )).eval ().shape == (4 , 3 )
388
356
@@ -408,10 +376,6 @@ def test_mvnormal_shape_size_difference(self):
408
376
assert rv .ndim == 3
409
377
assert tuple (rv .shape .eval ()) == (5 , 4 , 3 )
410
378
411
- rv = pm .MvNormal .dist (mu = np .ones ((4 , 3 , 2 )), cov = np .eye (2 ), shape = (6 , 5 , ...))
412
- assert rv .ndim == 5
413
- assert tuple (rv .shape .eval ()) == (6 , 5 , 4 , 3 , 2 )
414
-
415
379
rv = pm .MvNormal .dist (mu = [1 , 2 , 3 ], cov = np .eye (3 ), size = (5 , 4 ))
416
380
assert tuple (rv .shape .eval ()) == (5 , 4 , 3 )
417
381
@@ -422,22 +386,16 @@ def test_convert_dims(self):
422
386
assert convert_dims (dims = "town" ) == ("town" ,)
423
387
with pytest .raises (ValueError , match = "must be a tuple, str or list" ):
424
388
convert_dims (3 )
425
- with pytest .raises (ValueError , match = "may only appear in the last position" ):
426
- convert_dims (dims = (..., "town" ))
427
389
428
390
def test_convert_shape (self ):
429
391
assert convert_shape (5 ) == (5 ,)
430
392
with pytest .raises (ValueError , match = "tuple, TensorVariable, int or list" ):
431
393
convert_shape (shape = "notashape" )
432
- with pytest .raises (ValueError , match = "may only appear in the last position" ):
433
- convert_shape (shape = (3 , ..., 2 ))
434
394
435
395
def test_convert_size (self ):
436
396
assert convert_size (7 ) == (7 ,)
437
397
with pytest .raises (ValueError , match = "tuple, TensorVariable, int or list" ):
438
398
convert_size (size = "notasize" )
439
- with pytest .raises (ValueError , match = "cannot contain" ):
440
- convert_size (size = (3 , ...))
441
399
442
400
def test_lazy_flavors (self ):
443
401
assert pm .Uniform .dist (2 , [4 , 5 ], size = [3 , 2 ]).eval ().shape == (3 , 2 )
0 commit comments