1
1
import numpy as np
2
2
3
3
4
- def cholesky_decomposition (a : np .ndarray ) -> np .ndarray :
4
+ # ruff: noqa: N803,N806
5
+ def cholesky_decomposition (A : np .ndarray ) -> np .ndarray :
5
6
"""Return a Cholesky decomposition of the matrix A.
6
7
7
8
The Cholesky decomposition decomposes the square, positive definite matrix A
@@ -41,25 +42,28 @@ def cholesky_decomposition(a: np.ndarray) -> np.ndarray:
41
42
>>> np.allclose(X, X_true)
42
43
True
43
44
"""
44
- assert a .shape [0 ] == a .shape [1 ]
45
- n = a .shape [0 ]
46
- lo = np .tril (a )
47
45
48
- for i in range (n ):
49
- for j in range (i ):
50
- lo [i , j ] = (lo [i , j ] - np .sum (lo [i , :j ] * lo [j , :j ])) / lo [j , j ]
46
+ assert A .shape [0 ] == A .shape [1 ], f"A is not square, { A .shape = } "
51
47
52
- s = lo [i , i ] - np .sum (lo [i , :i ] * lo [i , :i ])
48
+ n = A .shape [0 ]
49
+ L = np .tril (A )
53
50
54
- if s <= 0 :
55
- raise ValueError ("Matrix A is not positive definite" )
51
+ for i in range (n ):
52
+ for j in range (i + 1 ):
53
+ L [i , j ] -= np .sum (L [i , :j ] * L [j , :j ])
56
54
57
- lo [i , i ] = np .sqrt (s )
55
+ if i == j :
56
+ if L [i , i ] <= 0 :
57
+ raise ValueError ("Matrix A is not positive definite" )
58
58
59
- return lo
59
+ L [i , i ] = np .sqrt (L [i , i ])
60
+ else :
61
+ L [i , j ] /= L [j , j ]
60
62
63
+ return L
61
64
62
- def solve_cholesky (lo : np .ndarray , y : np .ndarray ) -> np .ndarray :
65
+
66
+ def solve_cholesky (L : np .ndarray , Y : np .ndarray ) -> np .ndarray :
63
67
"""Given a Cholesky decomposition L L^T = A of a matrix A, solve the
64
68
system of equations A X = Y where B is either a matrix or a vector.
65
69
@@ -70,30 +74,32 @@ def solve_cholesky(lo: np.ndarray, y: np.ndarray) -> np.ndarray:
70
74
True
71
75
"""
72
76
77
+ assert L .shape [0 ] == L .shape [1 ], f"L is not square, { L .shape = } "
78
+ assert np .allclose (np .tril (L ), L ), "L is not lower triangular"
79
+
73
80
# Handle vector case by reshaping to matrix and then flattening again
74
- if len (y .shape ) == 1 :
75
- return solve_cholesky (lo , y .reshape (- 1 , 1 )).ravel ()
81
+ if len (Y .shape ) == 1 :
82
+ return solve_cholesky (L , Y .reshape (- 1 , 1 )).ravel ()
76
83
77
- n , m = y .shape
84
+ n = Y .shape [ 0 ]
78
85
79
- # Backsubstitute L X = B
80
- x = y .copy ()
86
+ # Solve L W = B for W
87
+ W = Y .copy ()
81
88
for i in range (n ):
82
89
for j in range (i ):
83
- x [ i , : ] -= lo [i , j ] * x [ j , : ]
90
+ W [ i ] -= L [i , j ] * W [ j ]
84
91
85
- for k in range (m ):
86
- x [i , k ] /= lo [i , i ]
92
+ W [i ] /= L [i , i ]
87
93
88
- # Backsubstitute L^T
94
+ # Solve L^T X = W for X
95
+ X = W
89
96
for i in reversed (range (n )):
90
97
for j in range (i + 1 , n ):
91
- x [ i , : ] -= lo [j , i ] * x [ j , : ]
98
+ X [ i ] -= L [j , i ] * X [ j ]
92
99
93
- for k in range (m ):
94
- x [i , k ] /= lo [i , i ]
100
+ X [i ] /= L [i , i ]
95
101
96
- return x
102
+ return X
97
103
98
104
99
105
if __name__ == "__main__" :
0 commit comments