Skip to content

Commit 8bf7472

Browse files
ricardoV94Ch0ronomato
authored andcommitted
Fix bug in local_useless_slice rewrite
Canonical slice start and stop values depend on the sign of the step. The rewrite wrongly assumed they were always 0:len(dim)
1 parent cdf3b62 commit 8bf7472

File tree

2 files changed

+94
-51
lines changed

2 files changed

+94
-51
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -342,14 +342,18 @@ def local_subtensor_of_dot(fgraph, node):
342342
@node_rewriter([Subtensor])
343343
def local_useless_slice(fgraph, node):
344344
"""
345-
Remove Subtensor of the form:
345+
Remove useless slice(None) of the form:
346346
1. X[0, :] -> X[0]
347347
2. X[:] -> X
348348
349-
Also, rewrite Subtensor of the form:
349+
Also, canonicalize slices of the form:
350350
X[0:7:1] -> X[None:None:None]
351351
where X is a vector of length 7
352352
353+
And:
354+
X[-1:-8:-1] -> X[::-1]
355+
where x is a vector of length 7
356+
353357
"""
354358
idxs = get_idx_list(node.inputs, node.op.idx_list)
355359
x = node.inputs[0]
@@ -368,32 +372,40 @@ def local_useless_slice(fgraph, node):
368372
if s == slice(None):
369373
continue
370374

375+
step = s.step
376+
377+
if step is None:
378+
positive_step = True
379+
elif isinstance(step, Constant):
380+
step_value = step.data
381+
positive_step = step.data > 0
382+
if step_value == 1:
383+
change_flag = True
384+
step = None
385+
else:
386+
# We can only canonicalize start and stop if we know the sign of step
387+
last_useful_idx = dim
388+
continue
389+
371390
start = s.start
372391
stop = s.stop
373-
step = s.step
374-
if (
375-
start is not None
376-
and extract_constant(start, only_process_constants=True) == 0
377-
):
392+
393+
if start is not None and extract_constant(
394+
start, only_process_constants=True
395+
) == (0 if positive_step else -1):
378396
change_flag = True
379397
start = None
380398

381399
if (
382400
stop is not None
383401
and x.type.shape[dim] is not None
384-
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
402+
and extract_constant(stop, only_process_constants=True)
403+
== (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1)
385404
):
386405
change_flag = True
387406
stop = None
388407

389-
if (
390-
step is not None
391-
and extract_constant(step, only_process_constants=True) == 1
392-
):
393-
change_flag = True
394-
step = None
395-
396-
if not (start is None and stop is None and step is None):
408+
if start is not None or stop is not None or step is not None:
397409
last_useful_idx = dim
398410

399411
new_idxs[dim] = slice(start, stop, step)
@@ -402,7 +414,6 @@ def local_useless_slice(fgraph, node):
402414
out = x[tuple(new_idxs[: last_useful_idx + 1])]
403415
# Copy over previous output stacktrace
404416
copy_stack_trace(node.outputs, out)
405-
406417
return [out]
407418

408419

tests/tensor/rewriting/test_subtensor.py

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,42 +2404,74 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
24042404
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
24052405

24062406

