@@ -242,8 +242,25 @@ def test_setitem_masking(shape, data):
242
242
)
243
243
244
244
245
- @given (shape = hh .shapes (), data = st .data ())
246
- def test_getitem_arrays_and_ints (shape , data ):
245
+ @pytest .mark .min_version ("2024.12" )
246
+ @pytest .mark .unvectorized
247
+ @given (shape = hh .shapes (min_dims = 2 ), data = st .data ())
248
+ def test_getitem_arrays_and_ints_1 (shape , data ):
249
+ # min_dims=2 : test multidim `x` arrays
250
+ # index arrays are all 1D
251
+ _test_getitem_arrays_and_ints_1D (shape , data )
252
+
253
+
254
+ @pytest .mark .min_version ("2024.12" )
255
+ @pytest .mark .unvectorized
256
+ @given (shape = hh .shapes (min_dims = 1 ), data = st .data ())
257
+ def test_getitem_arrays_and_ints_2 (shape , data ):
258
+ # min_dims=1 : favor 1D `x` arrays
259
+ # index arrays are all 1D
260
+ _test_getitem_arrays_and_ints_1D (shape , data )
261
+
262
+
263
+ def _test_getitem_arrays_and_ints_1D (shape , data ):
247
264
assume ((len (shape ) > 0 ) and all (sh > 0 for sh in shape ))
248
265
249
266
dtype = xp .int32
@@ -254,10 +271,12 @@ def test_getitem_arrays_and_ints(shape, data):
254
271
arr_index = [data .draw (st .booleans ()) for _ in range (len (shape ))]
255
272
assume (sum (arr_index ) > 0 )
256
273
257
- # draw shapes for index arrays
274
+ # draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
258
275
if sum (arr_index ) > 0 :
259
276
index_shapes = data .draw (
260
- hh .mutually_broadcastable_shapes (sum (arr_index ), min_dims = 1 , min_side = 1 )
277
+ hh .mutually_broadcastable_shapes (
278
+ sum (arr_index ), min_dims = 1 , max_dims = 1 , min_side = 1
279
+ )
261
280
)
262
281
index_shapes = list (index_shapes )
263
282
@@ -279,19 +298,18 @@ def test_getitem_arrays_and_ints(shape, data):
279
298
# draw an integer
280
299
key .append (data .draw (st .integers (- shape [i ], shape [i ]- 1 )))
281
300
282
-
283
- print (f"??? { x .shape = } { key = } -- { [k if isinstance (k , int ) else k .shape for k in key ]} " )
301
+ # print(f"??? {x.shape = } {key = }")
284
302
285
303
key = tuple (key )
286
304
out = x [key ]
287
305
288
- # XXX: how to properly check
289
- import numpy as np
290
- x_np = np .asarray (x )
291
- out_np = np .asarray (out )
292
- key_np = tuple (k if isinstance (k , int ) else np .asarray (k ) for k in key )
306
+ arrays = [xp .asarray (k ) for k in key ]
307
+ bcast_shape = sh .broadcast_shapes (* [arr .shape for arr in arrays ])
308
+ bcast_key = [xp .broadcast_to (arr , bcast_shape ) for arr in arrays ]
293
309
294
- np .testing .assert_equal (out_np , x_np [key_np ])
310
+ for idx in sh .ndindex (bcast_shape ):
311
+ tpl = tuple (k [idx ] for k in bcast_key )
312
+ assert out [idx ] == x [tpl ], f"failing at { idx = } w/ { key = } "
295
313
296
314
297
315
def make_scalar_casting_param (
0 commit comments