Skip to content

Commit dc2ce68

Browse files
Move docstring from Cholesky to cholesky
Match `cholesky` function signature to `scipy.linalg.cholesky`
1 parent d546e5b commit dc2ce68

File tree

1 file changed

+43
-22
lines changed

1 file changed

+43
-22
lines changed

pytensor/tensor/slinalg.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,6 @@
2828

2929

3030
class Cholesky(Op):
31-
"""
32-
Return a triangular matrix square root of positive semi-definite `x`.
33-
34-
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
35-
36-
Parameters
37-
----------
38-
lower : bool, default=True
39-
Whether to return the lower or upper cholesky factor
40-
on_error : ['raise', 'nan']
41-
If on_error is set to 'raise', this Op will raise a
42-
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
43-
If on_error is set to 'nan', it will return a matrix containing
44-
nans instead.
45-
overwrite_a: bool, ignored
46-
Whether to use the same memory for the output as `a`. This argument is ignored, and
47-
included only for consistency with scipy.linalg.cholesky.
48-
"""
49-
5031
# TODO: for specific dtypes
5132
# TODO: LAPACK wrapper with in-place behavior, for solve also
5233

@@ -148,9 +129,49 @@ def conjugate_solve_triangular(outer, inner):
148129

149130

150131
def cholesky(x, lower=True, on_error="raise", overwrite_a=False):
151-
return Blockwise(Cholesky(lower=lower, on_error=on_error, overwrite_a=overwrite_a))(
152-
x
153-
)
132+
"""
133+
Return a triangular matrix square root of positive semi-definite `x`.
134+
135+
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
136+
137+
Parameters
138+
----------
139+
lower : bool, default=True
140+
Whether to return the lower or upper cholesky factor
141+
on_error : ['raise', 'nan']
142+
If on_error is set to 'raise', this Op will raise a
143+
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
144+
If on_error is set to 'nan', it will return a matrix containing
145+
nans instead.
146+
overwrite_a: bool, ignored
147+
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
148+
for consistency with scipy.linalg.cholesky.
149+
150+
Returns
151+
-------
152+
TensorVariable
153+
Lower or upper triangular Cholesky factor of `x`
154+
155+
Example
156+
-------
157+
.. code-block:: python
158+
159+
import pytensor
160+
import pytensor.tensor as pt
161+
import numpy as np
162+
163+
x = pt.tensor('x', size=(5, 5), dtype='float64')
164+
L = pt.linalg.cholesky(x)
165+
166+
f = pytensor.function([x], L)
167+
x_value = np.random.normal(size=(5, 5))
168+
x_value = x_value @ x_value.T # Ensures x is positive definite
169+
L_value = f(x_value)
170+
print(np.allclose(L_value @ L_value.T, x_value))
171+
>>> True
172+
"""
173+
174+
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
154175

155176

156177
class SolveBase(Op):

0 commit comments

Comments
 (0)