File tree 1 file changed +3
-3
lines changed
pytensor/tensor/rewriting
1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change 85
85
inc_subtensor ,
86
86
indices_from_subtensor ,
87
87
)
88
- from pytensor .tensor .type import TensorType
88
+ from pytensor .tensor .type import TensorType , integer_dtypes
89
89
from pytensor .tensor .type_other import NoneTypeT , SliceConstant , SliceType
90
90
from pytensor .tensor .variable import TensorConstant , TensorVariable
91
91
@@ -1981,7 +1981,7 @@ def ravel_multidimensional_bool_idx(fgraph, node):
1981
1981
1982
1982
if any (
1983
1983
(
1984
- (isinstance (idx .type , TensorType ) and idx .type .dtype . startswith ( "int" ) )
1984
+ (isinstance (idx .type , TensorType ) and idx .type .dtype in integer_dtypes )
1985
1985
or isinstance (idx .type , NoneTypeT )
1986
1986
)
1987
1987
for idx in idxs
@@ -2052,7 +2052,7 @@ def ravel_multidimensional_int_idx(fgraph, node):
2052
2052
int_idxs = [
2053
2053
(i , idx )
2054
2054
for i , idx in enumerate (idxs )
2055
- if (isinstance (idx .type , TensorType ) and idx .dtype . startswith ( "int" ) )
2055
+ if (isinstance (idx .type , TensorType ) and idx .dtype in integer_dtypes )
2056
2056
]
2057
2057
2058
2058
if len (int_idxs ) != 1 :
You can’t perform that action at this time.
0 commit comments