28
28
29
29
30
30
class Cholesky (Op ):
31
- """
32
- Return a triangular matrix square root of positive semi-definite `x`.
33
-
34
- L = cholesky(X, lower=True) implies dot(L, L.T) == X.
35
-
36
- Parameters
37
- ----------
38
- lower : bool, default=True
39
- Whether to return the lower or upper cholesky factor
40
- on_error : ['raise', 'nan']
41
- If on_error is set to 'raise', this Op will raise a
42
- `scipy.linalg.LinAlgError` if the matrix is not positive definite.
43
- If on_error is set to 'nan', it will return a matrix containing
44
- nans instead.
45
- """
46
-
47
- # TODO: inplace
48
31
# TODO: for specific dtypes
49
32
# TODO: LAPACK wrapper with in-place behavior, for solve also
50
33
51
- __props__ = ("lower" , "destructive " , "on_error" )
34
+ __props__ = ("lower" , "overwrite_a " , "on_error" )
52
35
gufunc_signature = "(m,m)->(m,m)"
53
36
54
- def __init__ (self , * , lower = True , on_error = "raise" ):
37
+ def __init__ (self , * , lower = True , on_error = "raise" , overwrite_a = False ):
55
38
self .lower = lower
56
- self . destructive = False
39
+
57
40
if on_error not in ("raise" , "nan" ):
58
41
raise ValueError ('on_error must be one of "raise" or ""nan"' )
59
42
self .on_error = on_error
60
43
44
+ self .overwrite_a = overwrite_a
45
+ if self .overwrite_a :
46
+ self .destroy_map = {0 : [0 ]}
47
+
61
48
def infer_shape (self , fgraph , node , shapes ):
62
49
return [shapes [0 ]]
63
50
@@ -67,15 +54,27 @@ def make_node(self, x):
67
54
return Apply (self , [x ], [x .type ()])
68
55
69
56
def perform (self , node , inputs , outputs ):
70
- x = inputs [0 ]
71
- z = outputs [0 ]
57
+ (x ,) = inputs
58
+ (z ,) = outputs
59
+ input_dtype = x .dtype
72
60
try :
73
- z [0 ] = scipy .linalg .cholesky (x , lower = self .lower ).astype (x .dtype )
61
+ if x .flags ["C_CONTIGUOUS" ] and self .overwrite_a :
62
+ # Inputs to the LAPACK functions need to be exactly as expected for overwrite_a to work correctly,
63
+ # see https://github.com/scipy/scipy/issues/8155#issuecomment-343996798
64
+ x = scipy .linalg .cholesky (
65
+ x .T , lower = not self .lower , overwrite_a = self .overwrite_a
66
+ ).T
67
+ else :
68
+ x = scipy .linalg .cholesky (
69
+ x , lower = self .lower , overwrite_a = self .overwrite_a
70
+ )
71
+
74
72
except scipy .linalg .LinAlgError :
75
73
if self .on_error == "raise" :
76
74
raise
77
75
else :
78
- z [0 ] = (np .zeros (x .shape ) * np .nan ).astype (x .dtype )
76
+ x = np .full_like (x , np .nan )
77
+ z [0 ] = x .astype (input_dtype )
79
78
80
79
def L_op (self , inputs , outputs , gradients ):
81
80
"""
@@ -129,7 +128,49 @@ def conjugate_solve_triangular(outer, inner):
129
128
return [grad ]
130
129
131
130
132
- def cholesky (x , lower = True , on_error = "raise" ):
131
+ def cholesky (x , lower = True , on_error = "raise" , overwrite_a = False ):
132
+ """
133
+ Return a triangular matrix square root of positive semi-definite `x`.
134
+
135
+ L = cholesky(X, lower=True) implies dot(L, L.T) == X.
136
+
137
+ Parameters
138
+ ----------
139
+ lower : bool, default=True
140
+ Whether to return the lower or upper cholesky factor
141
+ on_error : ['raise', 'nan']
142
+ If on_error is set to 'raise', this Op will raise a
143
+ `scipy.linalg.LinAlgError` if the matrix is not positive definite.
144
+ If on_error is set to 'nan', it will return a matrix containing
145
+ nans instead.
146
+ overwrite_a: bool, ignored
147
+ Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
148
+ for consistency with scipy.linalg.cholesky.
149
+
150
+ Returns
151
+ -------
152
+ TensorVariable
153
+ Lower or upper triangular Cholesky factor of `x`
154
+
155
+ Example
156
+ -------
157
+ .. code-block:: python
158
+
159
+ import pytensor
160
+ import pytensor.tensor as pt
161
+ import numpy as np
162
+
163
+ x = pt.tensor('x', size=(5, 5), dtype='float64')
164
+ L = pt.linalg.cholesky(x)
165
+
166
+ f = pytensor.function([x], L)
167
+ x_value = np.random.normal(size=(5, 5))
168
+ x_value = x_value @ x_value.T # Ensures x is positive definite
169
+ L_value = f(x_value)
170
+ print(np.allclose(L_value @ L_value.T, x_value))
171
+ >>> True
172
+ """
173
+
133
174
return Blockwise (Cholesky (lower = lower , on_error = on_error ))(x )
134
175
135
176
0 commit comments