1
1
import numpy as np
2
2
3
3
4
- def cholesky_decomposition (a : np .ndarray ) -> np .ndarray :
4
+ def cholesky_decomposition (matrix : np .ndarray ) -> np .ndarray :
5
5
"""Return a Cholesky decomposition of the matrix A.
6
6
7
7
The Cholesky decomposition decomposes the square, positive definite matrix A
@@ -42,11 +42,13 @@ def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
42
42
True
43
43
"""
44
44
45
- assert a .shape [0 ] == a .shape [1 ], f"Matrix A is not square, { a .shape = } "
46
- assert np .allclose (a , a .T ), "Matrix A must be symmetric"
45
+ assert (
46
+ matrix .shape [0 ] == matrix .shape [1 ]
47
+ ), f"Input matrix is not square, { matrix .shape = } "
48
+ assert np .allclose (matrix , matrix .T ), "Input matrix must be symmetric"
47
49
48
- n = a .shape [0 ]
49
- lower_triangle = np .tril (a )
50
+ n = matrix .shape [0 ]
51
+ lower_triangle = np .tril (matrix )
50
52
51
53
for i in range (n ):
52
54
for j in range (i + 1 ):
@@ -65,9 +67,13 @@ def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
65
67
return lower_triangle
66
68
67
69
68
- def solve_cholesky (lower_triangle : np .ndarray , y : np .ndarray ) -> np .ndarray :
70
+ def solve_cholesky (
71
+ lower_triangle : np .ndarray ,
72
+ right_hand_side : np .ndarray ,
73
+ ) -> np .ndarray :
69
74
"""Given a Cholesky decomposition L L^T = A of a matrix A, solve the
70
- system of equations A X = Y where Y is either a matrix or a vector.
75
+ system of equations A X = Y where the right-hand side Y is either
76
+ a matrix or a vector.
71
77
72
78
>>> L = np.array([[2, 0], [3, 4]], dtype=float)
73
79
>>> Y = np.array([[22, 54], [81, 193]], dtype=float)
@@ -84,13 +90,13 @@ def solve_cholesky(lower_triangle: np.ndarray, y: np.ndarray) -> np.ndarray:
84
90
), "Matrix L is not lower triangular"
85
91
86
92
# Handle vector case by reshaping to matrix and then flattening again
87
- if len (y .shape ) == 1 :
88
- return solve_cholesky (lower_triangle , y .reshape (- 1 , 1 )).ravel ()
93
+ if len (right_hand_side .shape ) == 1 :
94
+ return solve_cholesky (lower_triangle , right_hand_side .reshape (- 1 , 1 )).ravel ()
89
95
90
- n = y .shape [0 ]
96
+ n = right_hand_side .shape [0 ]
91
97
92
- # Solve L W = B for W
93
- w = y .copy ()
98
+ # Solve L W = Y for W
99
+ w = right_hand_side .copy ()
94
100
for i in range (n ):
95
101
for j in range (i ):
96
102
w [i ] -= lower_triangle [i , j ] * w [j ]
0 commit comments