@@ -458,16 +458,16 @@ def convert_shape(shape: Shape) -> Optional[WeakShape]:
458
458
"""Process a user-provided shape variable into None or a valid shape object."""
459
459
if shape is None :
460
460
return None
461
-
462
- if isinstance (shape , int ) or (isinstance (shape , TensorVariable ) and shape .ndim == 0 ):
461
+ elif isinstance (shape , int ) or (isinstance (shape , TensorVariable ) and shape .ndim == 0 ):
463
462
shape = (shape ,)
463
+ elif isinstance (shape , TensorVariable ) and shape .ndim == 1 :
464
+ shape = tuple (shape )
464
465
elif isinstance (shape , (list , tuple )):
465
466
shape = tuple (shape )
466
467
else :
467
468
raise ValueError (
468
469
f"The `shape` parameter must be a tuple, TensorVariable, int or list. Actual: { type (shape )} "
469
470
)
470
-
471
471
if isinstance (shape , tuple ) and any (s == Ellipsis for s in shape [:- 1 ]):
472
472
raise ValueError (
473
473
f"Ellipsis in `shape` may only appear in the last position. Actual: { shape } "
@@ -480,16 +480,16 @@ def convert_size(size: Size) -> Optional[StrongSize]:
480
480
"""Process a user-provided size variable into None or a valid size object."""
481
481
if size is None :
482
482
return None
483
-
484
- if isinstance (size , int ) or (isinstance (size , TensorVariable ) and size .ndim == 0 ):
483
+ elif isinstance (size , int ) or (isinstance (size , TensorVariable ) and size .ndim == 0 ):
485
484
size = (size ,)
485
+ elif isinstance (size , TensorVariable ) and size .ndim == 1 :
486
+ size = tuple (size )
486
487
elif isinstance (size , (list , tuple )):
487
488
size = tuple (size )
488
489
else :
489
490
raise ValueError (
490
491
f"The `size` parameter must be a tuple, TensorVariable, int or list. Actual: { type (size )} "
491
492
)
492
-
493
493
if isinstance (size , tuple ) and Ellipsis in size :
494
494
raise ValueError (f"The `size` parameter cannot contain an Ellipsis. Actual: { size } " )
495
495
0 commit comments