Skip to content

Commit b184194

Browse files
committed
ENH: test fancy indexing with intex arrays and integers
Only test with 1D index arrays, for now
1 parent 7946772 commit b184194

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

array_api_tests/test_array_object.py

+70
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,76 @@ def test_setitem_masking(shape, data):
242242
)
243243

244244

245+
@pytest.mark.min_version("2024.12")
246+
@pytest.mark.unvectorized
247+
@given(shape=hh.shapes(min_dims=2), data=st.data())
248+
def test_getitem_arrays_and_ints_1(shape, data):
249+
# min_dims=2 : test multidim `x` arrays
250+
# index arrays are all 1D
251+
_test_getitem_arrays_and_ints_1D(shape, data)
252+
253+
254+
@pytest.mark.min_version("2024.12")
255+
@pytest.mark.unvectorized
256+
@given(shape=hh.shapes(min_dims=1), data=st.data())
257+
def test_getitem_arrays_and_ints_2(shape, data):
258+
# min_dims=1 : favor 1D `x` arrays
259+
# index arrays are all 1D
260+
_test_getitem_arrays_and_ints_1D(shape, data)
261+
262+
263+
def _test_getitem_arrays_and_ints_1D(shape, data):
264+
assume((len(shape) > 0) and all(sh > 0 for sh in shape))
265+
266+
dtype = xp.int32
267+
obj = data.draw(scalar_objects(dtype, shape), label="obj")
268+
x = xp.asarray(obj, dtype=dtype)
269+
270+
# draw a mix of ints and index arrays
271+
arr_index = [data.draw(st.booleans()) for _ in range(len(shape))]
272+
assume(sum(arr_index) > 0)
273+
274+
# draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
275+
if sum(arr_index) > 0:
276+
index_shapes = data.draw(
277+
hh.mutually_broadcastable_shapes(
278+
sum(arr_index), min_dims=1, max_dims=1, min_side=1
279+
)
280+
)
281+
index_shapes = list(index_shapes)
282+
283+
# prepare the indexing tuple, a mix of integer indices and index arrays
284+
key = []
285+
for i,typ in enumerate(arr_index):
286+
if typ:
287+
# draw an array index
288+
this_idx = data.draw(
289+
xps.arrays(
290+
dtype,
291+
shape=index_shapes.pop(),
292+
elements=st.integers(0, shape[i]-1)
293+
)
294+
)
295+
key.append(this_idx)
296+
297+
else:
298+
# draw an integer
299+
key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
300+
301+
# print(f"??? {x.shape = } {key = }")
302+
303+
key = tuple(key)
304+
out = x[key]
305+
306+
arrays = [xp.asarray(k) for k in key]
307+
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
308+
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
309+
310+
for idx in sh.ndindex(bcast_shape):
311+
tpl = tuple(k[idx] for k in bcast_key)
312+
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
313+
314+
245315
def make_scalar_casting_param(
246316
method_name: str, dtype: DataType, stype: ScalarType
247317
) -> Param:

0 commit comments

Comments
 (0)