Skip to content

Commit 4522258

Browse files
committed
Rename variables
1 parent 818448b commit 4522258

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

maths/cholesky_decomposition.py

+30-25
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import numpy as np
22

33

4-
# ruff: noqa: N803,N806
5-
def cholesky_decomposition(A: np.ndarray) -> np.ndarray:
4+
def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
65
"""Return a Cholesky decomposition of the matrix A.
76
87
The Cholesky decomposition decomposes the square, positive definite matrix A
@@ -26,7 +25,7 @@ def cholesky_decomposition(A: np.ndarray) -> np.ndarray:
2625
>>> np.allclose(np.tril(L), L)
2726
True
2827
29-
The Cholesky decomposition can be used to solve the system of equations A x = y.
28+
The Cholesky decomposition can be used to solve the linear system A x = y.
3029
3130
>>> x_true = np.array([1, 2, 3], dtype=float)
3231
>>> y = A @ x_true
@@ -43,28 +42,30 @@ def cholesky_decomposition(A: np.ndarray) -> np.ndarray:
4342
True
4443
"""
4544

46-
assert A.shape[0] == A.shape[1], f"Matrix A is not square, {A.shape=}"
47-
assert np.allclose(A, A.T), "Matrix A must be symmetric"
45+
assert a.shape[0] == a.shape[1], f"Matrix A is not square, {a.shape=}"
46+
assert np.allclose(a, a.T), "Matrix A must be symmetric"
4847

49-
n = A.shape[0]
50-
L = np.tril(A)
48+
n = a.shape[0]
49+
lower_triangle = np.tril(a)
5150

5251
for i in range(n):
5352
for j in range(i + 1):
54-
L[i, j] -= np.sum(L[i, :j] * L[j, :j])
53+
lower_triangle[i, j] -= np.sum(
54+
lower_triangle[i, :j] * lower_triangle[j, :j]
55+
)
5556

5657
if i == j:
57-
if L[i, i] <= 0:
58+
if lower_triangle[i, i] <= 0:
5859
raise ValueError("Matrix A is not positive definite")
5960

60-
L[i, i] = np.sqrt(L[i, i])
61+
lower_triangle[i, i] = np.sqrt(lower_triangle[i, i])
6162
else:
62-
L[i, j] /= L[j, j]
63+
lower_triangle[i, j] /= lower_triangle[j, j]
6364

64-
return L
65+
return lower_triangle
6566

6667

67-
def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray:
68+
def solve_cholesky(lower_triangle: np.ndarray, y: np.ndarray) -> np.ndarray:
6869
"""Given a Cholesky decomposition L L^T = A of a matrix A, solve the
6970
system of equations A X = Y where Y is either a matrix or a vector.
7071
@@ -75,32 +76,36 @@ def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray:
7576
True
7677
"""
7778

78-
assert L.shape[0] == L.shape[1], f"Matrix L is not square, {L.shape=}"
79-
assert np.allclose(np.tril(L), L), "Matrix L is not lower triangular"
79+
assert (
80+
lower_triangle.shape[0] == lower_triangle.shape[1]
81+
), f"Matrix L is not square, {lower_triangle.shape=}"
82+
assert np.allclose(
83+
np.tril(lower_triangle), lower_triangle
84+
), "Matrix L is not lower triangular"
8085

8186
# Handle vector case by reshaping to matrix and then flattening again
82-
if len(Y.shape) == 1:
83-
return solve_cholesky(L, Y.reshape(-1, 1)).ravel()
87+
if len(y.shape) == 1:
88+
return solve_cholesky(lower_triangle, y.reshape(-1, 1)).ravel()
8489

85-
n = Y.shape[0]
90+
n = y.shape[0]
8691

8792
# Solve L W = B for W
88-
W = Y.copy()
93+
w = y.copy()
8994
for i in range(n):
9095
for j in range(i):
91-
W[i] -= L[i, j] * W[j]
96+
w[i] -= lower_triangle[i, j] * w[j]
9297

93-
W[i] /= L[i, i]
98+
w[i] /= lower_triangle[i, i]
9499

95100
# Solve L^T X = W for X
96-
X = W
101+
x = w
97102
for i in reversed(range(n)):
98103
for j in range(i + 1, n):
99-
X[i] -= L[j, i] * X[j]
104+
x[i] -= lower_triangle[j, i] * x[j]
100105

101-
X[i] /= L[i, i]
106+
x[i] /= lower_triangle[i, i]
102107

103-
return X
108+
return x
104109

105110

106111
if __name__ == "__main__":

0 commit comments

Comments
 (0)