Skip to content

Commit 5d13604

Browse files
Add destructive in-place rewrite for pt.linalg.cholesky
1 parent e180927 commit 5d13604

File tree

4 files changed

+128
-27
lines changed

4 files changed

+128
-27
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,6 @@ def softplus(x):
812812
@numba_funcify.register(Cholesky)
813813
def numba_funcify_Cholesky(op, node, **kwargs):
814814
lower = op.lower
815-
816815
out_dtype = node.outputs[0].type.numpy_dtype
817816

818817
if lower:

pytensor/tensor/rewriting/linalg.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
22
from typing import cast
33

4-
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
4+
from pytensor.compile import optdb
5+
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
56
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
67
from pytensor.tensor.blas import Dot22
78
from pytensor.tensor.blockwise import Blockwise
@@ -310,3 +311,27 @@ def local_log_prod_sqr(fgraph, node):
310311

311312
# TODO: have a reduction like prod and sum that simply
312313
# returns the sign of the prod multiplication.
314+
315+
316+
cholesky_no_inplace = Cholesky(overwrite_a=False)
317+
cholesky_inplace = Cholesky(overwrite_a=True)
318+
319+
320+
@node_rewriter([cholesky_no_inplace], inplace=True)
321+
def local_inplace_cholesky(fgraph, node):
322+
new_out = [cholesky_inplace(*node.inputs)]
323+
copy_stack_trace(node.outputs, new_out)
324+
return new_out
325+
326+
327+
# After destroyhandler(49.5) but before we try to make elemwise things
328+
# inplace (75)
329+
linalg_opt_inplace = in2out(local_inplace_cholesky, name="linalg_opt_inplace")
330+
optdb.register(
331+
"InplaceLinalgOpt",
332+
linalg_opt_inplace,
333+
"fast_run",
334+
"inplace",
335+
"linalg_opt_inplace",
336+
position=69.0,
337+
)

pytensor/tensor/slinalg.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,36 +28,23 @@
2828

2929

3030
class Cholesky(Op):
31-
"""
32-
Return a triangular matrix square root of positive semi-definite `x`.
33-
34-
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
35-
36-
Parameters
37-
----------
38-
lower : bool, default=True
39-
Whether to return the lower or upper cholesky factor
40-
on_error : ['raise', 'nan']
41-
If on_error is set to 'raise', this Op will raise a
42-
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
43-
If on_error is set to 'nan', it will return a matrix containing
44-
nans instead.
45-
"""
46-
47-
# TODO: inplace
4831
# TODO: for specific dtypes
4932
# TODO: LAPACK wrapper with in-place behavior, for solve also
5033

51-
__props__ = ("lower", "destructive", "on_error")
34+
__props__ = ("lower", "overwrite_a", "on_error")
5235
gufunc_signature = "(m,m)->(m,m)"
5336

54-
def __init__(self, *, lower=True, on_error="raise"):
37+
def __init__(self, *, lower=True, on_error="raise", overwrite_a=False):
5538
self.lower = lower
56-
self.destructive = False
39+
5740
if on_error not in ("raise", "nan"):
5841
raise ValueError('on_error must be one of "raise" or ""nan"')
5942
self.on_error = on_error
6043

44+
self.overwrite_a = overwrite_a
45+
if self.overwrite_a:
46+
self.destroy_map = {0: [0]}
47+
6148
def infer_shape(self, fgraph, node, shapes):
6249
return [shapes[0]]
6350

@@ -67,15 +54,27 @@ def make_node(self, x):
6754
return Apply(self, [x], [x.type()])
6855

