@@ -46,6 +46,7 @@ def cholesky(x: Array, /, *, upper: bool = False) -> Array:
46
46
return Array ._new (L ).mT
47
47
return Array ._new (L )
48
48
49
+ # Note: cross is the numpy top-level namespace, not np.linalg
49
50
def cross (x1 : Array , x2 : Array , / , * , axis : int = - 1 ) -> Array :
50
51
"""
51
52
Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
@@ -76,6 +77,7 @@ def det(x: Array, /) -> Array:
76
77
raise TypeError ('Only floating-point dtypes are allowed in det' )
77
78
return Array ._new (np .linalg .det (x ._array ))
78
79
80
+ # Note: diagonal is the numpy top-level namespace, not np.linalg
79
81
def diagonal (x : Array , / , * , offset : int = 0 ) -> Array :
80
82
"""
81
83
Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
@@ -130,6 +132,7 @@ def inv(x: Array, /) -> Array:
130
132
return Array ._new (np .linalg .inv (x ._array ))
131
133
132
134
135
+ # Note: matmul is the numpy top-level namespace but not in np.linalg
133
136
def matmul (x1 : Array , x2 : Array , / ) -> Array :
134
137
"""
135
138
Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
@@ -198,13 +201,15 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A
198
201
tol = S .max (axis = - 1 , keepdims = True )* np .asarray (rtol )[..., np .newaxis ]
199
202
return Array ._new (np .count_nonzero (S > tol , axis = - 1 ))
200
203
204
+
201
205
# Note: this function is new in the array API spec. Unlike transpose, it only
202
206
# transposes the last two axes.
203
207
def matrix_transpose (x : Array , / ) -> Array :
204
208
if x .ndim < 2 :
205
209
raise ValueError ("x must be at least 2-dimensional for matrix_transpose" )
206
210
return Array ._new (np .swapaxes (x ._array , - 1 , - 2 ))
207
211
212
+ # Note: outer is the numpy top-level namespace, not np.linalg
208
213
def outer (x1 : Array , x2 : Array , / ) -> Array :
209
214
"""
210
215
Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
@@ -306,6 +311,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
306
311
def svdvals (x : Array , / ) -> Union [Array , Tuple [Array , ...]]:
307
312
return Array ._new (np .linalg .svd (x ._array , compute_uv = False ))
308
313
314
+ # Note: tensordot is the numpy top-level namespace but not in np.linalg
315
+
309
316
# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
310
317
def tensordot (x1 : Array , x2 : Array , / , * , axes : Union [int , Tuple [Sequence [int ], Sequence [int ]]] = 2 ) -> Array :
311
318
# Note: the restriction to numeric dtypes only is different from
@@ -315,6 +322,7 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int],
315
322
316
323
return Array ._new (np .tensordot (x1 ._array , x2 ._array , axes = axes ))
317
324
325
+ # Note: trace is the numpy top-level namespace, not np.linalg
318
326
def trace (x : Array , / , * , offset : int = 0 ) -> Array :
319
327
"""
320
328
Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
0 commit comments