2407-
def test_slice_canonicalize():
2408-
rng = np.random.default_rng(43)
2409-
x = tensor(shape=(3, 5, None, 9))
2410-
test_x = rng.normal(size=(3, 5, 8, 9))
2411-
# Test case 1
2412-
y = x[0:None, 0:5, 0:7, 0:9:1]
2413-
f = pytensor.function([x], y, allow_input_downcast=True)
2414-
2415-
# Get the DeepCopy input and assert that the Op is a DeepCopy
2416-
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
2417-
assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
2418-
2419-
expected_y = x[None:None:None, None:None:None, None:7:None]
2420-
2421-
assert equal_computations([test_y], [expected_y])
2422-
2423-
np.testing.assert_allclose(
2424-
f(test_x),
2425-
test_x[
2426-
0:None, 0:5, 0:7, 0:9:1
2427-
], # Use the unoptimized slice to make sure our rewrite logic is correct
2428-
)
2407+
class TestUselessSlice:
2408+
def test_positive_step(self):
2409+
# When steps are positive, default start and end are `0` and `len(dim)`
2410+
x = tensor(shape=(3, 5, None, 9), dtype="float64")
2411+
test_x = np.random.normal(size=(3, 5, 8, 9))
2412+
2413+
y = x[0:3:1, 1:5:2, 0:7:1, 0:9:1]
2414+
f = pytensor.function([x], y)
2415+
2416+
# Get the DeepCopy input and assert that the Op is a DeepCopy
2417+
deep_copy_node = f.maker.fgraph.outputs[0].owner
2418+
assert isinstance(deep_copy_node.op, DeepCopyOp)
2419+
2420+
rewritten_y = deep_copy_node.inputs[0]
2421+
expected_y = x[None:None:None, 1:None:2, None:7:None]
2422+
assert equal_computations([rewritten_y], [expected_y])
2423+
2424+
np.testing.assert_allclose(
2425+
f(test_x),
2426+
# Use the unoptimized slice to make sure our rewrite logic is correct
2427+
test_x[0:3:1, 1:5:2, 0:7:1, 0:9:1],
2428+
)
24292429

2430-
# Test case 2
2431-
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
2432-
f1 = pytensor.function([x], y1, allow_input_downcast=True)
2430+
def test_negative_step(self):
2431+
# When steps are negative, default start and end are `-1` and `-len(dim) - 1`
2432+
x = tensor(shape=(3, 5, None, 9), dtype="float64")
2433+
test_x = np.random.normal(size=(3, 5, 8, 9))
24332434

2434-
# Get the DeepCopy input and assert that the Op is a DeepCopy
2435-
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
2436-
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
2435+
y = x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None]
2436+
f = pytensor.function([x], y)
24372437

2438-
expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2438+
# Get the DeepCopy input and assert that the Op is a DeepCopy
2439+
deep_copy_node = f.maker.fgraph.outputs[0].owner
2440+
assert isinstance(deep_copy_node.op, DeepCopyOp)
24392441

2440-
assert equal_computations([test_y1], [expected_y1])
2442+
rewritten_y = deep_copy_node.inputs[0]
2443+
expected_y = x[None:None:-1, 0:5:-2, None:-9:-1]
2444+
assert equal_computations([rewritten_y], [expected_y])
24412445

2442-
np.testing.assert_allclose(
2443-
f1(test_x),
2444-
test_x[0:-1, 0:5, 0:7, 0:-1:-1],
2445-
)
2446+
np.testing.assert_allclose(
2447+
f(test_x),
2448+
test_x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None],
2449+
)
2450+
2451+
def test_unknown_step(self):
2452+
# If step isn't known, we can't canonicalize start and stop points
2453+
step = pt.scalar("step", dtype=int)
2454+
x = tensor(shape=(3, 5, None), dtype="float64")
2455+
test_x = np.random.normal(size=(3, 5, 7))
2456+
2457+
y = x[0:3:step, -1:-6:-step, ::]
2458+
# Need this rewrite when `FAST_COMPILE` otherwise step = -1 * step instead of neg(step)
2459+
mode = get_default_mode().including("local_mul_specialize")
2460+
f = pytensor.function([x, step], y, mode=mode)
2461+
2462+
# Get the DeepCopy input and assert that the Op is a DeepCopy
2463+
deep_copy_node = f.maker.fgraph.outputs[0].owner
2464+
assert isinstance(deep_copy_node.op, DeepCopyOp)
2465+
2466+
rewritten_y = deep_copy_node.inputs[0]
2467+
expected_y = x[0:3:step, -1:-6:-step]
2468+
assert equal_computations([rewritten_y], [expected_y])
2469+
2470+
np.testing.assert_allclose(
2471+
f(test_x, 1),
2472+
test_x[0:3:1, -1:-6:-1, ::],
2473+
)
2474+
np.testing.assert_allclose(
2475+
f(test_x, -2),
2476+
test_x[0:3:-2, -1:-6:2, ::],
2477+
)

0 commit comments

Comments
 (0)