Skip to content

Commit adb3f18

Browse files
committed
API: Remove broadcasting ambiguity from np.linalg.solve
Previously the np.linalg.solve documentation stated: a : (..., M, M) array_like Coefficient matrix. b : {(..., M,), (..., M, K)}, array_like however, this is inherently ambiguous. For example, if a has shape (2, 2, 2) and b has shape (2, 2), b could be treated as a (2,) stack of (2,) column vectors, in which case the result should have shape (2, 2), or as a single 2x2 matrix, in which case, the result should have shape (2, 2, 2). NumPy resolved this ambiguity in a confusing way, which was to treat b as (..., M) whenever b.ndim == a.ndim - 1, and as (..., M, K) otherwise. A much more consistent way to handle this ambiguity is to treat b as a single vector if and only if it is 1-dimensional, i.e., use b : {(M,), (..., M, K)}, array_like This is consistent with, for instance, the matmul operator, which only uses the special 1-D vector logic if an operand is exactly 1-dimensional, and treats the operands as (stacks of) 2-D matrices otherwise. This updates np.linalg.solve() to use this behavior. This is a backwards compatibility break, as any instance where the b array has more than one dimension and exactly one fewer dimension than the a array will now use the matrix logic, potentially returning a different result with a different shape. The previous behavior can be manually emulated with something like np.solve(a, b[..., None])[..., 0] since b as a (M,) vector is equivalent to b as a (M, 1) matrix (or the user could manually import and use the internal solve1() gufunc which implements the b-as-vector logic). This change aligns the solve() function with the array API, which resolves this broadcasting ambiguity in this way. See https://data-apis.org/array-api/latest/extensions/generated/array_api.linalg.solve.html#array_api.linalg.solve and data-apis/array-api#285. Fixes numpy#15349 Fixes numpy#25583
1 parent 0b43a0e commit adb3f18

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

numpy/linalg/_linalg.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def solve(a, b):
327327
----------
328328
a : (..., M, M) array_like
329329
Coefficient matrix.
330-
b : {(..., M,), (..., M, K)}, array_like
330+
b : {(M,), (..., M, K)}, array_like
331331
Ordinate or "dependent variable" values.
332332
333333
Returns
@@ -359,6 +359,13 @@ def solve(a, b):
359359
`lstsq` for the least-squares best "solution" of the
360360
system/equation.
361361
362+
.. versionchanged:: 2.0
363+
364+
The b array is only treated as a shape (M,) column vector if it is
365+
exactly 1-dimensional. In all other instances it is treated as a stack
366+
of (M, K) matrices. Previously b would be treated as a stack of (M,)
367+
vectors if b.ndim was equal to a.ndim - 1.
368+
362369
References
363370
----------
364371
.. [1] G. Strang, *Linear Algebra and Its Applications*, 2nd Ed., Orlando,
@@ -390,7 +397,7 @@ def solve(a, b):
390397

391398
# We use the b = (..., M,) logic, only if the number of extra dimensions
392399
# match exactly
393-
if b.ndim == a.ndim - 1:
400+
if b.ndim == 1:
394401
gufunc = _umath_linalg.solve1
395402
else:
396403
gufunc = _umath_linalg.solve

numpy/linalg/tests/test_linalg.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,23 @@ def test_types(self, dtype):
475475
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
476476
assert_equal(linalg.solve(x, x).dtype, dtype)
477477

478+
def test_1_d(self):
479+
class ArraySubclass(np.ndarray):
480+
pass
481+
a = np.arange(8).reshape(2, 2, 2)
482+
b = np.arange(2).view(ArraySubclass)
483+
result = linalg.solve(a, b)
484+
assert result.shape == (2, 2)
485+
486+
# If b is anything other than 1-D it should be treated as a stack of
487+
# matrices
488+
b = np.arange(4).reshape(2, 2).view(ArraySubclass)
489+
result = linalg.solve(a, b)
490+
assert result.shape == (2, 2, 2)
491+
492+
b = np.arange(2).reshape(1, 2).view(ArraySubclass)
493+
assert_raises(ValueError, linalg.solve, a, b)
494+
478495
def test_0_size(self):
479496
class ArraySubclass(np.ndarray):
480497
pass
@@ -497,9 +514,9 @@ class ArraySubclass(np.ndarray):
497514
assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
498515

499516
# Test zero "single equations" with 0x0 matrices.
500-
b = np.arange(2).reshape(1, 2).view(ArraySubclass)
517+
b = np.arange(2).view(ArraySubclass)
501518
expected = linalg.solve(a, b)[:, 0:0]
502-
result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0])
519+
result = linalg.solve(a[:, 0:0, 0:0], b[0:0])
503520
assert_array_equal(result, expected)
504521
assert_(isinstance(result, ArraySubclass))
505522

0 commit comments

Comments
 (0)