Skip to content

Commit 037ad59

Browse files
Merge Canonicalize slice and useless slice rewrites
1 parent b54bbcd commit 037ad59

File tree

2 files changed

+55
-71
lines changed

2 files changed

+55
-71
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 28 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -337,98 +337,68 @@ def local_subtensor_of_dot(fgraph, node):
337337
@register_useless
338338
@register_canonicalize
339339
@register_specialize
340+
@register_stabilize
340341
@node_rewriter([Subtensor])
341342
def local_useless_slice(fgraph, node):
342343
"""
343344
Remove Subtensor of the form:
344345
1. X[0, :] -> X[0]
345346
2. X[:] -> X
346347
347-
"""
348-
idxs = get_idx_list(node.inputs, node.op.idx_list)
349-
350-
if not idxs:
351-
return [node.inputs[0]]
352-
353-
last_useless_slice = len(idxs)
354-
for s in idxs[::-1]:
355-
# check if slice and then check slice indices
356-
if (
357-
isinstance(s, slice)
358-
and s.start is None
359-
and s.stop is None
360-
and (
361-
s.step is None
362-
or extract_constant(s.step, only_process_constants=True) == 1
363-
)
364-
):
365-
last_useless_slice -= 1
366-
else:
367-
break
368-
# check if we removed something
369-
if last_useless_slice < len(idxs):
370-
new_idxs = idxs[:last_useless_slice]
371-
372-
if new_idxs:
373-
new_subtensor = Subtensor(new_idxs)
374-
new_subtensor_inputs = get_slice_elements(
375-
new_idxs, lambda x: isinstance(x, Variable)
376-
)
377-
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
378-
# Copy over previous output stacktrace
379-
copy_stack_trace(node.outputs, out)
380-
return [out]
381-
else:
382-
# Subtensor is not needed at all
383-
return [node.inputs[0]]
384-
385-
386-
@register_useless
387-
@register_canonicalize
388-
@register_stabilize
389-
@register_specialize
390-
@node_rewriter([Subtensor])
391-
def local_replace_slice(fgraph, node):
392-
"""
393-
Rewrite Subtensor of the form:
348+
Also, rewrite Subtensor of the form:
394349
X[0:7:1] -> X[None:None:None]
395-
where X is a vector of length 7
350+
where X is a vector of length 7
396351
397352
"""
398353
idxs = get_idx_list(node.inputs, node.op.idx_list)
399354
x = node.inputs[0]
400355

401356
if not idxs:
402-
return
357+
return [node.inputs[0]]
403358

404359
new_idxs = list(idxs)
405-
idx_flag = False
360+
change_flag = False
361+
last_useful_idx = -1
406362
for dim, s in enumerate(new_idxs):
407363
if not isinstance(s, slice):
364+
last_useful_idx = dim
365+
continue
366+
367+
if s == slice(None):
408368
continue
409369

410370
start = s.start
411371
stop = s.stop
412372
step = s.step
413-
if extract_constant(start, only_process_constants=True) == 0:
414-
idx_flag = True
373+
if (
374+
start is not None
375+
and extract_constant(start, only_process_constants=True) == 0
376+
):
377+
change_flag = True
415378
start = None
416379

417380
if (
418-
x.type.shape[dim] is not None
381+
stop is not None
382+
and x.type.shape[dim] is not None
419383
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
420384
):
421-
idx_flag = True
385+
change_flag = True
422386
stop = None
423387

424-
if extract_constant(step, only_process_constants=True) == 1:
425-
idx_flag = True
388+
if (
389+
step is not None
390+
and extract_constant(step, only_process_constants=True) == 1
391+
):
392+
change_flag = True
426393
step = None
427394

395+
if not (start is None and stop is None and step is None):
396+
last_useful_idx = dim
397+
428398
new_idxs[dim] = slice(start, stop, step)
429399

430-
if idx_flag is True:
431-
out = x[tuple(new_idxs)]
400+
if change_flag or ((last_useful_idx + 1) < len(idxs)):
401+
out = x[tuple(new_idxs[: last_useful_idx + 1])]
432402
# Copy over previous output stacktrace
433403
copy_stack_trace(node.outputs, out)
434404

tests/tensor/rewriting/test_subtensor.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,22 +2404,36 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
24042404
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
24052405

24062406

2407-
@pytest.mark.parametrize("fstop, lstop, lstep", [(None, 9, 1), (-1, -1, -1)])
2408-
def test_slice_canonicalize(fstop, lstop, lstep):
2407+
def test_slice_canonicalize():
2408+
rng = np.random.default_rng(43)
24092409
x = tensor(shape=(3, 5, None, 9))
2410-
y = x[0:fstop, 0:5, 0:7, 0:lstop:lstep]
2411-
f = pytensor.function([x], y)
2412-
test_y = f.maker.fgraph.toposort()
2410+
test_x = rng.normal(size=(3, 5, 8, 9))
2411+
# Test case 1
2412+
y = x[0:None, 0:5, 0:7, 0:9:1]
2413+
f = pytensor.function([x], y, allow_input_downcast=True)
2414+
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
24132415

2414-
y1 = x[None:None:None, None:None:None, None:7:None, None:None:None]
2416+
expected_y = x[None:None:None, None:None:None, None:7:None]
24152417

2416-
if fstop == -1 and lstop == -1 and lstep == -1:
2417-
y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2418+
assert equal_computations([test_y], [expected_y])
24182419

2419-
f1 = pytensor.function([x], y1)
2420-
expected_y = f1.maker.fgraph.toposort()
2420+
np.testing.assert_allclose(
2421+
f(test_x),
2422+
test_x[
2423+
0:None, 0:5, 0:7, 0:9:1
2424+
], # Use the unoptimized slice to make sure our rewrite logic is correct
2425+
)
2426+
2427+
# Test case 2
2428+
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
2429+
f1 = pytensor.function([x], y1, allow_input_downcast=True)
2430+
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
2431+
2432+
expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2433+
2434+
assert equal_computations([test_y1], [expected_y1])
24212435

2422-
assert all(
2423-
equal_computations([x1], [y1])
2424-
for x1, y1 in zip(test_y[0].inputs, expected_y[0].inputs)
2436+
np.testing.assert_allclose(
2437+
f1(test_x),
2438+
test_x[0:-1, 0:5, 0:7, 0:-1:-1],
24252439
)

0 commit comments

Comments
 (0)