Skip to content

Commit 4ee45a0

Browse files
authored
Merge pull request #343 from ev-br/test_getitem_arrays_1D
Test fancy indexing with integers and index arrays
2 parents 28e1982 + bb5a6a5 commit 4ee45a0

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

array_api_tests/test_array_object.py

+75
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,81 @@ def test_setitem_masking(shape, data):
242242
)
243243

244244

245+
# ### Fancy indexing ###
246+
247+
@pytest.mark.min_version("2024.12")
248+
@pytest.mark.unvectorized
249+
@pytest.mark.parametrize("idx_max_dims", [1, None])
250+
@given(shape=hh.shapes(min_dims=2), data=st.data())
251+
def test_getitem_arrays_and_ints_1(shape, data, idx_max_dims):
252+
# min_dims=2 : test multidim `x` arrays
253+
# index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
254+
_test_getitem_arrays_and_ints(shape, data, idx_max_dims)
255+
256+
257+
@pytest.mark.min_version("2024.12")
258+
@pytest.mark.unvectorized
259+
@pytest.mark.parametrize("idx_max_dims", [1, None])
260+
@given(shape=hh.shapes(min_dims=1), data=st.data())
261+
def test_getitem_arrays_and_ints_2(shape, data, idx_max_dims):
262+
# min_dims=1 : favor 1D `x` arrays
263+
# index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
264+
_test_getitem_arrays_and_ints(shape, data, idx_max_dims)
265+
266+
267+
def _test_getitem_arrays_and_ints(shape, data, idx_max_dims):
268+
assume((len(shape) > 0) and all(sh > 0 for sh in shape))
269+
270+
dtype = xp.int32
271+
obj = data.draw(scalar_objects(dtype, shape), label="obj")
272+
x = xp.asarray(obj, dtype=dtype)
273+
274+
# draw a mix of ints and index arrays
275+
arr_index = [data.draw(st.booleans()) for _ in range(len(shape))]
276+
assume(sum(arr_index) > 0)
277+
278+
# draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY
279+
# max_dims=None ==> multidim indexing arrays
280+
if sum(arr_index) > 0:
281+
index_shapes = data.draw(
282+
hh.mutually_broadcastable_shapes(
283+
sum(arr_index), min_dims=1, max_dims=idx_max_dims, min_side=1
284+
)
285+
)
286+
index_shapes = list(index_shapes)
287+
288+
# prepare the indexing tuple, a mix of integer indices and index arrays
289+
key = []
290+
for i,typ in enumerate(arr_index):
291+
if typ:
292+
# draw an array index
293+
this_idx = data.draw(
294+
xps.arrays(
295+
dtype,
296+
shape=index_shapes.pop(),
297+
elements=st.integers(0, shape[i]-1)
298+
)
299+
)
300+
key.append(this_idx)
301+
302+
else:
303+
# draw an integer
304+
key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
305+
306+
print(f"??? {x.shape = } {len(key) = } {[xp.asarray(k).shape for k in key]}")
307+
308+
key = tuple(key)
309+
out = x[key]
310+
311+
arrays = [xp.asarray(k) for k in key]
312+
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
313+
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
314+
315+
for idx in sh.ndindex(bcast_shape):
316+
tpl = tuple(k[idx] for k in bcast_key)
317+
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
318+
319+
245320
def make_scalar_casting_param(
246321
method_name: str, dtype: DataType, stype: ScalarType
247322
) -> Param:

0 commit comments

Comments
 (0)