Skip to content

Commit 863105a

Browse files
Fix broken code on canonicalising slices
1 parent 422ac3c commit 863105a

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def local_replace_slice(fgraph, node):
398398
x = node.inputs[0]
399399

400400
if not idxs:
401-
return [x]
401+
return
402402

403403
new_idxs = list(idxs)
404404
idx_flag = False
@@ -433,7 +433,12 @@ def local_replace_slice(fgraph, node):
433433
raise AssertionError(
434434
f"The type of the returned variable {x[tuple(new_idxs)].type} did not match with the type of the original variable {node.outputs[0].type}"
435435
)
436-
return [x[tuple(new_idxs)]]
436+
437+
out = x[tuple(new_idxs)]
438+
# Copy over previous output stacktrace
439+
copy_stack_trace(node.outputs, out)
440+
441+
return [out]
437442

438443

439444
# fast_compile to allow opt subtensor(cast{float32}(make_vector))

0 commit comments

Comments
 (0)