Skip to content

Commit b83a378

Browse files
committed
Workaround np.linalg.solve ambiguity
NumPy's solve() does not handle the ambiguity around x2 being 1-D vector vs. an n-D stack of matrices in the way that the standard specifies. Namely, x2 should be treated as a 1-D vector iff it is 1-dimensional, and a stack of matrices in all other cases. This workaround is borrowed from array-api-strict. See numpy/numpy#15349 and data-apis/array-api#285. Note that this workaround only works for NumPy. CuPy currently does not support stacked vectors for solve() (see https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43), and the workaround in cupy.array_api.linalg does not seem to actually function.
1 parent 645f9a8 commit b83a378

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

array_api_compat/numpy/linalg.py

+48
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,53 @@
3333
vector_norm,
3434
)
3535

36+
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
37+
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
38+
# of matrices. The np.linalg.solve behavior of allowing stacks of both
39+
# matrices and vectors is ambiguous c.f.
40+
# https://github.com/numpy/numpy/issues/15349 and
41+
# https://github.com/data-apis/array-api/issues/285.
42+
43+
# To workaround this, the below is the code from np.linalg.solve except
44+
# only calling solve1 in the exactly 1D case.
45+
46+
# This code is here instead of in common because it is numpy specific. Also
47+
# note that CuPy's solve() does not currently support broadcasting (see
48+
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
49+
def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
50+
try:
51+
from numpy.linalg._linalg import (
52+
_makearray, _assert_stacked_2d, _assert_stacked_square,
53+
_commonType, isComplexType, _raise_linalgerror_singular
54+
)
55+
except ImportError:
56+
from numpy.linalg.linalg import (
57+
_makearray, _assert_stacked_2d, _assert_stacked_square,
58+
_commonType, isComplexType, _raise_linalgerror_singular
59+
)
60+
from numpy.linalg import _umath_linalg
61+
62+
x1, _ = _makearray(x1)
63+
_assert_stacked_2d(x1)
64+
_assert_stacked_square(x1)
65+
x2, wrap = _makearray(x2)
66+
t, result_t = _commonType(x1, x2)
67+
68+
# This part is different from np.linalg.solve
69+
if x2.ndim == 1:
70+
gufunc = _umath_linalg.solve1
71+
else:
72+
gufunc = _umath_linalg.solve
73+
74+
# This does nothing currently but is left in because it will be relevant
75+
# when complex dtype support is added to the spec in 2022.
76+
signature = 'DD->D' if isComplexType(t) else 'dd->d'
77+
with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
78+
over='ignore', divide='ignore', under='ignore'):
79+
r = gufunc(x1, x2, signature=signature)
80+
81+
return wrap(r.astype(result_t, copy=False))
82+
3683
__all__ = []
3784

3885
__all__ += _numpy_linalg_all
@@ -54,6 +101,7 @@
54101
"pinv",
55102
"qr",
56103
"slogdet",
104+
"solve",
57105
"svd",
58106
"svdvals",
59107
"tensordot",

0 commit comments

Comments
 (0)