Skip to content

Commit 015bc44

Browse files
committed
BUG: allow ndarray.transpose(axes: List[Int])
1 parent 579c619 commit 015bc44

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torch_np/_funcs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1358,9 +1358,12 @@ def reshape(a: ArrayLike, newshape, order="C"):
13581358
@normalizer
13591359
def transpose(a: ArrayLike, axes=None):
13601360
# numpy allows both .tranpose(sh) and .transpose(*sh)
1361-
axes = axes[0] if len(axes) == 1 else axes
1361+
# also older code uses axes being a list
13621362
if axes in [(), None, (None,)]:
13631363
axes = tuple(range(a.ndim))[::-1]
1364+
elif len(axes) == 1:
1365+
axes = axes[0]
1366+
13641367
try:
13651368
result = a.permute(axes)
13661369
except RuntimeError:

0 commit comments

Comments
 (0)