Skip to content

Commit 472e03f

Browse files
authored
Merge pull request #20499 from asmeurer/array_api_T
BUG: Fix the .T attribute in the array_api namespace Original NumPy Commit: 742f13f7ee6ff8ed56fc468c9ef57b3853141768
2 parents 11ebcb6 + 0ecf503 commit 472e03f

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

array_api_strict/_array_object.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1072,4 +1072,4 @@ def T(self) -> Array:
10721072
# https://data-apis.org/array-api/latest/API_specification/array_object.html#t
10731073
if self.ndim != 2:
10741074
raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.")
1075-
return self._array.T
1075+
return self.__class__._new(self._array.T)

array_api_strict/tests/test_array_object.py

+14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from numpy. import ones, asarray, result_type, all, equal
7+
from numpy._array_object import Array
78
from numpy._dtypes import (
89
_all_dtypes,
910
_boolean_dtypes,
@@ -301,3 +302,16 @@ def test_device_property():
301302

302303
assert all(equal(asarray(a, device='cpu'), a))
303304
assert_raises(ValueError, lambda: asarray(a, device='gpu'))
305+
306+
def test_array_properties():
307+
a = ones((1, 2, 3))
308+
b = ones((2, 3))
309+
assert_raises(ValueError, lambda: a.T)
310+
311+
assert isinstance(b.T, Array)
312+
assert b.T.shape == (3, 2)
313+
314+
assert isinstance(a.mT, Array)
315+
assert a.mT.shape == (1, 3, 2)
316+
assert isinstance(b.mT, Array)
317+
assert b.mT.shape == (3, 2)

0 commit comments

Comments
 (0)