File tree 1 file changed +20
-6
lines changed
pytensor/tensor/rewriting
1 file changed +20
-6
lines changed Original file line number Diff line number Diff line change @@ -1295,12 +1295,26 @@ def local_inplace_setsubtensor(fgraph, node):
1295
1295
1296
1296
@node_rewriter ([AdvancedIncSubtensor1 ], inplace = True )
1297
1297
def local_inplace_AdvancedIncSubtensor1 (fgraph , node ):
1298
- if isinstance (node .op , AdvancedIncSubtensor1 ) and not node .op .inplace :
1299
- new_op = node .op .clone_inplace ()
1300
- new_node = new_op (* node .inputs )
1301
- copy_stack_trace (node .outputs , new_node )
1302
- return [new_node ]
1303
- return False
1298
+ if node .op .inplace :
1299
+ return
1300
+
1301
+ x , y , idx = node .inputs
1302
+ if fgraph .has_destroyers ([x ]):
1303
+ # In this case we can't operate inplace, but if x is just an alloc of zeros
1304
+ # We're better off duplicating it and then acting on it inplace.
1305
+ if (
1306
+ x .owner is not None
1307
+ and isinstance (x .owner .op , Alloc )
1308
+ and x .owner .op .value_is_scalar_zero (x .owner .inputs [0 ])
1309
+ ):
1310
+ x = x .owner .clone ().outputs [0 ]
1311
+ else :
1312
+ return None # Inplace isn't valid
1313
+
1314
+ new_op = node .op .clone_inplace ()
1315
+ new_node = new_op (x , y , idx )
1316
+ copy_stack_trace (node .outputs , new_node )
1317
+ return [new_node ]
1304
1318
1305
1319
1306
1320
compile .optdb .register (
You can’t perform that action at this time.
0 commit comments