Skip to content

Commit 82a5757

Browse files
ricardoV94lucianopaz
authored andcommitted
Implement vectorize_node dispatch for some forms of AdvancedSubtensor
1 parent 56637af commit 82a5757

File tree

2 files changed

+152
-2
lines changed

2 files changed

+152
-2
lines changed

pytensor/tensor/subtensor.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
zscalar,
4848
)
4949
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice
50+
from pytensor.tensor.variable import TensorVariable
5051

5152

5253
_logger = logging.getLogger("pytensor.tensor.subtensor")
@@ -473,6 +474,13 @@ def group_indices(indices):
473474
return idx_groups
474475

475476

477+
def _non_contiguous_adv_indexing(indices) -> bool:
478+
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing)."""
479+
idx_groups = group_indices(indices)
480+
# This means that there are at least two groups of advanced indexing separated by basic indexing
481+
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])
482+
483+
476484
def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
477485
"""Compute the symbolic shape resulting from `a[indices]` for `a.shape == array_shape`.
478486
@@ -497,8 +505,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
497505
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
498506
idx_groups = group_indices(indices)
499507

500-
if len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0]):
501-
# This means that there are at least two groups of advanced indexing separated by basic indexing
508+
if _non_contiguous_adv_indexing(indices):
502509
# In this case NumPy places the advanced index groups in the front of the array
503510
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
504511
idx_groups = sorted(idx_groups, key=lambda x: x[0])
@@ -2682,10 +2689,68 @@ def grad(self, inputs, grads):
26822689
rest
26832690
)
26842691

2692+
@staticmethod
2693+
def non_contiguous_adv_indexing(node: Apply) -> bool:
2694+
"""
2695+
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
2696+
2697+
This function checks if the advanced indexing is non-contiguous,
2698+
in which case the advanced index dimensions are placed on the left of the
2699+
output array, regardless of their opriginal position.
2700+
2701+
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
2702+
2703+
2704+
Parameters
2705+
----------
2706+
node : Apply
2707+
The node of the AdvancedSubtensor operation.
2708+
2709+
Returns
2710+
-------
2711+
bool
2712+
True if the advanced indexing is non-contiguous, False otherwise.
2713+
"""
2714+
_, *idxs = node.inputs
2715+
return _non_contiguous_adv_indexing(idxs)
2716+
26852717

26862718
advanced_subtensor = AdvancedSubtensor()
26872719

26882720

2721+
@_vectorize_node.register(AdvancedSubtensor)
2722+
def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
2723+
x, *idxs = node.inputs
2724+
batch_x, *batch_idxs = batch_inputs
2725+
2726+
x_is_batched = x.type.ndim < batch_x.type.ndim
2727+
idxs_are_batched = any(
2728+
batch_idx.type.ndim > idx.type.ndim
2729+
for batch_idx, idx in zip(batch_idxs, idxs)
2730+
if isinstance(batch_idx, TensorVariable)
2731+
)
2732+
2733+
if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)):
2734+
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
2735+
# which would put the indexed results to the left of the batch dimensions!
2736+
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
2737+
2738+
# Blockwise doesn't accept None or Slices types so we raise informative error here
2739+
# TODO: Implement these internally, so Blockwise is always a safe fallback
2740+
if any(not isinstance(idx, TensorVariable) for idx in idxs):
2741+
raise NotImplementedError(
2742+
"Vectorized AdvancedSubtensor with batched indexes or non-contiguous advanced indexing "
2743+
"and slices or newaxis is currently not supported."
2744+
)
2745+
else:
2746+
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
2747+
2748+
# Otherwise we just need to add None slices for every new batch dim
2749+
x_batch_ndim = batch_x.type.ndim - x.type.ndim
2750+
empty_slices = (slice(None),) * x_batch_ndim
2751+
return op.make_node(batch_x, *empty_slices, *batch_idxs)
2752+
2753+
26892754
class AdvancedIncSubtensor(Op):
26902755
"""Increments a subtensor using advanced indexing."""
26912756

