Skip to content

Commit 78f1cf8

Browse files
committed
Avoid copy of zeros in AdvancedIncSubtensor1
1 parent b7d1a52 commit 78f1cf8

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

pytensor/tensor/rewriting/subtensor.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -1295,12 +1295,26 @@ def local_inplace_setsubtensor(fgraph, node):
12951295

12961296
@node_rewriter([AdvancedIncSubtensor1], inplace=True)
12971297
def local_inplace_AdvancedIncSubtensor1(fgraph, node):
1298-
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
1299-
new_op = node.op.clone_inplace()
1300-
new_node = new_op(*node.inputs)
1301-
copy_stack_trace(node.outputs, new_node)
1302-
return [new_node]
1303-
return False
1298+
if node.op.inplace:
1299+
return
1300+
1301+
x, y, idx = node.inputs
1302+
if fgraph.has_destroyers([x]):
1303+
# In this case we can't operate inplace, but if x is just an alloc of zeros
1304+
# We're better off duplicating it and then acting on it inplace.
1305+
if (
1306+
x.owner is not None
1307+
and isinstance(x.owner.op, Alloc)
1308+
and x.owner.op.value_is_scalar_zero(x.owner.inputs[0])
1309+
):
1310+
x = x.owner.clone().outputs[0]
1311+
else:
1312+
return None # Inplace isn't valid
1313+
1314+
new_op = node.op.clone_inplace()
1315+
new_node = new_op(x, y, idx)
1316+
copy_stack_trace(node.outputs, new_node)
1317+
return [new_node]
13041318

13051319

13061320
compile.optdb.register(

0 commit comments

Comments
 (0)