Skip to content

Commit 34b084f

Browse files
Merge Canonicalize slice and useless slice rewrites
1 parent c7767e8 commit 34b084f

File tree

1 file changed

+28
-58
lines changed

1 file changed

+28
-58
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 28 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -336,98 +336,68 @@ def local_subtensor_of_dot(fgraph, node):
336336
@register_useless
337337
@register_canonicalize
338338
@register_specialize
339+
@register_stabilize
339340
@node_rewriter([Subtensor])
340341
def local_useless_slice(fgraph, node):
341342
"""
342343
Remove Subtensor of the form:
343344
1. X[0, :] -> X[0]
344345
2. X[:] -> X
345346
346-
"""
347-
idxs = get_idx_list(node.inputs, node.op.idx_list)
348-
349-
if not idxs:
350-
return [node.inputs[0]]
351-
352-
last_useless_slice = len(idxs)
353-
for s in idxs[::-1]:
354-
# check if slice and then check slice indices
355-
if (
356-
isinstance(s, slice)
357-
and s.start is None
358-
and s.stop is None
359-
and (
360-
s.step is None
361-
or extract_constant(s.step, only_process_constants=True) == 1
362-
)
363-
):
364-
last_useless_slice -= 1
365-
else:
366-
break
367-
# check if we removed something
368-
if last_useless_slice < len(idxs):
369-
new_idxs = idxs[:last_useless_slice]
370-
371-
if new_idxs:
372-
new_subtensor = Subtensor(new_idxs)
373-
new_subtensor_inputs = get_slice_elements(
374-
new_idxs, lambda x: isinstance(x, Variable)
375-
)
376-
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
377-
# Copy over previous output stacktrace
378-
copy_stack_trace(node.outputs, out)
379-
return [out]
380-
else:
381-
# Subtensor is not needed at all
382-
return [node.inputs[0]]
383-
384-
385-
@register_useless
386-
@register_canonicalize
387-
@register_stabilize
388-
@register_specialize
389-
@node_rewriter([Subtensor])
390-
def local_replace_slice(fgraph, node):
391-
"""
392-
Rewrite Subtensor of the form:
347+
Also, rewrite Subtensor of the form:
393348
X[0:7:1] -> X[None:None:None]
394-
where X is a vector of length 7
349+
where X is a vector of length 7
395350
396351
"""
397352
idxs = get_idx_list(node.inputs, node.op.idx_list)
398353
x = node.inputs[0]
399354

400355
if not idxs:
401-
return
356+
return [node.inputs[0]]
402357

403358
new_idxs = list(idxs)
404-
idx_flag = False
359+
change_flag = False
360+
last_useful_idx = -1
405361
for dim, s in enumerate(new_idxs):
406362
if not isinstance(s, slice):
363+
last_useful_idx = dim
364+
continue
365+
366+
if s == slice(None):
407367
continue
408368

409369
start = s.start
410370
stop = s.stop
411371
step = s.step
412-
if extract_constant(start, only_process_constants=True) == 0:
413-
idx_flag = True
372+
if (
373+
start is not None
374+
and extract_constant(start, only_process_constants=True) == 0
375+
):
376+
change_flag = True
414377
start = None
415378

416379
if (
417-
x.type.shape[dim] is not None
380+
stop is not None
381+
and x.type.shape[dim] is not None
418382
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
419383
):
420-
idx_flag = True
384+
change_flag = True
421385
stop = None
422386

423-
if extract_constant(step, only_process_constants=True) == 1:
424-
idx_flag = True
387+
if (
388+
step is not None
389+
and extract_constant(step, only_process_constants=True) == 1
390+
):
391+
change_flag = True
425392
step = None
426393

394+
if not (start is None and stop is None and step is None):
395+
last_useful_idx = dim
396+
427397
new_idxs[dim] = slice(start, stop, step)
428398

429-
if idx_flag is True:
430-
out = x[tuple(new_idxs)]
399+
if change_flag or ((last_useful_idx + 1) < len(idxs)):
400+
out = x[tuple(new_idxs[: last_useful_idx + 1])]
431401
# Copy over previous output stacktrace
432402
copy_stack_trace(node.outputs, out)
433403

0 commit comments

Comments
 (0)