Skip to content

Commit b54bbcd

Browse files
Canonicalize subtensor slices
1 parent c767114 commit b54bbcd

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def local_useless_slice(fgraph, node):
368368
# check if we removed something
369369
if last_useless_slice < len(idxs):
370370
new_idxs = idxs[:last_useless_slice]
371+
371372
if new_idxs:
372373
new_subtensor = Subtensor(new_idxs)
373374
new_subtensor_inputs = get_slice_elements(
@@ -382,6 +383,58 @@ def local_useless_slice(fgraph, node):
382383
return [node.inputs[0]]
383384

384385

386+
@register_useless
387+
@register_canonicalize
388+
@register_stabilize
389+
@register_specialize
390+
@node_rewriter([Subtensor])
391+
def local_replace_slice(fgraph, node):
392+
"""
393+
Rewrite Subtensor of the form:
394+
X[0:7:1] -> X[None:None:None]
395+
where X is a vector of length 7
396+
397+
"""
398+
idxs = get_idx_list(node.inputs, node.op.idx_list)
399+
x = node.inputs[0]
400+
401+
if not idxs:
402+
return
403+
404+
new_idxs = list(idxs)
405+
idx_flag = False
406+
for dim, s in enumerate(new_idxs):
407+
if not isinstance(s, slice):
408+
continue
409+
410+
start = s.start
411+
stop = s.stop
412+
step = s.step
413+
if extract_constant(start, only_process_constants=True) == 0:
414+
idx_flag = True
415+
start = None
416+
417+
if (
418+
x.type.shape[dim] is not None
419+
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
420+
):
421+
idx_flag = True
422+
stop = None
423+
424+
if extract_constant(step, only_process_constants=True) == 1:
425+
idx_flag = True
426+
step = None
427+
428+
new_idxs[dim] = slice(start, stop, step)
429+
430+
if idx_flag is True:
431+
out = x[tuple(new_idxs)]
432+
# Copy over previous output stacktrace
433+
copy_stack_trace(node.outputs, out)
434+
435+
return [out]
436+
437+
385438
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
386439
@register_canonicalize("fast_compile")
387440
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.compile.ops import DeepCopyOp
1111
from pytensor.configdefaults import config
1212
from pytensor.graph import FunctionGraph, vectorize_graph
13-
from pytensor.graph.basic import Constant, Variable, ancestors
13+
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
1414
from pytensor.graph.rewriting.basic import check_stack_trace
1515
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1616
from pytensor.graph.rewriting.utils import rewrite_graph
@@ -2402,3 +2402,24 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
24022402
else:
24032403
expected_out[:, :, core_idxs] += test_y
24042404
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
2405+
2406+
2407+
@pytest.mark.parametrize("fstop, lstop, lstep", [(None, 9, 1), (-1, -1, -1)])
2408+
def test_slice_canonicalize(fstop, lstop, lstep):
2409+
x = tensor(shape=(3, 5, None, 9))
2410+
y = x[0:fstop, 0:5, 0:7, 0:lstop:lstep]
2411+
f = pytensor.function([x], y)
2412+
test_y = f.maker.fgraph.toposort()
2413+
2414+
y1 = x[None:None:None, None:None:None, None:7:None, None:None:None]
2415+
2416+
if fstop == -1 and lstop == -1 and lstep == -1:
2417+
y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2418+
2419+
f1 = pytensor.function([x], y1)
2420+
expected_y = f1.maker.fgraph.toposort()
2421+
2422+
assert all(
2423+
equal_computations([x1], [y1])
2424+
for x1, y1 in zip(test_y[0].inputs, expected_y[0].inputs)
2425+
)

0 commit comments

Comments
 (0)