Skip to content

Commit 7808f21

Browse files
committed
Rename variables
1 parent 4522258 commit 7808f21

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

maths/cholesky_decomposition.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33

4-
def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
4+
def cholesky_decomposition(matrix: np.ndarray) -> np.ndarray:
55
"""Return a Cholesky decomposition of the matrix A.
66
77
The Cholesky decomposition decomposes the square, positive definite matrix A
@@ -42,11 +42,13 @@ def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
4242
True
4343
"""
4444

45-
assert a.shape[0] == a.shape[1], f"Matrix A is not square, {a.shape=}"
46-
assert np.allclose(a, a.T), "Matrix A must be symmetric"
45+
assert (
46+
matrix.shape[0] == matrix.shape[1]
47+
), f"Input matrix is not square, {matrix.shape=}"
48+
assert np.allclose(matrix, matrix.T), "Input matrix must be symmetric"
4749

48-
n = a.shape[0]
49-
lower_triangle = np.tril(a)
50+
n = matrix.shape[0]
51+
lower_triangle = np.tril(matrix)
5052

5153
for i in range(n):
5254
for j in range(i + 1):
@@ -65,9 +67,13 @@ def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
6567
return lower_triangle
6668

6769

68-
def solve_cholesky(lower_triangle: np.ndarray, y: np.ndarray) -> np.ndarray:
70+
def solve_cholesky(
71+
lower_triangle: np.ndarray,
72+
right_hand_side: np.ndarray,
73+
) -> np.ndarray:
6974
"""Given a Cholesky decomposition L L^T = A of a matrix A, solve the
70-
system of equations A X = Y where Y is either a matrix or a vector.
75+
system of equations A X = Y where the right-hand side Y is either
76+
a matrix or a vector.
7177
7278
>>> L = np.array([[2, 0], [3, 4]], dtype=float)
7379
>>> Y = np.array([[22, 54], [81, 193]], dtype=float)
@@ -84,13 +90,13 @@ def solve_cholesky(lower_triangle: np.ndarray, y: np.ndarray) -> np.ndarray:
8490
), "Matrix L is not lower triangular"
8591

8692
# Handle vector case by reshaping to matrix and then flattening again
87-
if len(y.shape) == 1:
88-
return solve_cholesky(lower_triangle, y.reshape(-1, 1)).ravel()
93+
if len(right_hand_side.shape) == 1:
94+
return solve_cholesky(lower_triangle, right_hand_side.reshape(-1, 1)).ravel()
8995

90-
n = y.shape[0]
96+
n = right_hand_side.shape[0]
9197

92-
# Solve L W = B for W
93-
w = y.copy()
98+
# Solve L W = Y for W
99+
w = right_hand_side.copy()
94100
for i in range(n):
95101
for j in range(i):
96102
w[i] -= lower_triangle[i, j] * w[j]

0 commit comments

Comments
 (0)