Skip to content

Commit 4785093

Browse files
committed
Normalise indices for setitem
Effort to test thoroughly, so trusting in NumPy's own tests down-the-line
1 parent 43b16c0 commit 4785093

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torch_np/_ndarray.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ def __getitem__(self, index):
297297
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self)
298298

299299
def __setitem__(self, index, value):
300+
index = _helpers.to_tensors(index)
301+
index = ndarray._upcast_int_indices(index)
300302
value = asarray(value).get()
301303
return self._tensor.__setitem__(index, value)
302304

0 commit comments

Comments
 (0)