From 1f6319b42620978f42a387ddd636ba2687d0add0 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 6 May 2024 08:32:16 +0530 Subject: [PATCH 1/3] Canonicalize subtensor slices --- pytensor/tensor/rewriting/subtensor.py | 53 ++++++++++++++++++++++++ tests/tensor/rewriting/test_subtensor.py | 23 +++++++++- 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 8ee86e6021..6434d883b6 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -368,6 +368,7 @@ def local_useless_slice(fgraph, node): # check if we removed something if last_useless_slice < len(idxs): new_idxs = idxs[:last_useless_slice] + if new_idxs: new_subtensor = Subtensor(new_idxs) new_subtensor_inputs = get_slice_elements( @@ -382,6 +383,58 @@ def local_useless_slice(fgraph, node): return [node.inputs[0]] +@register_useless +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([Subtensor]) +def local_replace_slice(fgraph, node): + """ + Rewrite Subtensor of the form: + X[0:7:1] -> X[None:None:None] + where X is a vector of length 7 + + """ + idxs = get_idx_list(node.inputs, node.op.idx_list) + x = node.inputs[0] + + if not idxs: + return + + new_idxs = list(idxs) + idx_flag = False + for dim, s in enumerate(new_idxs): + if not isinstance(s, slice): + continue + + start = s.start + stop = s.stop + step = s.step + if extract_constant(start, only_process_constants=True) == 0: + idx_flag = True + start = None + + if ( + x.type.shape[dim] is not None + and extract_constant(stop, only_process_constants=True) == x.type.shape[dim] + ): + idx_flag = True + stop = None + + if extract_constant(step, only_process_constants=True) == 1: + idx_flag = True + step = None + + new_idxs[dim] = slice(start, stop, step) + + if idx_flag is True: + out = x[tuple(new_idxs)] + # Copy over previous output stacktrace + copy_stack_trace(node.outputs, out) + + return [out] + + # fast_compile to allow opt subtensor(cast{float32}(make_vector)) @register_canonicalize("fast_compile") @node_rewriter([Subtensor]) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index f7ea7cdce4..a2b6fec428 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -10,7 +10,7 @@ from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config from pytensor.graph import FunctionGraph, vectorize_graph -from pytensor.graph.basic import Constant, Variable, ancestors +from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import rewrite_graph @@ -2402,3 +2402,24 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc): else: expected_out[:, :, core_idxs] += test_y np.testing.assert_allclose(fn(test_x, test_y), expected_out) + + +@pytest.mark.parametrize("fstop, lstop, lstep", [(None, 9, 1), (-1, -1, -1)]) +def test_slice_canonicalize(fstop, lstop, lstep): + x = tensor(shape=(3, 5, None, 9)) + y = x[0:fstop, 0:5, 0:7, 0:lstop:lstep] + f = pytensor.function([x], y) + test_y = f.maker.fgraph.toposort() + + y1 = x[None:None:None, None:None:None, None:7:None, None:None:None] + + if fstop == -1 and lstop == -1 and lstep == -1: + y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1] + + f1 = pytensor.function([x], y1) + expected_y = f1.maker.fgraph.toposort() + + assert all( + equal_computations([x1], [y1]) + for x1, y1 in zip(test_y[0].inputs, expected_y[0].inputs) + ) From 479bd7b92192f036362e4d14f718688a2910bac3 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Wed, 29 May 2024 18:35:17 +0530 Subject: [PATCH 2/3] Merge Canonicalize slice and useless slice rewrites --- pytensor/tensor/rewriting/subtensor.py | 86 ++++++++---------------- tests/tensor/rewriting/test_subtensor.py | 40 +++++++---- 2 files changed, 55 insertions(+), 71 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 6434d883b6..f234b46804 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -337,6 +337,7 @@ def local_subtensor_of_dot(fgraph, node): @register_useless @register_canonicalize @register_specialize +@register_stabilize @node_rewriter([Subtensor]) def local_useless_slice(fgraph, node): """ @@ -344,91 +345,60 @@ def local_useless_slice(fgraph, node): 1. X[0, :] -> X[0] 2. X[:] -> X - """ - idxs = get_idx_list(node.inputs, node.op.idx_list) - - if not idxs: - return [node.inputs[0]] - - last_useless_slice = len(idxs) - for s in idxs[::-1]: - # check if slice and then check slice indices - if ( - isinstance(s, slice) - and s.start is None - and s.stop is None - and ( - s.step is None - or extract_constant(s.step, only_process_constants=True) == 1 - ) - ): - last_useless_slice -= 1 - else: - break - # check if we removed something - if last_useless_slice < len(idxs): - new_idxs = idxs[:last_useless_slice] - - if new_idxs: - new_subtensor = Subtensor(new_idxs) - new_subtensor_inputs = get_slice_elements( - new_idxs, lambda x: isinstance(x, Variable) - ) - out = new_subtensor(node.inputs[0], *new_subtensor_inputs) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs, out) - return [out] - else: - # Subtensor is not needed at all - return [node.inputs[0]] - - -@register_useless -@register_canonicalize -@register_stabilize -@register_specialize -@node_rewriter([Subtensor]) -def local_replace_slice(fgraph, node): - """ - Rewrite Subtensor of the form: + Also, rewrite Subtensor of the form: X[0:7:1] -> X[None:None:None] - where X is a vector of length 7 + where X is a vector of length 7 """ idxs = get_idx_list(node.inputs, node.op.idx_list) x = node.inputs[0] if not idxs: - return + return [node.inputs[0]] new_idxs = list(idxs) - idx_flag = False + change_flag = False + last_useful_idx = -1 for dim, s in enumerate(new_idxs): if not isinstance(s, slice): + last_useful_idx = dim + continue + + if s == slice(None): continue start = s.start stop = s.stop step = s.step - if extract_constant(start, only_process_constants=True) == 0: - idx_flag = True + if ( + start is not None + and extract_constant(start, only_process_constants=True) == 0 + ): + change_flag = True start = None if ( - x.type.shape[dim] is not None + stop is not None + and x.type.shape[dim] is not None and extract_constant(stop, only_process_constants=True) == x.type.shape[dim] ): - idx_flag = True + change_flag = True stop = None - if extract_constant(step, only_process_constants=True) == 1: - idx_flag = True + if ( + step is not None + and extract_constant(step, only_process_constants=True) == 1 + ): + change_flag = True step = None + if not (start is None and stop is None and step is None): + last_useful_idx = dim + new_idxs[dim] = slice(start, stop, step) - if idx_flag is True: - out = x[tuple(new_idxs)] + if change_flag or ((last_useful_idx + 1) < len(idxs)): + out = x[tuple(new_idxs[: last_useful_idx + 1])] # Copy over previous output stacktrace copy_stack_trace(node.outputs, out) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index a2b6fec428..f8e5d27f29 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -2404,22 +2404,36 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc): np.testing.assert_allclose(fn(test_x, test_y), expected_out) -@pytest.mark.parametrize("fstop, lstop, lstep", [(None, 9, 1), (-1, -1, -1)]) -def test_slice_canonicalize(fstop, lstop, lstep): +def test_slice_canonicalize(): + rng = np.random.default_rng(43) x = tensor(shape=(3, 5, None, 9)) - y = x[0:fstop, 0:5, 0:7, 0:lstop:lstep] - f = pytensor.function([x], y) - test_y = f.maker.fgraph.toposort() + test_x = rng.normal(size=(3, 5, 8, 9)) + # Test case 1 + y = x[0:None, 0:5, 0:7, 0:9:1] + f = pytensor.function([x], y, allow_input_downcast=True) + test_y = f.maker.fgraph.outputs[0].owner.inputs[0] - y1 = x[None:None:None, None:None:None, None:7:None, None:None:None] + expected_y = x[None:None:None, None:None:None, None:7:None] - if fstop == -1 and lstop == -1 and lstep == -1: - y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1] + assert equal_computations([test_y], [expected_y]) - f1 = pytensor.function([x], y1) - expected_y = f1.maker.fgraph.toposort() + np.testing.assert_allclose( + f(test_x), + test_x[ + 0:None, 0:5, 0:7, 0:9:1 + ], # Use the unoptimized slice to make sure our rewrite logic is correct + ) + + # Test case 2 + y1 = x[0:-1, 0:5, 0:7, 0:-1:-1] + f1 = pytensor.function([x], y1, allow_input_downcast=True) + test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0] + + expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1] + + assert equal_computations([test_y1], [expected_y1]) - assert all( - equal_computations([x1], [y1]) - for x1, y1 in zip(test_y[0].inputs, expected_y[0].inputs) + np.testing.assert_allclose( + f1(test_x), + test_x[0:-1, 0:5, 0:7, 0:-1:-1], ) From 0f6a7a0ec1b14f57f949bc977781bd090978e746 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Thu, 26 Sep 2024 15:28:34 +0530 Subject: [PATCH 3/3] assert Deepcopy input while canonicalising subtensor slices --- tests/tensor/rewriting/test_subtensor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index f8e5d27f29..91575bc7da 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -2411,7 +2411,10 @@ def test_slice_canonicalize(): # Test case 1 y = x[0:None, 0:5, 0:7, 0:9:1] f = pytensor.function([x], y, allow_input_downcast=True) + + # Get the DeepCopy input and assert that the Op is a DeepCopy test_y = f.maker.fgraph.outputs[0].owner.inputs[0] + assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp) expected_y = x[None:None:None, None:None:None, None:7:None] @@ -2427,7 +2430,10 @@ def test_slice_canonicalize(): # Test case 2 y1 = x[0:-1, 0:5, 0:7, 0:-1:-1] f1 = pytensor.function([x], y1, allow_input_downcast=True) + + # Get the DeepCopy input and assert that the Op is a DeepCopy test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0] + assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp) expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]