Skip to content

Commit 307ce1c

Browse files
committed
Simplify equations, rename variables
1 parent 8019213 commit 307ce1c

File tree

1 file changed

+32
-26
lines changed

1 file changed

+32
-26
lines changed

maths/cholesky_decomposition.py

+32-26
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22

33

4-
def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
4+
# ruff: noqa: N803,N806
5+
def cholesky_decomposition(A: np.ndarray) -> np.ndarray:
56
"""Return a Cholesky decomposition of the matrix A.
67
78
The Cholesky decomposition decomposes the square, positive definite matrix A
@@ -41,25 +42,28 @@ def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
4142
>>> np.allclose(X, X_true)
4243
True
4344
"""
44-
assert a.shape[0] == a.shape[1]
45-
n = a.shape[0]
46-
lo = np.tril(a)
4745

48-
for i in range(n):
49-
for j in range(i):
50-
lo[i, j] = (lo[i, j] - np.sum(lo[i, :j] * lo[j, :j])) / lo[j, j]
46+
assert A.shape[0] == A.shape[1], f"A is not square, {A.shape=}"
5147

52-
s = lo[i, i] - np.sum(lo[i, :i] * lo[i, :i])
48+
n = A.shape[0]
49+
L = np.tril(A)
5350

54-
if s <= 0:
55-
raise ValueError("Matrix A is not positive definite")
51+
for i in range(n):
52+
for j in range(i + 1):
53+
L[i, j] -= np.sum(L[i, :j] * L[j, :j])
5654

57-
lo[i, i] = np.sqrt(s)
55+
if i == j:
56+
if L[i, i] <= 0:
57+
raise ValueError("Matrix A is not positive definite")
5858

59-
return lo
59+
L[i, i] = np.sqrt(L[i, i])
60+
else:
61+
L[i, j] /= L[j, j]
6062

63+
return L
6164

62-
def solve_cholesky(lo: np.ndarray, y: np.ndarray) -> np.ndarray:
65+
66+
def solve_cholesky(L: np.ndarray, Y: np.ndarray) -> np.ndarray:
6367
"""Given a Cholesky decomposition L L^T = A of a matrix A, solve the
6468
system of equations A X = Y where B is either a matrix or a vector.
6569
@@ -70,30 +74,32 @@ def solve_cholesky(lo: np.ndarray, y: np.ndarray) -> np.ndarray:
7074
True
7175
"""
7276

77+
assert L.shape[0] == L.shape[1], f"L is not square, {L.shape=}"
78+
assert np.allclose(np.tril(L), L), "L is not lower triangular"
79+
7380
# Handle vector case by reshaping to matrix and then flattening again
74-
if len(y.shape) == 1:
75-
return solve_cholesky(lo, y.reshape(-1, 1)).ravel()
81+
if len(Y.shape) == 1:
82+
return solve_cholesky(L, Y.reshape(-1, 1)).ravel()
7683

77-
n, m = y.shape
84+
n = Y.shape[0]
7885

79-
# Backsubstitute L X = B
80-
x = y.copy()
86+
# Solve L W = B for W
87+
W = Y.copy()
8188
for i in range(n):
8289
for j in range(i):
83-
x[i, :] -= lo[i, j] * x[j, :]
90+
W[i] -= L[i, j] * W[j]
8491

85-
for k in range(m):
86-
x[i, k] /= lo[i, i]
92+
W[i] /= L[i, i]
8793

88-
# Backsubstitute L^T
94+
# Solve L^T X = W for X
95+
X = W
8996
for i in reversed(range(n)):
9097
for j in range(i + 1, n):
91-
x[i, :] -= lo[j, i] * x[j, :]
98+
X[i] -= L[j, i] * X[j]
9299

93-
for k in range(m):
94-
x[i, k] /= lo[i, i]
100+
X[i] /= L[i, i]
95101

96-
return x
102+
return X
97103

98104

99105
if __name__ == "__main__":

0 commit comments

Comments
 (0)