diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 42a998c7..8f01593b 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,5 +1,6 @@ from numpy.linalg import * # noqa: F403 from numpy.linalg import __all__ as linalg_all +import numpy as _np from ..common import _linalg from .._internal import get_xp @@ -27,6 +28,53 @@ diagonal = get_xp(np)(_linalg.diagonal) trace = get_xp(np)(_linalg.trace) +# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a +# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack +# of matrices. The np.linalg.solve behavior of allowing stacks of both +# matrices and vectors is ambiguous c.f. +# https://github.com/numpy/numpy/issues/15349 and +# https://github.com/data-apis/array-api/issues/285. + +# To workaround this, the below is the code from np.linalg.solve except +# only calling solve1 in the exactly 1D case. + +# This code is here instead of in common because it is numpy specific. Also +# note that CuPy's solve() does not currently support broadcasting (see +# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). +def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: + try: + from numpy.linalg._linalg import ( + _makearray, _assert_stacked_2d, _assert_stacked_square, + _commonType, isComplexType, _raise_linalgerror_singular + ) + except ImportError: + from numpy.linalg.linalg import ( + _makearray, _assert_stacked_2d, _assert_stacked_square, + _commonType, isComplexType, _raise_linalgerror_singular + ) + from numpy.linalg import _umath_linalg + + x1, _ = _makearray(x1) + _assert_stacked_2d(x1) + _assert_stacked_square(x1) + x2, wrap = _makearray(x2) + t, result_t = _commonType(x1, x2) + + # This part is different from np.linalg.solve + if x2.ndim == 1: + gufunc = _umath_linalg.solve1 + else: + gufunc = _umath_linalg.solve + + # This does nothing currently but is left in because it will be relevant + # when complex dtype support is added to the spec in 2022. + signature = 'DD->D' if isComplexType(t) else 'dd->d' + with _np.errstate(call=_raise_linalgerror_singular, invalid='call', + over='ignore', divide='ignore', under='ignore'): + r = gufunc(x1, x2, signature=signature) + + return wrap(r.astype(result_t, copy=False)) + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np.linalg, 'vector_norm'): @@ -34,7 +82,7 @@ else: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = linalg_all + _linalg.__all__ +__all__ = linalg_all + _linalg.__all__ + ['solve'] del get_xp del np diff --git a/numpy-xfails.txt b/numpy-xfails.txt index c541602b..d0be245b 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -10,12 +10,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] -# linalg tests require https://github.com/data-apis/array-api-tests/pull/101 -# cleanups. Also some tests are using .mT -array_api_tests/test_linalg.py::test_eigvalsh -array_api_tests/test_linalg.py::test_solve -array_api_tests/test_linalg.py::test_trace - # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] array_api_tests/test_signatures.py::test_array_method_signature[to_device]