Skip to content

Commit 43b16c0

Browse files
committed
Upcast int tensor indices
1 parent edccd3b commit 43b16c0

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

torch_np/_ndarray.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,19 @@ def nonzero(self):
282282
std = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.std)))
283283

284284
### indexing ###
285+
@staticmethod
286+
def _upcast_int_indices(index):
287+
if isinstance(index, torch.Tensor):
288+
if index.dtype in [torch.int8, torch.int16, torch.int32]:
289+
return index.type(torch.int64)
290+
elif isinstance(index, tuple):
291+
return tuple(ndarray._upcast_int_indices(i) for i in index)
292+
return index
293+
285294
def __getitem__(self, index):
286-
t_index = _helpers.ndarrays_to_tensors(index)
287-
return ndarray._from_tensor_and_base(self._tensor.__getitem__(t_index), self)
295+
index = _helpers.ndarrays_to_tensors(index)
296+
index = ndarray._upcast_int_indices(index)
297+
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self)
288298

289299
def __setitem__(self, index, value):
290300
value = asarray(value).get()

0 commit comments

Comments
 (0)