Skip to content

Commit aec7e8b

Browse files
authored
ENH: Add the linalg extension to the array_api submodule (#19980)
Original NumPy Commit: a1813504ad44b70fb139181a9df8465bcb22e24d
1 parent af9d538 commit aec7e8b

File tree

5 files changed

+419
-79
lines changed

5 files changed

+419
-79
lines changed

array_api_strict/__init__.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@
109109
- The spec is still in an RFC phase and may still have minor updates, which
110110
will need to be reflected here.
111111
112-
- The linear algebra extension in the spec will be added in a future pull
113-
request.
114-
115112
- Complex number support in array API spec is planned but not yet finalized,
116113
as are the fft extension and certain linear algebra functions such as eig
117114
that require complex dtypes.
@@ -334,12 +331,13 @@
334331
"trunc",
335332
]
336333

337-
# einsum is not yet implemented in the array API spec.
334+
# linalg is an extension in the array API spec, which is a sub-namespace. Only
335+
# a subset of functions in it are imported into the top-level namespace.
336+
from . import linalg
338337

339-
# from ._linear_algebra_functions import einsum
340-
# __all__ += ['einsum']
338+
__all__ += ["linalg"]
341339

342-
from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
340+
from .linalg import matmul, tensordot, matrix_transpose, vecdot
343341

344342
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
345343

array_api_strict/_array_object.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ def device(self) -> Device:
10301030
# Note: mT is new in array API spec (see matrix_transpose)
10311031
@property
10321032
def mT(self) -> Array:
1033-
from ._linear_algebra_functions import matrix_transpose
1033+
from .linalg import matrix_transpose
10341034
return matrix_transpose(self)
10351035

10361036
@property

array_api_strict/_linear_algebra_functions.py

-67
This file was deleted.

array_api_strict/_statistical_functions.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,12 @@ def sum(
9393
) -> Array:
9494
if x.dtype not in _numeric_dtypes:
9595
raise TypeError("Only numeric dtypes are allowed in sum")
96-
# Note: sum() and prod() always upcast float32 to float64 for dtype=None
97-
# We need to do so here before summing to avoid overflow
96+
# Note: sum() and prod() always upcast integers to (u)int64 and float32 to
97+
# float64 for dtype=None. `np.sum` does that too for integers, but not for
98+
# float32, so we need to special-case it here
9899
if dtype is None and x.dtype == float32:
99-
x = asarray(x, dtype=float64)
100-
return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims))
100+
dtype = float64
101+
return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
101102

102103

103104
def var(

0 commit comments

Comments
 (0)