Skip to content

Commit 23c2bf3

Browse files
committed
Accept uint8s as advance integer indices
1 parent 75b95f4 commit 23c2bf3

File tree

2 files changed

+3
-14
lines changed

2 files changed

+3
-14
lines changed

torch_np/_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def clip(self, min, max, out=None):
414414
@staticmethod
415415
def _upcast_int_indices(index):
416416
if isinstance(index, torch.Tensor):
417-
if index.dtype in [torch.int8, torch.int16, torch.int32]:
417+
if index.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
418418
return index.to(torch.int64)
419419
elif isinstance(index, tuple):
420420
return tuple(ndarray._upcast_int_indices(i) for i in index)

torch_np/tests/numpy_tests/core/test_indexing.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,11 @@ def test_void_scalar_empty_tuple(self):
129129
assert_equal(s[()], s)
130130
assert_equal(type(s[...]), np.ndarray)
131131

132-
@pytest.mark.xfail(
133-
reason=(
134-
"torch does not support integer indexing int tensors with uints - "
135-
"torch instead treats uint8 tensors as boolean masks (deprecated)"
136-
)
137-
)
138132
def test_same_kind_index_casting(self):
139133
# Indexes should be cast with same-kind and not safe, even if that
140134
# is somewhat unsafe. So test various different code paths.
141135
index = np.arange(5)
142-
u_index = index.astype(np.uintp) # i.e. cast to default uint indexing dtype
136+
u_index = index.astype(np.uint8) # i.e. cast to default uint indexing dtype
143137
arr = np.arange(10)
144138

145139
assert_array_equal(arr[index], arr[u_index])
@@ -150,6 +144,7 @@ def test_same_kind_index_casting(self):
150144
assert_array_equal(arr[index], arr[u_index])
151145

152146
arr[u_index] = np.arange(5)[:,None]
147+
pytest.xfail("XXX: repeat() not implemented")
153148
assert_array_equal(arr, np.arange(5)[:,None].repeat(2, axis=1))
154149

155150
arr = np.arange(25).reshape(5, 5)
@@ -488,12 +483,6 @@ def __array__(self):
488483
assert_(isinstance(a[z, np.array(0)], np.ndarray))
489484
assert_(isinstance(a[z, ArrayLike()], np.ndarray))
490485

491-
@pytest.mark.xfail(
492-
reason=(
493-
"torch does not support integer indexing int tensors with uints - "
494-
"torch instead treats uint8 tensors as boolean masks (deprecated)"
495-
)
496-
)
497486
def test_small_regressions(self):
498487
# Reference count of intp for index checks
499488
a = np.array([0])

0 commit comments

Comments
 (0)