Skip to content

Commit a681a0e

Browse files
Canonicalize subtensor slices
1 parent 30b760f commit a681a0e

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def local_useless_slice(fgraph, node):
367367
# check if we removed something
368368
if last_useless_slice < len(idxs):
369369
new_idxs = idxs[:last_useless_slice]
370+
370371
if new_idxs:
371372
new_subtensor = Subtensor(new_idxs)
372373
new_subtensor_inputs = get_slice_elements(
@@ -381,6 +382,58 @@ def local_useless_slice(fgraph, node):
381382
return [node.inputs[0]]
382383

383384

385+
@register_useless
386+
@register_canonicalize
387+
@register_stabilize
388+
@register_specialize
389+
@node_rewriter([Subtensor])
390+
def local_replace_slice(fgraph, node):
391+
"""
392+
Rewrite Subtensor of the form:
393+
X[0:7:1] -> X[None:None:None]
394+
where X is a vector of length 7
395+
396+
"""
397+
idxs = get_idx_list(node.inputs, node.op.idx_list)
398+
x = node.inputs[0]
399+
400+
if not idxs:
401+
return
402+
403+
new_idxs = list(idxs)
404+
idx_flag = False
405+
for dim, s in enumerate(new_idxs):
406+
if not isinstance(s, slice):
407+
continue
408+
409+
start = s.start
410+
stop = s.stop
411+
step = s.step
412+
if extract_constant(start, only_process_constants=True) == 0:
413+
idx_flag = True
414+
start = None
415+
416+
if (
417+
x.type.shape[dim] is not None
418+
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
419+
):
420+
idx_flag = True
421+
stop = None
422+
423+
if extract_constant(step, only_process_constants=True) == 1:
424+
idx_flag = True
425+
step = None
426+
427+
new_idxs[dim] = slice(start, stop, step)
428+
429+
if idx_flag is True:
430+
out = x[tuple(new_idxs)]
431+
# Copy over previous output stacktrace
432+
copy_stack_trace(node.outputs, out)
433+
434+
return [out]
435+
436+
384437
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
385438
@register_canonicalize("fast_compile")
386439
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.compile.ops import DeepCopyOp
1111
from pytensor.configdefaults import config
1212
from pytensor.graph import FunctionGraph, vectorize_graph
13-
from pytensor.graph.basic import Constant, Variable, ancestors
13+
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
1414
from pytensor.graph.rewriting.basic import check_stack_trace
1515
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1616
from pytensor.graph.rewriting.utils import rewrite_graph
@@ -2402,3 +2402,24 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
24022402
else:
24032403
expected_out[:, :, core_idxs] += test_y
24042404
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
2405+
2406+
2407+
@pytest.mark.parametrize("fstop, lstop, lstep", [(None, 9, 1), (-1, -1, -1)])
2408+
def test_slice_canonicalize(fstop, lstop, lstep):
2409+
x = tensor(shape=(3, 5, None, 9))
2410+
y = x[0:fstop, 0:5, 0:7, 0:lstop:lstep]
2411+
f = pytensor.function([x], y)
2412+
test_y = f.maker.fgraph.toposort()
2413+
2414+
y1 = x[None:None:None, None:None:None, None:7:None, None:None:None]
2415+
2416+
if fstop == -1 and lstop == -1 and lstep == -1:
2417+
y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2418+
2419+
f1 = pytensor.function([x], y1)
2420+
expected_y = f1.maker.fgraph.toposort()
2421+
2422+
assert all(
2423+
equal_computations([x1], [y1])
2424+
for x1, y1 in zip(test_y[0].inputs, expected_y[0].inputs)
2425+
)

0 commit comments

Comments
 (0)