Skip to content

Commit 8550622

Browse files
committed
Support single multidimensional indexing in Numba via rewrites
1 parent d512271 commit 8550622

File tree

2 files changed

+157
-5
lines changed

2 files changed

+157
-5
lines changed

pytensor/tensor/rewriting/subtensor.py

+109
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytensor
88
import pytensor.scalar.basic as ps
99
from pytensor import compile
10+
from pytensor.compile import optdb
1011
from pytensor.graph.basic import Constant, Variable
1112
from pytensor.graph.rewriting.basic import (
1213
WalkingGraphRewriter,
@@ -1932,3 +1933,111 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
19321933
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
19331934
copy_stack_trace(node.outputs, new_out)
19341935
return new_out
1936+
1937+
1938+
@node_rewriter(tracks=[AdvancedSubtensor])
1939+
def ravel_multidimensional_bool_idx(fgraph, node):
1940+
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
1941+
1942+
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
1943+
"""
1944+
x, *idxs = node.inputs
1945+
1946+
if any(
1947+
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")
1948+
for idx in idxs
1949+
):
1950+
# Get out if there are any other advanced indexes
1951+
return None
1952+
1953+
bool_idxs = [
1954+
(i, idx)
1955+
for i, idx in enumerate(idxs)
1956+
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
1957+
]
1958+
1959+
if len(bool_idxs) != 1:
1960+
# Get out if there are no or multiple boolean idxs
1961+
return None
1962+
1963+
[(bool_idx_pos, bool_idx)] = bool_idxs
1964+
bool_idx_ndim = bool_idx.type.ndim
1965+
if bool_idx.type.ndim < 2:
1966+
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
1967+
return None
1968+
1969+
x_shape = x.shape
1970+
raveled_x = x.reshape(
1971+
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :])
1972+
)
1973+
1974+
raveled_bool_idx = bool_idx.ravel()
1975+
new_idxs = list(idxs)
1976+
new_idxs[bool_idx_pos] = raveled_bool_idx
1977+
1978+
return [raveled_x[tuple(new_idxs)]]
1979+
1980+
1981+
@node_rewriter(tracks=[AdvancedSubtensor])
1982+
def ravel_multidimensional_int_idx(fgraph, node):
1983+
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
1984+
1985+
x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
1986+
1987+
1988+
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
1989+
1990+
x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
1991+
"""
1992+
x, *idxs = node.inputs
1993+
1994+
if any(
1995+
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool")
1996+
for idx in idxs
1997+
):
1998+
# Get out if there are any other advanced indexes
1999+
return None
2000+
2001+
int_idxs = [
2002+
(i, idx)
2003+
for i, idx in enumerate(idxs)
2004+
if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int"))
2005+
]
2006+
2007+
if len(int_idxs) != 1:
2008+
# Get out if there are no or multiple integer idxs
2009+
return None
2010+
2011+
[(int_idx_pos, int_idx)] = int_idxs
2012+
if int_idx.type.ndim < 2:
2013+
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
2014+
return None
2015+
2016+
raveled_int_idx = int_idx.ravel()
2017+
new_idxs = list(idxs)
2018+
new_idxs[int_idx_pos] = raveled_int_idx
2019+
raveled_subtensor = x[tuple(new_idxs)]
2020+
2021+
# Reshape into correct shape
2022+
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
2023+
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
2024+
raveled_shape = raveled_subtensor.shape
2025+
unraveled_shape = (
2026+
*raveled_shape[:int_idx_pos],
2027+
*int_idx.shape,
2028+
*raveled_shape[int_idx_pos + 1 :],
2029+
)
2030+
return [raveled_subtensor.reshape(unraveled_shape)]
2031+
2032+
2033+
optdb["specialize"].register(
2034+
ravel_multidimensional_bool_idx.__name__,
2035+
ravel_multidimensional_bool_idx,
2036+
"numba",
2037+
)
2038+
2039+
optdb["specialize"].register(
2040+
ravel_multidimensional_int_idx.__name__,
2041+
ravel_multidimensional_int_idx,
2042+
"numba",
2043+
)

tests/link/numba/test_subtensor.py

+48-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
inc_subtensor,
2020
set_subtensor,
2121
)
22-
from tests.link.numba.test_basic import compare_numba_and_py
22+
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
2323

2424

2525
rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
@@ -74,6 +74,7 @@ def test_AdvancedSubtensor1_out_of_bounds():
7474
@pytest.mark.parametrize(
7575
"x, indices, objmode_needed",
7676
[
77+
# Single vector indexing (supported natively by Numba)
7778
(
7879
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
7980
(0, [1, 2, 2, 3]),
@@ -84,25 +85,63 @@ def test_AdvancedSubtensor1_out_of_bounds():
8485
(np.array([True, False, False])),
8586
False,
8687
),
88+
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
89+
# Single multidimensional indexing (supported after specialization rewrites)
90+
(
91+
as_tensor(np.arange(3 * 3).reshape((3, 3))),
92+
(np.eye(3).astype(int)),
93+
False,
94+
),
8795
(
8896
as_tensor(np.arange(3 * 3).reshape((3, 3))),
8997
(np.eye(3).astype(bool)),
98+
False,
99+
),
100+
(
101+
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
102+
(np.eye(3).astype(int)),
103+
False,
104+
),
105+
(
106+
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
107+
(np.eye(3).astype(bool)),
108+
False,
109+
),
110+
(
111+
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
112+
(slice(2, None), np.eye(3).astype(int)),
113+
False,
114+
),
115+
(
116+
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
117+
(slice(2, None), np.eye(3).astype(bool)),
118+
False,
119+
),
120+
# Multiple advanced indexing, only supported in obj mode
121+
(
122+
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
123+
(slice(None), [1, 2], [3, 4]),
90124
True,
91125
),
92-
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
93126
(
94127
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
95128
([1, 2], slice(None), [3, 4]),
96129
True,
97130
),
131+
(
132+
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
133+
([[1, 2], [2, 1]], [0, 0]),
134+
True,
135+
),
98136
],
99137
)
100138
@pytest.mark.filterwarnings("error")
101139
def test_AdvancedSubtensor(x, indices, objmode_needed):
102140
"""Test NumPy's advanced indexing in more than one dimension."""
103-
out_pt = x[indices]
141+
x_pt = x.type()
142+
out_pt = x_pt[indices]
104143
assert isinstance(out_pt.owner.op, AdvancedSubtensor)
105-
out_fg = FunctionGraph([], [out_pt])
144+
out_fg = FunctionGraph([x_pt], [out_pt])
106145
with (
107146
pytest.warns(
108147
UserWarning,
@@ -111,7 +150,11 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
111150
if objmode_needed
112151
else contextlib.nullcontext()
113152
):
114-
compare_numba_and_py(out_fg, [])
153+
compare_numba_and_py(
154+
out_fg,
155+
[x.data],
156+
numba_mode=numba_mode.including("specialize"),
157+
)
115158

116159

117160
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)