6956
def perform(self, node, inputs, outputs):
70-
x = inputs[0]
71-
z = outputs[0]
57+
(x,) = inputs
58+
(z,) = outputs
59+
input_dtype = x.dtype
7260
try:
73-
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
61+
if x.flags["C_CONTIGUOUS"] and self.overwrite_a:
62+
# Inputs to the LAPACK functions need to be exactly as expected for overwrite_a to work correctly,
63+
# see https://github.com/scipy/scipy/issues/8155#issuecomment-343996798
64+
x = scipy.linalg.cholesky(
65+
x.T, lower=not self.lower, overwrite_a=self.overwrite_a
66+
).T
67+
else:
68+
x = scipy.linalg.cholesky(
69+
x, lower=self.lower, overwrite_a=self.overwrite_a
70+
)
71+
7472
except scipy.linalg.LinAlgError:
7573
if self.on_error == "raise":
7674
raise
7775
else:
78-
z[0] = (np.zeros(x.shape) * np.nan).astype(x.dtype)
76+
x = np.full_like(x, np.nan)
77+
z[0] = x.astype(input_dtype)
7978

8079
def L_op(self, inputs, outputs, gradients):
8180
"""
@@ -129,7 +128,49 @@ def conjugate_solve_triangular(outer, inner):
129128
return [grad]
130129

131130

132-
def cholesky(x, lower=True, on_error="raise"):
131+
def cholesky(x, lower=True, on_error="raise", overwrite_a=False):
132+
"""
133+
Return a triangular matrix square root of positive semi-definite `x`.
134+
135+
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
136+
137+
Parameters
138+
----------
139+
lower : bool, default=True
140+
Whether to return the lower or upper cholesky factor
141+
on_error : ['raise', 'nan']
142+
If on_error is set to 'raise', this Op will raise a
143+
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
144+
If on_error is set to 'nan', it will return a matrix containing
145+
nans instead.
146+
overwrite_a: bool, ignored
147+
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
148+
for consistency with scipy.linalg.cholesky.
149+
150+
Returns
151+
-------
152+
TensorVariable
153+
Lower or upper triangular Cholesky factor of `x`
154+
155+
Example
156+
-------
157+
.. code-block:: python
158+
159+
import pytensor
160+
import pytensor.tensor as pt
161+
import numpy as np
162+
163+
x = pt.tensor('x', size=(5, 5), dtype='float64')
164+
L = pt.linalg.cholesky(x)
165+
166+
f = pytensor.function([x], L)
167+
x_value = np.random.normal(size=(5, 5))
168+
x_value = x_value @ x_value.T # Ensures x is positive definite
169+
L_value = f(x_value)
170+
print(np.allclose(L_value @ L_value.T, x_value))
171+
>>> True
172+
"""
173+
133174
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
134175

135176

tests/tensor/rewriting/test_linalg.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,39 @@ def test_invalid_batched_a(self):
306306
ref_fn(test_a, test_b),
307307
rtol=1e-7 if config.floatX == "float64" else 1e-5,
308308
)
309+
310+
311+
@pytest.mark.skipif(
312+
config.mode == "FAST_COMPILE",
313+
reason="inplace rewrites disabled when mode is FAST_COMPILE",
314+
)
315+
def test_local_inplace_cholesky():
316+
X = matrix("X")
317+
L = cholesky(X, overwrite_a=False, lower=True)
318+
f = function([pytensor.In(X, mutable=True)], L)
319+
320+
assert not L.owner.op.core_op.destructive
321+
322+
nodes = f.maker.fgraph.toposort()
323+
for node in nodes:
324+
if isinstance(node, Cholesky):
325+
assert node.destructive
326+
break
327+
328+
X_val = np.random.normal(size=(10, 10)).astype(config.floatX)
329+
X_val_in = X_val @ X_val.T
330+
X_val_in_copy = X_val_in.copy()
331+
f(X_val_in)
332+
333+
assert_allclose(
334+
X_val_in[np.triu_indices_from(X_val_in, k=1)],
335+
0.0,
336+
atol=1e-4 if config.floatX == "float32" else 1e-8,
337+
rtol=1e-4 if config.floatX == "float32" else 1e-8,
338+
)
339+
assert_allclose(
340+
X_val_in @ X_val_in.T,
341+
X_val_in_copy,
342+
atol=1e-4 if config.floatX == "float32" else 1e-8,
343+
rtol=1e-4 if config.floatX == "float32" else 1e-8,
344+
)

0 commit comments

Comments
 (0)