|
28 | 28 |
|
29 | 29 |
|
30 | 30 | class Cholesky(Op):
|
31 |
| - """ |
32 |
| - Return a triangular matrix square root of positive semi-definite `x`. |
33 |
| -
|
34 |
| - L = cholesky(X, lower=True) implies dot(L, L.T) == X. |
35 |
| -
|
36 |
| - Parameters |
37 |
| - ---------- |
38 |
| - lower : bool, default=True |
39 |
| - Whether to return the lower or upper cholesky factor |
40 |
| - on_error : ['raise', 'nan'] |
41 |
| - If on_error is set to 'raise', this Op will raise a |
42 |
| - `scipy.linalg.LinAlgError` if the matrix is not positive definite. |
43 |
| - If on_error is set to 'nan', it will return a matrix containing |
44 |
| - nans instead. |
45 |
| - overwrite_a: bool, ignored |
46 |
| - Whether to use the same memory for the output as `a`. This argument is ignored, and |
47 |
| - included only for consistency with scipy.linalg.cholesky. |
48 |
| - """ |
49 |
| - |
50 | 31 | # TODO: for specific dtypes
|
51 | 32 | # TODO: LAPACK wrapper with in-place behavior, for solve also
|
52 | 33 |
|
@@ -148,9 +129,49 @@ def conjugate_solve_triangular(outer, inner):
|
148 | 129 |
|
149 | 130 |
|
150 | 131 | def cholesky(x, lower=True, on_error="raise", overwrite_a=False):
|
151 |
| - return Blockwise(Cholesky(lower=lower, on_error=on_error, overwrite_a=overwrite_a))( |
152 |
| - x |
153 |
| - ) |
| 132 | + """ |
| 133 | + Return a triangular matrix square root of positive semi-definite `x`. |
| 134 | +
|
| 135 | + L = cholesky(X, lower=True) implies dot(L, L.T) == X. |
| 136 | +
|
| 137 | + Parameters |
| 138 | + ---------- |
| 139 | + lower : bool, default=True |
| 140 | + Whether to return the lower or upper cholesky factor |
| 141 | + on_error : ['raise', 'nan'] |
| 142 | + If on_error is set to 'raise', this Op will raise a |
| 143 | + `scipy.linalg.LinAlgError` if the matrix is not positive definite. |
| 144 | + If on_error is set to 'nan', it will return a matrix containing |
| 145 | + nans instead. |
| 146 | + overwrite_a: bool, ignored |
| 147 | + Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only |
| 148 | + for consistency with scipy.linalg.cholesky. |
| 149 | +
|
| 150 | + Returns |
| 151 | + ------- |
| 152 | + TensorVariable |
| 153 | + Lower or upper triangular Cholesky factor of `x` |
| 154 | +
|
| 155 | + Example |
| 156 | + ------- |
| 157 | + .. code-block:: python |
| 158 | +
|
| 159 | + import pytensor |
| 160 | + import pytensor.tensor as pt |
| 161 | + import numpy as np |
| 162 | +
|
| 163 | + x = pt.tensor('x', size=(5, 5), dtype='float64') |
| 164 | + L = pt.linalg.cholesky(x) |
| 165 | +
|
| 166 | + f = pytensor.function([x], L) |
| 167 | + x_value = np.random.normal(size=(5, 5)) |
| 168 | + x_value = x_value @ x_value.T # Ensures x is positive definite |
| 169 | + L_value = f(x_value) |
| 170 | + print(np.allclose(L_value @ L_value.T, x_value)) |
| 171 | + >>> True |
| 172 | + """ |
| 173 | + |
| 174 | + return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) |
154 | 175 |
|
155 | 176 |
|
156 | 177 | class SolveBase(Op):
|
|
0 commit comments