File tree Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -282,9 +282,19 @@ def nonzero(self):
282
282
std = emulate_out_arg (axis_keepdims_wrapper (dtype_to_torch (_reductions .std )))
283
283
284
284
### indexing ###
285
+ @staticmethod
286
+ def _upcast_int_indices (index ):
287
+ if isinstance (index , torch .Tensor ):
288
+ if index .dtype in [torch .int8 , torch .int16 , torch .int32 ]:
289
+ return index .type (torch .int64 )
290
+ elif isinstance (index , tuple ):
291
+ return tuple (ndarray ._upcast_int_indices (i ) for i in index )
292
+ return index
293
+
285
294
def __getitem__ (self , index ):
286
- t_index = _helpers .ndarrays_to_tensors (index )
287
- return ndarray ._from_tensor_and_base (self ._tensor .__getitem__ (t_index ), self )
295
+ index = _helpers .ndarrays_to_tensors (index )
296
+ index = ndarray ._upcast_int_indices (index )
297
+ return ndarray ._from_tensor_and_base (self ._tensor .__getitem__ (index ), self )
288
298
289
299
def __setitem__ (self , index , value ):
290
300
value = asarray (value ).get ()
You can’t perform that action at this time.
0 commit comments