tests/tensor/test_subtensor.py

+85
Original file line numberDiff line numberDiff line change
@@ -2751,3 +2751,88 @@ def core_fn(x, start):
27512751
vectorize_pt(x_test, start_test),
27522752
vectorize_np(x_test, start_test),
27532753
)
2754+
2755+
2756+
@pytest.mark.parametrize(
2757+
"core_idx_fn, signature, x_shape, idx_shape, uses_blockwise",
2758+
[
2759+
# Core case
2760+
((lambda x, idx: x[:, idx, :]), "(7,5,3),(2)->(7,2,3)", (7, 5, 3), (2,), False),
2761+
# Batched x, core idx
2762+
(
2763+
(lambda x, idx: x[:, idx, :]),
2764+
"(7,5,3),(2)->(7,2,3)",
2765+
(11, 7, 5, 3),
2766+
(2,),
2767+
False,
2768+
),
2769+
(
2770+
(lambda x, idx: x[idx, None]),
2771+
"(5,7,3),(2)->(2,1,7,3)",
2772+
(11, 5, 7, 3),
2773+
(2,),
2774+
False,
2775+
),
2776+
# (this is currently failing because PyTensor tries to vectorize the slice(None) operation,
2777+
# due to the exact same None constant being used there and in the np.newaxis)
2778+
pytest.param(
2779+
(lambda x, idx: x[:, idx, None]),
2780+
"(7,5,3),(2)->(7,2,1,3)",
2781+
(11, 7, 5, 3),
2782+
(2,),
2783+
False,
2784+
marks=pytest.mark.xfail(raises=NotImplementedError),
2785+
),
2786+
(
2787+
(lambda x, idx: x[:, idx, idx, :]),
2788+
"(7,5,5,3),(2)->(7,2,3)",
2789+
(11, 7, 5, 5, 3),
2790+
(2,),
2791+
False,
2792+
),
2793+
# (not supported, because fallback Blocwise can't handle slices)
2794+
pytest.param(
2795+
(lambda x, idx: x[:, idx, :, idx]),
2796+
"(7,5,3,5),(2)->(2,7,3)",
2797+
(11, 7, 5, 3, 5),
2798+
(2,),
2799+
True,
2800+
marks=pytest.mark.xfail(raises=NotImplementedError),
2801+
),
2802+
# Core x, batched idx
2803+
((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (7,), (11, 2), True),
2804+
# Batched x, batched idx
2805+
((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (11, 7), (11, 2), True),
2806+
# (not supported, because fallback Blocwise can't handle slices)
2807+
pytest.param(
2808+
(lambda x, idx: x[:, idx, :]),
2809+
"(t1,t2,t3),(idx)->(t1,tx,t3)",
2810+
(11, 7, 5, 3),
2811+
(11, 2),
2812+
True,
2813+
marks=pytest.mark.xfail(raises=NotImplementedError),
2814+
),
2815+
],
2816+
)
2817+
def test_vectorize_adv_subtensor(
2818+
core_idx_fn, signature, x_shape, idx_shape, uses_blockwise
2819+
):
2820+
x = tensor(shape=x_shape, dtype="float64")
2821+
idx = tensor(shape=idx_shape, dtype="int64")
2822+
vectorize_pt = function(
2823+
[x, idx], vectorize(core_idx_fn, signature=signature)(x, idx)
2824+
)
2825+
2826+
has_blockwise = any(
2827+
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
2828+
)
2829+
assert has_blockwise == uses_blockwise
2830+
2831+
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
2832+
# Idx dimension should be length 5
2833+
idx_test = np.random.randint(0, 5, size=idx.type.shape)
2834+
vectorize_np = np.vectorize(core_idx_fn, signature=signature)
2835+
np.testing.assert_allclose(
2836+
vectorize_pt(x_test, idx_test),
2837+
vectorize_np(x_test, idx_test),
2838+
)

0 commit comments

Comments
 (0)