Skip to content

Commit 8267d0e

Browse files
committed
More robust check for multiple integer indices in numba ravel_multidimensional_idx rewrites
1 parent 4e85676 commit 8267d0e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytensor/tensor/rewriting/subtensor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
inc_subtensor,
8686
indices_from_subtensor,
8787
)
88-
from pytensor.tensor.type import TensorType
88+
from pytensor.tensor.type import TensorType, integer_dtypes
8989
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
9090
from pytensor.tensor.variable import TensorConstant, TensorVariable
9191

@@ -1981,7 +1981,7 @@ def ravel_multidimensional_bool_idx(fgraph, node):
19811981

19821982
if any(
19831983
(
1984-
(isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int"))
1984+
(isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes)
19851985
or isinstance(idx.type, NoneTypeT)
19861986
)
19871987
for idx in idxs
@@ -2052,7 +2052,7 @@ def ravel_multidimensional_int_idx(fgraph, node):
20522052
int_idxs = [
20532053
(i, idx)
20542054
for i, idx in enumerate(idxs)
2055-
if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int"))
2055+
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
20562056
]
20572057

20582058
if len(int_idxs) != 1:

0 commit comments

Comments
 (0)