Skip to content

Commit 7b0a392

Browse files
committed
Lift Subtensor over AdvancedSubtensor
1 parent 58f1fd2 commit 7b0a392

File tree

2 files changed

+125
-2
lines changed

2 files changed

+125
-2
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from pytensor import Variable
7+
from pytensor.compile import optdb
78
from pytensor.graph import Constant, FunctionGraph, node_rewriter
89
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
910
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
@@ -37,16 +38,18 @@
3738
)
3839
from pytensor.tensor.special import Softmax, softmax
3940
from pytensor.tensor.subtensor import (
41+
AdvancedSubtensor,
4042
AdvancedSubtensor1,
4143
Subtensor,
44+
_non_consecutive_adv_indexing,
4245
as_index_literal,
4346
get_canonical_form_slice,
4447
get_constant_idx,
4548
get_idx_list,
4649
indices_from_subtensor,
4750
)
4851
from pytensor.tensor.type import TensorType
49-
from pytensor.tensor.type_other import SliceType
52+
from pytensor.tensor.type_other import NoneTypeT, SliceType
5053
from pytensor.tensor.variable import TensorVariable
5154

5255

@@ -769,3 +772,79 @@ def local_subtensor_shape_constant(fgraph, node):
769772
return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
770773
elif shape_parts:
771774
return [as_tensor(1, dtype=np.int64)]
775+
776+
777+
@node_rewriter([Subtensor])
778+
def local_subtensor_of_adv_subtensor(fgraph, node):
779+
"""Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones.
780+
781+
x[:, :, vec_idx][i, j] -> x[i, j][vec_idx]
782+
x[:, vec_idx][i, j, k] -> x[i][vec_idx][j, k]
783+
784+
Restricted to a single advanced indexing dimension.
785+
786+
An alternative approach could have fused the basic and advanced indices,
787+
so it is not clear this rewrite should be canonical or a specialization.
788+
Users must include it manually if it fits their use case.
789+
"""
790+
adv_subtensor, *idxs = node.inputs
791+
792+
if not (
793+
adv_subtensor.owner and isinstance(adv_subtensor.owner.op, AdvancedSubtensor)
794+
):
795+
return None
796+
797+
if len(fgraph.clients[adv_subtensor]) > 1:
798+
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice
799+
return None
800+
801+
x, *adv_idxs = adv_subtensor.owner.inputs
802+
803+
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
804+
if any(
805+
(
806+
isinstance(adv_idx.type, NoneTypeT)
807+
or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool")
808+
or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx))
809+
)
810+
for adv_idx in adv_idxs
811+
) or _non_consecutive_adv_indexing(adv_idxs):
812+
return None
813+
814+
for first_adv_idx_dim, adv_idx in enumerate(adv_idxs):
815+
# We already made sure there were only None slices besides integer indexes
816+
if isinstance(adv_idx.type, TensorType):
817+
break
818+
else: # no-break
819+
# Not sure if this should ever happen, but better safe than sorry
820+
return None
821+
822+
basic_idxs = indices_from_subtensor(idxs, node.op.idx_list)
823+
basic_idxs_lifted = basic_idxs[:first_adv_idx_dim]
824+
basic_idxs_kept = ((slice(None),) * len(basic_idxs_lifted)) + basic_idxs[
825+
first_adv_idx_dim:
826+
]
827+
828+
if all(basic_idx == slice(None) for basic_idx in basic_idxs_lifted):
829+
# All basic indices happen to the right of the advanced indices
830+
return None
831+
832+
[basic_subtensor] = node.outputs
833+
dropped_dims = _dims_dropped_by_basic_index(basic_idxs_lifted)
834+
835+
x_indexed = x[basic_idxs_lifted]
836+
copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed)
837+
838+
x_after_index_lift = expand_dims(x_indexed, dropped_dims)
839+
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs)
840+
copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx)
841+
842+
new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims)
843+
return [new_out]
844+
845+
846+
# Rewrite will only be included if tagged by name
847+
r = local_subtensor_of_adv_subtensor
848+
optdb["canonicalize"].register(r.__name__, r, use_db_name_as_tag=False)
849+
optdb["specialize"].register(r.__name__, r, use_db_name_as_tag=False)
850+
del r

tests/tensor/rewriting/test_subtensor_lift.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747
from pytensor.tensor.shape import SpecifyShape, _shape
4848
from pytensor.tensor.special import softmax
49-
from pytensor.tensor.subtensor import Subtensor
49+
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
5050

5151

5252
mode_opt = config.mode
@@ -695,3 +695,47 @@ def __eq__(self, other):
695695
x = shape(Variable(MyType(), None, None))[0]
696696

697697
assert not local_subtensor_shape_constant.transform(None, x.owner)
698+
699+
700+
@pytest.mark.parametrize(
701+
"original_fn, supported",
702+
[
703+
(lambda x: x[:, [0, 1]][0], True),
704+
(lambda x: x[:, [0, 1], [0, 0]][1:], True),
705+
(lambda x: x[:, [[0, 1], [0, 0]]][1:], True),
706+
# Not supported, basic indexing on advanced indexing dim
707+
(lambda x: x[[0, 1]][0], False),
708+
# Not implemented, basic indexing on the right of advanced indexing
709+
(lambda x: x[[0, 1]][:, 0], False),
710+
# Not implemented, complex flavors of advanced indexing
711+
(lambda x: x[:, None, [0, 1]][0], False),
712+
(lambda x: x[:, 5:, [0, 1]][0], False),
713+
(lambda x: x[:, :, np.array([True, False, False])][0], False),
714+
(lambda x: x[[0, 1], :, [0, 1]][:, 0], False),
715+
],
716+
)
717+
def test_local_subtensor_of_adv_subtensor(original_fn, supported):
718+
rng = np.random.default_rng(257)
719+
x = pt.tensor3("x", shape=(7, 5, 3))
720+
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
721+
722+
out = original_fn(x)
723+
opt_out = rewrite_graph(
724+
out, include=("canonicalize", "local_subtensor_of_adv_subtensor")
725+
)
726+
# The graphs generated are too complicated to assert
727+
# We simply check that the happens before the advanced subtensor
728+
toposort = FunctionGraph(outputs=[opt_out], clone=False).toposort()
729+
[idx_subtensor] = [
730+
i for i, node in enumerate(toposort) if isinstance(node.op, Subtensor)
731+
]
732+
[idx_adv_subtensor] = [
733+
i for i, node in enumerate(toposort) if isinstance(node.op, AdvancedSubtensor)
734+
]
735+
swapped = idx_subtensor < idx_adv_subtensor
736+
correct = swapped if supported else not swapped
737+
assert correct, debugprint(opt_out, print_type=True)
738+
np.testing.assert_allclose(
739+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
740+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
741+
)

0 commit comments

Comments
 (0)