Skip to content

Commit d681749

Browse files
Merge Canonicalize slice and useless slice rewrites
1 parent c7767e8 commit d681749

File tree

2 files changed

+53
-69
lines changed

2 files changed

+53
-69
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 28 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -336,98 +336,68 @@ def local_subtensor_of_dot(fgraph, node):
336336
@register_useless
337337
@register_canonicalize
338338
@register_specialize
339+
@register_stabilize
339340
@node_rewriter([Subtensor])
340341
def local_useless_slice(fgraph, node):
341342
"""
342343
Remove Subtensor of the form:
343344
1. X[0, :] -> X[0]
344345
2. X[:] -> X
345346
346-
"""
347-
idxs = get_idx_list(node.inputs, node.op.idx_list)
348-
349-
if not idxs:
350-
return [node.inputs[0]]
351-
352-
last_useless_slice = len(idxs)
353-
for s in idxs[::-1]:
354-
# check if slice and then check slice indices
355-
if (
356-
isinstance(s, slice)
357-
and s.start is None
358-
and s.stop is None
359-
and (
360-
s.step is None
361-
or extract_constant(s.step, only_process_constants=True) == 1
362-
)
363-
):
364-
last_useless_slice -= 1
365-
else:
366-
break
367-
# check if we removed something
368-
if last_useless_slice < len(idxs):
369-
new_idxs = idxs[:last_useless_slice]
370-
371-
if new_idxs:
372-
new_subtensor = Subtensor(new_idxs)
373-
new_subtensor_inputs = get_slice_elements(
374-
new_idxs, lambda x: isinstance(x, Variable)
375-
)
376-
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
377-
# Copy over previous output stacktrace
378-
copy_stack_trace(node.outputs, out)
379-
return [out]
380-
else:
381-
# Subtensor is not needed at all
382-
return [node.inputs[0]]
383-
384-
385-
@register_useless
386-
@register_canonicalize
387-
@register_stabilize
388-
@register_specialize
389-
@node_rewriter([Subtensor])
390-
def local_replace_slice(fgraph, node):
391-
"""
392-
Rewrite Subtensor of the form:
347+
Also, rewrite Subtensor of the form:
393348
X[0:7:1] -> X[None:None:None]
394-
where X is a vector of length 7
349+
where X is a vector of length 7
395350
396351
"""
397352
idxs = get_idx_list(node.inputs, node.op.idx_list)
398353
x = node.inputs[0]
399354

400355
if not idxs:
401-
return
356+
return [node.inputs[0]]
402357

403358
new_idxs = list(idxs)
404-
idx_flag = False
359+
change_flag = False
360+
last_useful_idx = -1
405361
for dim, s in enumerate(new_idxs):
406362
if not isinstance(s, slice):
363+
last_useful_idx = dim
364+
continue
365+
366+
if s == slice(None):
407367
continue
408368

409369
start = s.start
410370
stop = s.stop
411371
step = s.step
412-
if extract_constant(start, only_process_constants=True) == 0:
413-
idx_flag = True
372+
if (
373+
start is not None
374+
and extract_constant(start, only_process_constants=True) == 0
375+
):
376+
change_flag = True
414377
start = None
415378

416379
if (
417-
x.type.shape[dim] is not None
380+
stop is not None
381+
and x.type.shape[dim] is not None
418382
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
419383
):
420-
idx_flag = True
384+
change_flag = True
421385
stop = None
422386

423-
if extract_constant(step, only_process_constants=True) == 1:
424-
idx_flag = True
387+
if (
388+
step is not None
389+
and extract_constant(step, only_process_constants=True) == 1
390+
):
391+
change_flag = True
425392
step = None
426393

394+
if not (start is None and stop is None and step is None):
395+
last_useful_idx = dim
396+
427397
new_idxs[dim] = slice(start, stop, step)
428398

429-
if idx_flag is True:
430-
out = x[tuple(new_idxs)]
399+
if change_flag or ((last_useful_idx + 1) < len(idxs)):
400+
out = x[tuple(new_idxs[: last_useful_idx + 1])]
431401
# Copy over previous output stacktrace
432402
copy_stack_trace(node.outputs, out)
433403

tests/tensor/rewriting/test_subtensor.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,22 +2404,36 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
24042404
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
24052405

24062406

2407-
@pytest.mark.parametrize("fstop, lstop, lstep", [(None, 9, 1), (-1, -1, -1)])
2408-
def test_slice_canonicalize(fstop, lstop, lstep):
2407+
def test_slice_canonicalize():
2408+
rng = np.random.default_rng(43)
24092409
x = tensor(shape=(3, 5, None, 9))
2410-
y = x[0:fstop, 0:5, 0:7, 0:lstop:lstep]
2410+
test_x = rng.normal(size=(3, 5, 8, 9))
2411+
# Test case 1
2412+
y = x[0:None, 0:5, 0:7, 0:9:1]
24112413
f = pytensor.function([x], y)
2412-
test_y = f.maker.fgraph.toposort()
2414+
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
24132415

2414-
y1 = x[None:None:None, None:None:None, None:7:None, None:None:None]
2416+
expected_y = x[None:None:None, None:None:None, None:7:None]
24152417

2416-
if fstop == -1 and lstop == -1 and lstep == -1:
2417-
y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2418+
assert equal_computations([test_y], [expected_y])
24182419

2420+
np.testing.assert_allclose(
2421+
f(test_x),
2422+
test_x[
2423+
0:None, 0:5, 0:7, 0:9:1
2424+
], # Use the unoptimized slice to make sure our rewrite logic is correct
2425+
)
2426+
2427+
# Test case 2
2428+
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
24192429
f1 = pytensor.function([x], y1)
2420-
expected_y = f1.maker.fgraph.toposort()
2430+
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
2431+
2432+
expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2433+
2434+
assert equal_computations([test_y1], [expected_y1])
24212435

2422-
assert all(
2423-
equal_computations([x1], [y1])
2424-
for x1, y1 in zip(test_y[0].inputs, expected_y[0].inputs)
2436+
np.testing.assert_allclose(
2437+
f1(test_x),
2438+
test_x[0:-1, 0:5, 0:7, 0:-1:-1],
24252439
)

0 commit comments

Comments
 (0)