Skip to content

Commit edccd3b

Browse files
committed
Smoke test for advance integer indexing
1 parent 1960e94 commit edccd3b

File tree

5 files changed

+66
-9
lines changed

5 files changed

+66
-9
lines changed

torch_np/_helpers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,23 @@ def result_or_out(result_tensor, out_array=None):
5656

5757
def ndarrays_to_tensors(*inputs):
5858
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
59-
return tuple(
60-
[value.get() if isinstance(value, ndarray) else value for value in inputs]
61-
)
59+
if len(inputs) == 0:
60+
return ValueError()
61+
elif len(inputs) == 1:
62+
input_ = inputs[0]
63+
if isinstance(input_, ndarray):
64+
return input_.get()
65+
elif isinstance(input_, tuple):
66+
result = []
67+
for sub_input in input_:
68+
sub_result = ndarrays_to_tensors(sub_input)
69+
result.append(sub_result)
70+
return tuple(result)
71+
else:
72+
return input_
73+
else:
74+
assert isinstance(inputs, tuple) # sanity check
75+
return ndarrays_to_tensors(inputs)
6276

6377

6478
def to_tensors(*inputs):

torch_np/_ndarray.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,9 @@ def nonzero(self):
282282
std = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.std)))
283283

284284
### indexing ###
285-
def __getitem__(self, *args, **kwds):
286-
t_args = _helpers.ndarrays_to_tensors(*args)
287-
return ndarray._from_tensor_and_base(
288-
self._tensor.__getitem__(*t_args, **kwds), self
289-
)
285+
def __getitem__(self, index):
286+
t_index = _helpers.ndarrays_to_tensors(index)
287+
return ndarray._from_tensor_and_base(self._tensor.__getitem__(t_index), self)
290288

291289
def __setitem__(self, index, value):
292290
value = asarray(value).get()

torch_np/_wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None):
233233

234234
def full(shape, fill_value, dtype=None, order="C", *, like=None):
235235
_util.subok_not_ok(like)
236+
if isinstance(shape, int):
237+
shape = (shape,)
236238
if order != "C":
237239
raise NotImplementedError
238240
if isinstance(fill_value, ndarray):

torch_np/tests/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import warnings
2+
3+
from hypothesis.errors import HypothesisWarning
4+
from hypothesis.extra.array_api import make_strategies_namespace
5+
6+
import torch_np
7+
8+
__all__ = ["xps"]
9+
10+
with warnings.catch_warnings():
11+
warnings.filterwarnings("ignore", category=HypothesisWarning)
12+
xps = make_strategies_namespace(torch_np, api_version="2021.12")

torch_np/tests/test_ndarray_methods.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
import itertools
22

33
import pytest
4-
from pytest import raises as assert_raises
54

65
# import numpy as np
6+
import torch
7+
from hypothesis import given
8+
from hypothesis import strategies as st
9+
from pytest import raises as assert_raises
10+
711
import torch_np as np
12+
from torch_np._ndarray import ndarray
813
from torch_np.testing import assert_equal
914

15+
from . import xps
16+
1017

1118
class TestIndexing:
1219
def test_indexing_simple(self):
@@ -23,6 +30,30 @@ def test_setitem(self):
2330
assert_equal(a, [[8, 2, 3], [4, 5, 6]])
2431

2532

33+
def integer_array_indices(shape, result_shape) -> st.SearchStrategy[tuple]:
34+
# See hypothesis.extra.numpy.integer_array_indices()
35+
# n.b. result_shape only accepts a shape, as opposed to only accepting a strategy
36+
def array_for(index_shape, size):
37+
return xps.arrays(
38+
dtype=xps.integer_dtypes(),
39+
shape=index_shape,
40+
elements=st.integers(-size, size - 1),
41+
)
42+
43+
return st.tuples(*(array_for(result_shape, size) for size in shape))
44+
45+
46+
@given(
47+
x=xps.arrays(dtype=xps.integer_dtypes(), shape=xps.array_shapes()),
48+
data=st.data(),
49+
)
50+
def test_integer_indexing(x, data):
51+
result_shape = data.draw(xps.array_shapes(), label="result_shape")
52+
idx = data.draw(integer_array_indices(x.shape, result_shape), label="idx")
53+
result = x[idx]
54+
assert result.shape == result_shape
55+
56+
2657
class TestReshape:
2758
def test_reshape_function(self):
2859
arr = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]

0 commit comments

Comments
 (0)