Skip to content

Commit b01a53b

Browse files
committed
Fix Gaussian elimination pivoting
1 parent c03f16d commit b01a53b

File tree

1 file changed

+11
-23
lines changed

1 file changed

+11
-23
lines changed

linear_algebra/src/gaussian_elimination_pivoting/gaussian_elimination_pivoting.py

+11-23
Original file line numberDiff line numberDiff line change
@@ -32,40 +32,28 @@ def solve_linear_system(matrix: np.ndarray) -> np.ndarray:
3232
>>> solution = solve_linear_system(np.column_stack((A, B)))
3333
>>> np.allclose(solution, np.array([2., 3., -1.]))
3434
True
35-
>>> solve_linear_system(np.array([[0, 0], [0, 0]], dtype=float))
36-
array([nan, nan])
35+
>>> solve_linear_system(np.array([[0, 0, 0], [0, 0, 0]], dtype=float))
36+
Traceback (most recent call last):
37+
...
38+
ValueError: Matrix is not correct
3739
"""
3840
ab = np.copy(matrix)
3941
num_of_rows = ab.shape[0]
4042
num_of_columns = ab.shape[1] - 1
4143
x_lst: list[float] = []
4244

43-
# Lead element search
44-
for column_num in range(num_of_rows):
45-
for i in range(column_num, num_of_columns):
46-
if abs(ab[i][column_num]) > abs(ab[column_num][column_num]):
47-
ab[[column_num, i]] = ab[[i, column_num]]
48-
if ab[column_num, column_num] == 0.0:
49-
raise ValueError("Matrix is not correct")
50-
else:
51-
pass
52-
if column_num != 0:
53-
for i in range(column_num, num_of_rows):
54-
ab[i, :] -= (
55-
ab[i, column_num - 1]
56-
/ ab[column_num - 1, column_num - 1]
57-
* ab[column_num - 1, :]
58-
)
45+
assert num_of_rows == num_of_columns
5946

60-
# Upper triangular matrix
6147
for column_num in range(num_of_rows):
48+
# Lead element search
6249
for i in range(column_num, num_of_columns):
6350
if abs(ab[i][column_num]) > abs(ab[column_num][column_num]):
6451
ab[[column_num, i]] = ab[[i, column_num]]
65-
if ab[column_num, column_num] == 0.0:
66-
raise ValueError("Matrix is not correct")
67-
else:
68-
pass
52+
53+
# Upper triangular matrix
54+
if ab[column_num, column_num] == 0.0:
55+
raise ValueError("Matrix is not correct")
56+
6957
if column_num != 0:
7058
for i in range(column_num, num_of_rows):
7159
ab[i, :] -= (

0 commit comments

Comments
 (0)