Skip to content

Commit 470ea60

Browse files
Introduce make_inplace helper function for destructive rewrites
Refactor cholesky destructive re-write to use `make_inplace` helper
1 parent 5d13604 commit 470ea60

File tree

4 files changed

+39
-58
lines changed

4 files changed

+39
-58
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,18 @@ def alloc_like(
135135
return rval
136136

137137

138+
def make_inplace(node, inplace_prop="inplace"):
139+
op = getattr(node.op, "core_op", node.op)
140+
props = op._props_dict()
141+
if props[inplace_prop]:
142+
return False
143+
144+
props[inplace_prop] = True
145+
inplace_op = type(op)(**props)
146+
147+
return inplace_op.make_node(*node.inputs).outputs
148+
149+
138150
def register_useless(
139151
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags, **kwargs
140152
):

pytensor/tensor/rewriting/linalg.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
1111
from pytensor.tensor.nlinalg import MatrixInverse, det
1212
from pytensor.tensor.rewriting.basic import (
13+
make_inplace,
1314
register_canonicalize,
1415
register_specialize,
1516
register_stabilize,
@@ -313,15 +314,9 @@ def local_log_prod_sqr(fgraph, node):
313314
# returns the sign of the prod multiplication.
314315

315316

316-
cholesky_no_inplace = Cholesky(overwrite_a=False)
317-
cholesky_inplace = Cholesky(overwrite_a=True)
318-
319-
320-
@node_rewriter([cholesky_no_inplace], inplace=True)
317+
@node_rewriter([Cholesky], inplace=True)
321318
def local_inplace_cholesky(fgraph, node):
322-
new_out = [cholesky_inplace(*node.inputs)]
323-
copy_stack_trace(node.outputs, new_out)
324-
return new_out
319+
return make_inplace(node, "overwrite_a")
325320

326321

327322
# After destroyhandler(49.5) but before we try to make elemwise things

pytensor/tensor/slinalg.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@
2828

2929

3030
class Cholesky(Op):
31+
"""
32+
Return a triangular matrix square root of positive semi-definite `x`.
33+
34+
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
35+
36+
Parameters
37+
----------
38+
lower : bool, default=True
39+
Whether to return the lower or upper cholesky factor
40+
on_error : ['raise', 'nan']
41+
If on_error is set to 'raise', this Op will raise a
42+
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
43+
If on_error is set to 'nan', it will return a matrix containing
44+
nans instead.
45+
"""
46+
3147
# TODO: for specific dtypes
3248
# TODO: LAPACK wrapper with in-place behavior, for solve also
3349

@@ -36,13 +52,11 @@ class Cholesky(Op):
3652

3753
def __init__(self, *, lower=True, on_error="raise", overwrite_a=False):
3854
self.lower = lower
39-
55+
self.overwrite_a = overwrite_a
4056
if on_error not in ("raise", "nan"):
4157
raise ValueError('on_error must be one of "raise" or ""nan"')
4258
self.on_error = on_error
43-
44-
self.overwrite_a = overwrite_a
45-
if self.overwrite_a:
59+
if overwrite_a:
4660
self.destroy_map = {0: [0]}
4761

4862
def infer_shape(self, fgraph, node, shapes):
@@ -73,7 +87,7 @@ def perform(self, node, inputs, outputs):
7387
if self.on_error == "raise":
7488
raise
7589
else:
76-
x = np.full_like(x, np.nan)
90+
x = np.zeros(x.shape) * np.nan
7791
z[0] = x.astype(input_dtype)
7892

7993
def L_op(self, inputs, outputs, gradients):
@@ -129,49 +143,9 @@ def conjugate_solve_triangular(outer, inner):
129143

130144

131145
def cholesky(x, lower=True, on_error="raise", overwrite_a=False):
132-
"""
133-
Return a triangular matrix square root of positive semi-definite `x`.
134-
135-
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
136-
137-
Parameters
138-
----------
139-
lower : bool, default=True
140-
Whether to return the lower or upper cholesky factor
141-
on_error : ['raise', 'nan']
142-
If on_error is set to 'raise', this Op will raise a
143-
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
144-
If on_error is set to 'nan', it will return a matrix containing
145-
nans instead.
146-
overwrite_a: bool, ignored
147-
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
148-
for consistency with scipy.linalg.cholesky.
149-
150-
Returns
151-
-------
152-
TensorVariable
153-
Lower or upper triangular Cholesky factor of `x`
154-
155-
Example
156-
-------
157-
.. code-block:: python
158-
159-
import pytensor
160-
import pytensor.tensor as pt
161-
import numpy as np
162-
163-
x = pt.tensor('x', size=(5, 5), dtype='float64')
164-
L = pt.linalg.cholesky(x)
165-
166-
f = pytensor.function([x], L)
167-
x_value = np.random.normal(size=(5, 5))
168-
x_value = x_value @ x_value.T # Ensures x is positive definite
169-
L_value = f(x_value)
170-
print(np.allclose(L_value @ L_value.T, x_value))
171-
>>> True
172-
"""
173-
174-
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
146+
return Blockwise(Cholesky(lower=lower, on_error=on_error, overwrite_a=overwrite_a))(
147+
x
148+
)
175149

176150

177151
class SolveBase(Op):

tests/tensor/rewriting/test_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,12 @@ def test_local_inplace_cholesky():
317317
L = cholesky(X, overwrite_a=False, lower=True)
318318
f = function([pytensor.In(X, mutable=True)], L)
319319

320-
assert not L.owner.op.core_op.destructive
320+
assert not L.owner.op.core_op.overwrite_a
321321

322322
nodes = f.maker.fgraph.toposort()
323323
for node in nodes:
324324
if isinstance(node, Cholesky):
325-
assert node.destructive
325+
assert node.overwrite_a
326326
break
327327

328328
X_val = np.random.normal(size=(10, 10)).astype(config.floatX)

0 commit comments

Comments
 (0)