Skip to content

Commit edc5ddd

Browse files
Merge Canonicalize slice and useless slice rewrites
1 parent a681a0e commit edc5ddd

File tree

1 file changed

+40
-45
lines changed

1 file changed

+40
-45
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,21 @@ def local_subtensor_of_dot(fgraph, node):
336336
@register_useless
337337
@register_canonicalize
338338
@register_specialize
339+
@register_stabilize
339340
@node_rewriter([Subtensor])
340341
def local_useless_slice(fgraph, node):
341342
"""
342343
Remove Subtensor of the form:
343344
1. X[0, :] -> X[0]
344345
2. X[:] -> X
345346
347+
Also, rewrite Subtensor of the form:
348+
X[0:7:1] -> X[None:None:None]
349+
where X is a vector of length 7
350+
346351
"""
347352
idxs = get_idx_list(node.inputs, node.op.idx_list)
353+
x = node.inputs[0]
348354

349355
if not idxs:
350356
return [node.inputs[0]]
@@ -364,74 +370,63 @@ def local_useless_slice(fgraph, node):
364370
last_useless_slice -= 1
365371
else:
366372
break
367-
# check if we removed something
368-
if last_useless_slice < len(idxs):
369-
new_idxs = idxs[:last_useless_slice]
370-
371-
if new_idxs:
372-
new_subtensor = Subtensor(new_idxs)
373-
new_subtensor_inputs = get_slice_elements(
374-
new_idxs, lambda x: isinstance(x, Variable)
375-
)
376-
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
377-
# Copy over previous output stacktrace
378-
copy_stack_trace(node.outputs, out)
379-
return [out]
380-
else:
381-
# Subtensor is not needed at all
382-
return [node.inputs[0]]
383373

384-
385-
@register_useless
386-
@register_canonicalize
387-
@register_stabilize
388-
@register_specialize
389-
@node_rewriter([Subtensor])
390-
def local_replace_slice(fgraph, node):
391-
"""
392-
Rewrite Subtensor of the form:
393-
X[0:7:1] -> X[None:None:None]
394-
where X is a vector of length 7
395-
396-
"""
397-
idxs = get_idx_list(node.inputs, node.op.idx_list)
398-
x = node.inputs[0]
399-
400-
if not idxs:
401-
return
402-
403-
new_idxs = list(idxs)
404-
idx_flag = False
374+
new_idxs = list(idxs)[:last_useless_slice]
375+
change_flag = False
405376
for dim, s in enumerate(new_idxs):
406-
if not isinstance(s, slice):
377+
if not isinstance(s, slice) or s == slice(None):
407378
continue
408379

409380
start = s.start
410381
stop = s.stop
411382
step = s.step
412-
if extract_constant(start, only_process_constants=True) == 0:
413-
idx_flag = True
383+
if (
384+
start is not None
385+
and extract_constant(start, only_process_constants=True) == 0
386+
):
387+
change_flag = True
414388
start = None
415389

416390
if (
417-
x.type.shape[dim] is not None
391+
stop is not None
392+
and x.type.shape[dim] is not None
418393
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
419394
):
420-
idx_flag = True
395+
change_flag = True
421396
stop = None
422397

423-
if extract_constant(step, only_process_constants=True) == 1:
424-
idx_flag = True
398+
if (
399+
step is not None
400+
and extract_constant(step, only_process_constants=True) == 1
401+
):
402+
change_flag = True
425403
step = None
426404

427405
new_idxs[dim] = slice(start, stop, step)
428406

429-
if idx_flag is True:
407+
if change_flag is True or last_useless_slice < len(idxs):
430408
out = x[tuple(new_idxs)]
431409
# Copy over previous output stacktrace
432410
copy_stack_trace(node.outputs, out)
433411

434412
return [out]
413+
# elif last_useless_slice >= len(idxs):
414+
# return [x]
415+
# check if we removed something
416+
# if last_useless_slice < len(idxs):
417+
# new_idxs = idxs[:last_useless_slice]
418+
# if new_idxs:
419+
# new_subtensor = Subtensor(new_idxs)
420+
# new_subtensor_inputs = get_slice_elements(
421+
# new_idxs, lambda x: isinstance(x, Variable)
422+
# )
423+
# out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
424+
# # Copy over previous output stacktrace
425+
# copy_stack_trace(node.outputs, out)
426+
# return [out]
427+
# else:
428+
# # Subtensor is not needed at all
429+
# return [node.inputs[0]]
435430

436431

437432
# fast_compile to allow opt subtensor(cast{float32}(make_vector))

0 commit comments

Comments
 (0)