Skip to content

Commit 5167e48

Browse files
committed
ENH: test with ND index arrays
1 parent f54180c commit 5167e48

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

array_api_tests/test_array_object.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -242,25 +242,29 @@ def test_setitem_masking(shape, data):
242242
)
243243

244244

245+
# ### Fancy indexing ###
246+
245247
@pytest.mark.min_version("2024.12")
246248
@pytest.mark.unvectorized
249+
@pytest.mark.parametrize("idx_max_dims", [1, None])
247250
@given(shape=hh.shapes(min_dims=2), data=st.data())
248-
def test_getitem_arrays_and_ints_1(shape, data):
251+
def test_getitem_arrays_and_ints_1(shape, data, idx_max_dims):
249252
# min_dims=2 : test multidim `x` arrays
250-
# index arrays are all 1D
251-
_test_getitem_arrays_and_ints_1D(shape, data)
253+
# index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
254+
_test_getitem_arrays_and_ints(shape, data, idx_max_dims)
252255

253256

254257
@pytest.mark.min_version("2024.12")
255258
@pytest.mark.unvectorized
259+
@pytest.mark.parametrize("idx_max_dims", [1, None])
256260
@given(shape=hh.shapes(min_dims=1), data=st.data())
257-
def test_getitem_arrays_and_ints_2(shape, data):
261+
def test_getitem_arrays_and_ints_2(shape, data, idx_max_dims):
258262
# min_dims=1 : favor 1D `x` arrays
259-
# index arrays are all 1D
260-
_test_getitem_arrays_and_ints_1D(shape, data)
263+
# index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
264+
_test_getitem_arrays_and_ints(shape, data, idx_max_dims)
261265

262266

263-
def _test_getitem_arrays_and_ints_1D(shape, data):
267+
def _test_getitem_arrays_and_ints(shape, data, idx_max_dims):
264268
assume((len(shape) > 0) and all(sh > 0 for sh in shape))
265269

266270
dtype = xp.int32
@@ -271,11 +275,12 @@ def _test_getitem_arrays_and_ints_1D(shape, data):
271275
arr_index = [data.draw(st.booleans()) for _ in range(len(shape))]
272276
assume(sum(arr_index) > 0)
273277

274-
# draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
278+
# draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY
279+
# max_dims=None ==> multidim indexing arrays
275280
if sum(arr_index) > 0:
276281
index_shapes = data.draw(
277282
hh.mutually_broadcastable_shapes(
278-
sum(arr_index), min_dims=1, max_dims=1, min_side=1
283+
sum(arr_index), min_dims=1, max_dims=idx_max_dims, min_side=1
279284
)
280285
)
281286
index_shapes = list(index_shapes)
@@ -298,7 +303,7 @@ def _test_getitem_arrays_and_ints_1D(shape, data):
298303
# draw an integer
299304
key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
300305

301-
# print(f"??? {x.shape = } {key = }")
306+
print(f"??? {x.shape = } {len(key) = } {[xp.asarray(k).shape for k in key]}")
302307

303308
key = tuple(key)
304309
out = x[key]

0 commit comments

Comments
 (0)