Skip to content

Commit 7358bd5

Browse files
Changes for rewriting redundant slices
1 parent 4eb7828 commit 7358bd5

File tree

2 files changed

+44
-43
lines changed

2 files changed

+44
-43
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -389,56 +389,51 @@ def local_useless_slice(fgraph, node):
389389
def local_replace_slice(fgraph, node):
390390
"""
391391
Rewrite Subtensor of the form:
392-
1. X[0:-1:1] -> X[:-1]
393-
2. X[0:-1] -> X[:-1]
394-
3. X[:-1] -> X[:-1]
392+
1. X[0:-1:1] -> X[None:None:None]
393+
2. X[0:-1:2] -> X[None:None:2]
394+
3. X[3:-1] -> X[3:None:None]
395395
396396
"""
397397
idxs = get_idx_list(node.inputs, node.op.idx_list)
398398
x = node.inputs[0]
399399

400400
if not idxs:
401-
return [node.inputs[0]]
402-
403-
# pytensor.dprint(node)
404-
# print(node.inputs[0].type.shape)
401+
return [x]
405402

406-
# last_slice = len(idxs)
407403
new_idxs = list(idxs)
404+
idx_flag = False
405+
for dim, s in enumerate(new_idxs):
406+
if not isinstance(s, slice):
407+
continue
408+
409+
flag = False
410+
start = s.start
411+
stop = s.stop
412+
step = s.step
413+
if start is None or extract_constant(start, only_process_constants=True) == 0:
414+
flag = True
415+
start = None
408416

409-
# flag = False
410-
# index = -1
411-
# call s slice
412-
for dim, s in enumerate(idxs):
413417
if (
414-
isinstance(s, slice)
415-
and (s.start is None or extract_constant(s.start, only_process_constants=True) == 0)
416-
and (extract_constant(s.stop, only_process_constants=True) == -1 or extract_constant(s.stop, only_process_constants=True) == node.inputs[0].type.shape[dim])
417-
and (s.step is None or extract_constant(s.step, only_process_constants=True) == 1)
418+
extract_constant(stop, only_process_constants=True) == -1
419+
or extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
418420
):
419-
# break
420-
if index != -1:
421-
new_idxs[dim] = slice(None, None, None)
422-
else:
423-
# exchange with if
424-
continue
425-
# if nothing changewd, return None
426-
# if index != -1:
427-
# new_idxs[dim] = slice(None, None, None)
428-
429-
# new_subtensor = Subtensor(tuple(new_idxs))
430-
# new_subtensor_inputs = get_slice_elements(
431-
# new_idxs, lambda x: isinstance(x, Variable)
432-
# )
433-
# out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
434-
# # Copy over previous output stacktrace
435-
# copy_stack_trace(node.outputs, out)
436-
# return [out]
437-
if change:
438-
x[tuple(new_idxs)]
421+
flag = True
422+
stop = None
423+
424+
if step is None or extract_constant(step, only_process_constants=True) == 1:
425+
flag = True
426+
step = None
427+
428+
if flag:
429+
idx_flag = True
430+
new_idxs[dim] = slice(start, stop, step)
431+
432+
if idx_flag is True:
433+
return [x[tuple(new_idxs)]]
439434
else:
440435
# Subtensor is not needed at all
441-
return None
436+
return [x]
442437

443438

444439
# fast_compile to allow opt subtensor(cast{float32}(make_vector))

tests/tensor/rewriting/test_subtensor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.compile.ops import DeepCopyOp
1111
from pytensor.configdefaults import config
1212
from pytensor.graph import FunctionGraph, vectorize_graph
13-
from pytensor.graph.basic import Constant, Variable, ancestors
13+
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
1414
from pytensor.graph.rewriting.basic import check_stack_trace
1515
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1616
from pytensor.graph.rewriting.utils import rewrite_graph
@@ -2405,11 +2405,17 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
24052405

24062406

24072407
def test_slice_canonicalize():
2408-
# Dummy test. Need to and will be changed
24092408
x = tensor("x", shape=(3, 5, None, 9))
24102409
y = x[:-1, 0:5, 0:7, 0:-1:-1]
24112410

2412-
f = pytensor.function([x], [y])
2413-
# use equal_computations
2414-
pytensor.dprint(f)
2415-
assert 0
2411+
f = pytensor.function([x], y)
2412+
2413+
test_y = f.maker.fgraph.toposort()
2414+
2415+
y1 = x[None:None:None, None:None:None, None:7:None, None:None:-1]
2416+
2417+
f = pytensor.function([x], y1)
2418+
2419+
expected_y = f.maker.fgraph.toposort()
2420+
2421+
assert equal_computations([test_y], [expected_y])

0 commit comments

Comments
 (0)