Skip to content

Commit 416d49f

Browse files
5hv5hvnkmichaelosthege
authored andcommitted
Updated convert_shape and convert_size
Closes #5394
1 parent 09bddde commit 416d49f

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pymc/distributions/shape_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -458,16 +458,16 @@ def convert_shape(shape: Shape) -> Optional[WeakShape]:
458458
"""Process a user-provided shape variable into None or a valid shape object."""
459459
if shape is None:
460460
return None
461-
462-
if isinstance(shape, int) or (isinstance(shape, TensorVariable) and shape.ndim == 0):
461+
elif isinstance(shape, int) or (isinstance(shape, TensorVariable) and shape.ndim == 0):
463462
shape = (shape,)
463+
elif isinstance(shape, TensorVariable) and shape.ndim == 1:
464+
shape = tuple(shape)
464465
elif isinstance(shape, (list, tuple)):
465466
shape = tuple(shape)
466467
else:
467468
raise ValueError(
468469
f"The `shape` parameter must be a tuple, TensorVariable, int or list. Actual: {type(shape)}"
469470
)
470-
471471
if isinstance(shape, tuple) and any(s == Ellipsis for s in shape[:-1]):
472472
raise ValueError(
473473
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
@@ -480,16 +480,16 @@ def convert_size(size: Size) -> Optional[StrongSize]:
480480
"""Process a user-provided size variable into None or a valid size object."""
481481
if size is None:
482482
return None
483-
484-
if isinstance(size, int) or (isinstance(size, TensorVariable) and size.ndim == 0):
483+
elif isinstance(size, int) or (isinstance(size, TensorVariable) and size.ndim == 0):
485484
size = (size,)
485+
elif isinstance(size, TensorVariable) and size.ndim == 1:
486+
size = tuple(size)
486487
elif isinstance(size, (list, tuple)):
487488
size = tuple(size)
488489
else:
489490
raise ValueError(
490491
f"The `size` parameter must be a tuple, TensorVariable, int or list. Actual: {type(size)}"
491492
)
492-
493493
if isinstance(size, tuple) and Ellipsis in size:
494494
raise ValueError(f"The `size` parameter cannot contain an Ellipsis. Actual: {size}")
495495

0 commit comments

Comments
 (0)