Skip to content

Commit 73abd65

Browse files
committed
ENH: only test with 1D index arrays, make test unvecorized
1 parent d5d3080 commit 73abd65

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

array_api_tests/test_array_object.py

+30-12
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,25 @@ def test_setitem_masking(shape, data):
242242
)
243243

244244

245-
@given(shape=hh.shapes(), data=st.data())
246-
def test_getitem_arrays_and_ints(shape, data):
245+
@pytest.mark.min_version("2024.12")
246+
@pytest.mark.unvectorized
247+
@given(shape=hh.shapes(min_dims=2), data=st.data())
248+
def test_getitem_arrays_and_ints_1(shape, data):
249+
# min_dims=2 : test multidim `x` arrays
250+
# index arrays are all 1D
251+
_test_getitem_arrays_and_ints_1D(shape, data)
252+
253+
254+
@pytest.mark.min_version("2024.12")
255+
@pytest.mark.unvectorized
256+
@given(shape=hh.shapes(min_dims=1), data=st.data())
257+
def test_getitem_arrays_and_ints_2(shape, data):
258+
# min_dims=1 : favor 1D `x` arrays
259+
# index arrays are all 1D
260+
_test_getitem_arrays_and_ints_1D(shape, data)
261+
262+
263+
def _test_getitem_arrays_and_ints_1D(shape, data):
247264
assume((len(shape) > 0) and all(sh > 0 for sh in shape))
248265

249266
dtype = xp.int32
@@ -254,10 +271,12 @@ def test_getitem_arrays_and_ints(shape, data):
254271
arr_index = [data.draw(st.booleans()) for _ in range(len(shape))]
255272
assume(sum(arr_index) > 0)
256273

257-
# draw shapes for index arrays
274+
# draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
258275
if sum(arr_index) > 0:
259276
index_shapes = data.draw(
260-
hh.mutually_broadcastable_shapes(sum(arr_index), min_dims=1, min_side=1)
277+
hh.mutually_broadcastable_shapes(
278+
sum(arr_index), min_dims=1, max_dims=1, min_side=1
279+
)
261280
)
262281
index_shapes = list(index_shapes)
263282

@@ -279,19 +298,18 @@ def test_getitem_arrays_and_ints(shape, data):
279298
# draw an integer
280299
key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
281300

282-
283-
print(f"??? {x.shape = } {key = } -- {[k if isinstance(k, int) else k.shape for k in key]}")
301+
# print(f"??? {x.shape = } {key = }")
284302

285303
key = tuple(key)
286304
out = x[key]
287305

288-
# XXX: how to properly check
289-
import numpy as np
290-
x_np = np.asarray(x)
291-
out_np = np.asarray(out)
292-
key_np = tuple(k if isinstance(k, int) else np.asarray(k) for k in key)
306+
arrays = [xp.asarray(k) for k in key]
307+
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
308+
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
293309

294-
np.testing.assert_equal(out_np, x_np[key_np])
310+
for idx in sh.ndindex(bcast_shape):
311+
tpl = tuple(k[idx] for k in bcast_key)
312+
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
295313

296314

297315
def make_scalar_casting_param(

0 commit comments

Comments
 (0)