Skip to content

Commit 5b96d83

Browse files
Rewrite local_inplace_cholesky to use make_inplace helper function
1 parent a15e6aa commit 5b96d83

File tree

4 files changed

+22
-15
lines changed

4 files changed

+22
-15
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,18 @@ def alloc_like(
135135
return rval
136136

137137

138+
def make_inplace(node, inplace_prop="inplace"):
139+
op = getattr(node.op, "core_op", node.op)
140+
props = op._props_dict()
141+
if props[inplace_prop]:
142+
return False
143+
144+
props[inplace_prop] = True
145+
inplace_op = type(op)(**props)
146+
147+
return inplace_op.make_node(*node.inputs).outputs
148+
149+
138150
def register_useless(
139151
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags, **kwargs
140152
):

pytensor/tensor/rewriting/linalg.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
1111
from pytensor.tensor.nlinalg import MatrixInverse, det
1212
from pytensor.tensor.rewriting.basic import (
13+
make_inplace,
1314
register_canonicalize,
1415
register_specialize,
1516
register_stabilize,
@@ -313,15 +314,9 @@ def local_log_prod_sqr(fgraph, node):
313314
# returns the sign of the prod multiplication.
314315

315316

316-
cholesky_no_inplace = Cholesky(overwrite_a=False)
317-
cholesky_inplace = Cholesky(overwrite_a=True)
318-
319-
320-
@node_rewriter([cholesky_no_inplace], inplace=True)
317+
@node_rewriter([Cholesky], inplace=True)
321318
def local_inplace_cholesky(fgraph, node):
322-
new_out = [cholesky_inplace(*node.inputs)]
323-
copy_stack_trace(node.outputs, new_out)
324-
return new_out
319+
return make_inplace(node, "overwrite_a")
325320

326321

327322
# After destroyhandler(49.5) but before we try to make elemwise things

pytensor/tensor/slinalg.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ class Cholesky(Op):
4747
# TODO: for specific dtypes
4848
# TODO: LAPACK wrapper with in-place behavior, for solve also
4949

50-
__props__ = ("lower", "destructive", "on_error")
50+
__props__ = ("lower", "overwrite_a", "on_error")
5151
gufunc_signature = "(m,m)->(m,m)"
5252

5353
def __init__(self, *, lower=True, on_error="raise", overwrite_a=False):
5454
self.lower = lower
55-
self.destructive = overwrite_a
55+
self.overwrite_a = overwrite_a
5656
if on_error not in ("raise", "nan"):
5757
raise ValueError('on_error must be one of "raise" or ""nan"')
5858
self.on_error = on_error
@@ -72,15 +72,15 @@ def perform(self, node, inputs, outputs):
7272
(z,) = outputs
7373

7474
try:
75-
if x.flags["C_CONTIGUOUS"]:
75+
if x.flags["C_CONTIGUOUS"] and self.overwrite_a:
7676
# Inputs to the LAPACK functions need to be exactly as expected for overwrite_a to work correctly,
7777
# see https://github.com/scipy/scipy/issues/8155#issuecomment-343996798
7878
x = scipy.linalg.cholesky(
79-
x.T, lower=not self.lower, overwrite_a=self.destructive
79+
x.T, lower=not self.lower, overwrite_a=self.overwrite_a
8080
).T
8181
else:
8282
x = scipy.linalg.cholesky(
83-
x, lower=self.lower, overwrite_a=self.destructive
83+
x, lower=self.lower, overwrite_a=self.overwrite_a
8484
)
8585

8686
except scipy.linalg.LinAlgError:

tests/tensor/rewriting/test_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,12 @@ def test_local_inplace_cholesky():
313313
L = cholesky(X, overwrite_a=False, lower=True)
314314
f = function([pytensor.In(X, mutable=True)], L)
315315

316-
assert not L.owner.op.core_op.destructive
316+
assert not L.owner.op.core_op.overwrite_a
317317

318318
nodes = f.maker.fgraph.toposort()
319319
for node in nodes:
320320
if isinstance(node, Cholesky):
321-
assert node.destructive
321+
assert node.overwrite_a
322322
break
323323

324324
X_val = np.random.normal(size=(10, 10)).astype(config.floatX)

0 commit comments

Comments
 (0)