Skip to content

Commit fa0ab9d

Browse files
Canonicalize Subtensor slices (#761)
1 parent 117f80d commit fa0ab9d

File tree

2 files changed

+93
-29
lines changed

2 files changed

+93
-29
lines changed

pytensor/tensor/rewriting/subtensor.py

+51-28
Original file line numberDiff line numberDiff line change
@@ -337,49 +337,72 @@ def local_subtensor_of_dot(fgraph, node):
337337
@register_useless
338338
@register_canonicalize
339339
@register_specialize
340+
@register_stabilize
340341
@node_rewriter([Subtensor])
341342
def local_useless_slice(fgraph, node):
342343
"""
343344
Remove Subtensor of the form:
344345
1. X[0, :] -> X[0]
345346
2. X[:] -> X
346347
348+
Also, rewrite Subtensor of the form:
349+
X[0:7:1] -> X[None:None:None]
350+
where X is a vector of length 7
351+
347352
"""
348353
idxs = get_idx_list(node.inputs, node.op.idx_list)
354+
x = node.inputs[0]
349355

350356
if not idxs:
351357
return [node.inputs[0]]
352358

353-
last_useless_slice = len(idxs)
354-
for s in idxs[::-1]:
355-
# check if slice and then check slice indices
359+
new_idxs = list(idxs)
360+
change_flag = False
361+
last_useful_idx = -1
362+
for dim, s in enumerate(new_idxs):
363+
if not isinstance(s, slice):
364+
last_useful_idx = dim
365+
continue
366+
367+
if s == slice(None):
368+
continue
369+
370+
start = s.start
371+
stop = s.stop
372+
step = s.step
356373
if (
357-
isinstance(s, slice)
358-
and s.start is None
359-
and s.stop is None
360-
and (
361-
s.step is None
362-
or extract_constant(s.step, only_process_constants=True) == 1
363-
)
374+
start is not None
375+
and extract_constant(start, only_process_constants=True) == 0
364376
):
365-
last_useless_slice -= 1
366-
else:
367-
break
368-
# check if we removed something
369-
if last_useless_slice < len(idxs):
370-
new_idxs = idxs[:last_useless_slice]
371-
if new_idxs:
372-
new_subtensor = Subtensor(new_idxs)
373-
new_subtensor_inputs = get_slice_elements(
374-
new_idxs, lambda x: isinstance(x, Variable)
375-
)
376-
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
377-
# Copy over previous output stacktrace
378-
copy_stack_trace(node.outputs, out)
379-
return [out]
380-
else:
381-
# Subtensor is not needed at all
382-
return [node.inputs[0]]
377+
change_flag = True
378+
start = None
379+
380+
if (
381+
stop is not None
382+
and x.type.shape[dim] is not None
383+
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
384+
):
385+
change_flag = True
386+
stop = None
387+
388+
if (
389+
step is not None
390+
and extract_constant(step, only_process_constants=True) == 1
391+
):
392+
change_flag = True
393+
step = None
394+
395+
if not (start is None and stop is None and step is None):
396+
last_useful_idx = dim
397+
398+
new_idxs[dim] = slice(start, stop, step)
399+
400+
if change_flag or ((last_useful_idx + 1) < len(idxs)):
401+
out = x[tuple(new_idxs[: last_useful_idx + 1])]
402+
# Copy over previous output stacktrace
403+
copy_stack_trace(node.outputs, out)
404+
405+
return [out]
383406

384407

385408
# fast_compile to allow opt subtensor(cast{float32}(make_vector))

tests/tensor/rewriting/test_subtensor.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.compile.ops import DeepCopyOp
1111
from pytensor.configdefaults import config
1212
from pytensor.graph import FunctionGraph, vectorize_graph
13-
from pytensor.graph.basic import Constant, Variable, ancestors
13+
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
1414
from pytensor.graph.rewriting.basic import check_stack_trace
1515
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1616
from pytensor.graph.rewriting.utils import rewrite_graph
@@ -2402,3 +2402,44 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
24022402
else:
24032403
expected_out[:, :, core_idxs] += test_y
24042404
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
2405+
2406+
2407+
def test_slice_canonicalize():
2408+
rng = np.random.default_rng(43)
2409+
x = tensor(shape=(3, 5, None, 9))
2410+
test_x = rng.normal(size=(3, 5, 8, 9))
2411+
# Test case 1
2412+
y = x[0:None, 0:5, 0:7, 0:9:1]
2413+
f = pytensor.function([x], y, allow_input_downcast=True)
2414+
2415+
# Get the DeepCopy input and assert that the Op is a DeepCopy
2416+
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
2417+
assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
2418+
2419+
expected_y = x[None:None:None, None:None:None, None:7:None]
2420+
2421+
assert equal_computations([test_y], [expected_y])
2422+
2423+
np.testing.assert_allclose(
2424+
f(test_x),
2425+
test_x[
2426+
0:None, 0:5, 0:7, 0:9:1
2427+
], # Use the unoptimized slice to make sure our rewrite logic is correct
2428+
)
2429+
2430+
# Test case 2
2431+
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
2432+
f1 = pytensor.function([x], y1, allow_input_downcast=True)
2433+
2434+
# Get the DeepCopy input and assert that the Op is a DeepCopy
2435+
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
2436+
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
2437+
2438+
expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2439+
2440+
assert equal_computations([test_y1], [expected_y1])
2441+
2442+
np.testing.assert_allclose(
2443+
f1(test_x),
2444+
test_x[0:-1, 0:5, 0:7, 0:-1:-1],
2445+
)

0 commit comments

Comments
 (0)