Skip to content

Commit 89dbcb3

Browse files
committed
Fix the array API diagonal and trace implementations
The NumPy versions operate on the first two axes, but the array API diagonal should operate on the last two, so that they work correctly on stacks of matrices. See data-apis/array-api#241.
1 parent ee49a7c commit 89dbcb3

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

numpy/array_api/linalg.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array:
8484
8585
See its docstring for more information.
8686
"""
87-
return Array._new(np.diagonal(x._array, offset=offset))
87+
# Note: diagonal always operates on the last two axes, whereas np.diagonal
88+
# operates on the first two axes by default
89+
return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
8890

8991

9092
# Note: the keyword argument name upper is different from np.linalg.eigh
@@ -329,7 +331,9 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
329331
330332
See its docstring for more information.
331333
"""
332-
return Array._new(np.asarray(np.trace(x._array, offset=offset)))
334+
# Note: trace always operates on the last two axes, whereas np.trace
335+
# operates on the first two axes by default
336+
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
333337

334338
# Note: vecdot is not in NumPy
335339
def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array:

0 commit comments

Comments
 (0)