Skip to content

Commit ee49a7c

Browse files
committed
Add notes for array API linalg functions that aren't in np.linalg
1 parent b099684 commit ee49a7c

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

numpy/array_api/linalg.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def cholesky(x: Array, /, *, upper: bool = False) -> Array:
4646
return Array._new(L).mT
4747
return Array._new(L)
4848

49+
# Note: cross is the numpy top-level namespace, not np.linalg
4950
def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
5051
"""
5152
Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
@@ -76,6 +77,7 @@ def det(x: Array, /) -> Array:
7677
raise TypeError('Only floating-point dtypes are allowed in det')
7778
return Array._new(np.linalg.det(x._array))
7879

80+
# Note: diagonal is the numpy top-level namespace, not np.linalg
7981
def diagonal(x: Array, /, *, offset: int = 0) -> Array:
8082
"""
8183
Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
@@ -130,6 +132,7 @@ def inv(x: Array, /) -> Array:
130132
return Array._new(np.linalg.inv(x._array))
131133

132134

135+
# Note: matmul is the numpy top-level namespace but not in np.linalg
133136
def matmul(x1: Array, x2: Array, /) -> Array:
134137
"""
135138
Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
@@ -198,13 +201,15 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A
198201
tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis]
199202
return Array._new(np.count_nonzero(S > tol, axis=-1))
200203

204+
201205
# Note: this function is new in the array API spec. Unlike transpose, it only
202206
# transposes the last two axes.
203207
def matrix_transpose(x: Array, /) -> Array:
204208
if x.ndim < 2:
205209
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
206210
return Array._new(np.swapaxes(x._array, -1, -2))
207211

212+
# Note: outer is the numpy top-level namespace, not np.linalg
208213
def outer(x1: Array, x2: Array, /) -> Array:
209214
"""
210215
Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
@@ -306,6 +311,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
306311
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
307312
return Array._new(np.linalg.svd(x._array, compute_uv=False))
308313

314+
# Note: tensordot is the numpy top-level namespace but not in np.linalg
315+
309316
# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
310317
def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
311318
# Note: the restriction to numeric dtypes only is different from
@@ -315,6 +322,7 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int],
315322

316323
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
317324

325+
# Note: trace is the numpy top-level namespace, not np.linalg
318326
def trace(x: Array, /, *, offset: int = 0) -> Array:
319327
"""
320328
Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.

0 commit comments

Comments
 (0)