28
28
29
29
30
30
class Cholesky (Op ):
31
+ """
32
+ Return a triangular matrix square root of positive semi-definite `x`.
33
+
34
+ L = cholesky(X, lower=True) implies dot(L, L.T) == X.
35
+
36
+ Parameters
37
+ ----------
38
+ lower : bool, default=True
39
+ Whether to return the lower or upper cholesky factor
40
+ on_error : ['raise', 'nan']
41
+ If on_error is set to 'raise', this Op will raise a
42
+ `scipy.linalg.LinAlgError` if the matrix is not positive definite.
43
+ If on_error is set to 'nan', it will return a matrix containing
44
+ nans instead.
45
+ """
46
+
31
47
# TODO: for specific dtypes
32
48
# TODO: LAPACK wrapper with in-place behavior, for solve also
33
49
@@ -36,13 +52,11 @@ class Cholesky(Op):
36
52
37
53
def __init__ (self , * , lower = True , on_error = "raise" , overwrite_a = False ):
38
54
self .lower = lower
39
-
55
+ self . overwrite_a = overwrite_a
40
56
if on_error not in ("raise" , "nan" ):
41
57
raise ValueError ('on_error must be one of "raise" or ""nan"' )
42
58
self .on_error = on_error
43
-
44
- self .overwrite_a = overwrite_a
45
- if self .overwrite_a :
59
+ if overwrite_a :
46
60
self .destroy_map = {0 : [0 ]}
47
61
48
62
def infer_shape (self , fgraph , node , shapes ):
@@ -73,7 +87,7 @@ def perform(self, node, inputs, outputs):
73
87
if self .on_error == "raise" :
74
88
raise
75
89
else :
76
- x = np .full_like ( x , np .nan )
90
+ x = np .zeros ( x . shape ) * np .nan
77
91
z [0 ] = x .astype (input_dtype )
78
92
79
93
def L_op (self , inputs , outputs , gradients ):
@@ -129,49 +143,9 @@ def conjugate_solve_triangular(outer, inner):
129
143
130
144
131
145
def cholesky (x , lower = True , on_error = "raise" , overwrite_a = False ):
132
- """
133
- Return a triangular matrix square root of positive semi-definite `x`.
134
-
135
- L = cholesky(X, lower=True) implies dot(L, L.T) == X.
136
-
137
- Parameters
138
- ----------
139
- lower : bool, default=True
140
- Whether to return the lower or upper cholesky factor
141
- on_error : ['raise', 'nan']
142
- If on_error is set to 'raise', this Op will raise a
143
- `scipy.linalg.LinAlgError` if the matrix is not positive definite.
144
- If on_error is set to 'nan', it will return a matrix containing
145
- nans instead.
146
- overwrite_a: bool, ignored
147
- Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
148
- for consistency with scipy.linalg.cholesky.
149
-
150
- Returns
151
- -------
152
- TensorVariable
153
- Lower or upper triangular Cholesky factor of `x`
154
-
155
- Example
156
- -------
157
- .. code-block:: python
158
-
159
- import pytensor
160
- import pytensor.tensor as pt
161
- import numpy as np
162
-
163
- x = pt.tensor('x', size=(5, 5), dtype='float64')
164
- L = pt.linalg.cholesky(x)
165
-
166
- f = pytensor.function([x], L)
167
- x_value = np.random.normal(size=(5, 5))
168
- x_value = x_value @ x_value.T # Ensures x is positive definite
169
- L_value = f(x_value)
170
- print(np.allclose(L_value @ L_value.T, x_value))
171
- >>> True
172
- """
173
-
174
- return Blockwise (Cholesky (lower = lower , on_error = on_error ))(x )
146
+ return Blockwise (Cholesky (lower = lower , on_error = on_error , overwrite_a = overwrite_a ))(
147
+ x
148
+ )
175
149
176
150
177
151
class SolveBase (Op ):
0 commit comments