@@ -242,25 +242,29 @@ def test_setitem_masking(shape, data):
242
242
)
243
243
244
244
245
+ # ### Fancy indexing ###
246
+
245
247
@pytest .mark .min_version ("2024.12" )
246
248
@pytest .mark .unvectorized
249
+ @pytest .mark .parametrize ("idx_max_dims" , [1 , None ])
247
250
@given (shape = hh .shapes (min_dims = 2 ), data = st .data ())
248
- def test_getitem_arrays_and_ints_1 (shape , data ):
251
+ def test_getitem_arrays_and_ints_1 (shape , data , idx_max_dims ):
249
252
# min_dims=2 : test multidim `x` arrays
250
- # index arrays are all 1D
251
- _test_getitem_arrays_and_ints_1D (shape , data )
253
+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
254
+ _test_getitem_arrays_and_ints (shape , data , idx_max_dims )
252
255
253
256
254
257
@pytest .mark .min_version ("2024.12" )
255
258
@pytest .mark .unvectorized
259
+ @pytest .mark .parametrize ("idx_max_dims" , [1 , None ])
256
260
@given (shape = hh .shapes (min_dims = 1 ), data = st .data ())
257
- def test_getitem_arrays_and_ints_2 (shape , data ):
261
+ def test_getitem_arrays_and_ints_2 (shape , data , idx_max_dims ):
258
262
# min_dims=1 : favor 1D `x` arrays
259
- # index arrays are all 1D
260
- _test_getitem_arrays_and_ints_1D (shape , data )
263
+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
264
+ _test_getitem_arrays_and_ints (shape , data , idx_max_dims )
261
265
262
266
263
- def _test_getitem_arrays_and_ints_1D (shape , data ):
267
+ def _test_getitem_arrays_and_ints (shape , data , idx_max_dims ):
264
268
assume ((len (shape ) > 0 ) and all (sh > 0 for sh in shape ))
265
269
266
270
dtype = xp .int32
@@ -271,11 +275,12 @@ def _test_getitem_arrays_and_ints_1D(shape, data):
271
275
arr_index = [data .draw (st .booleans ()) for _ in range (len (shape ))]
272
276
assume (sum (arr_index ) > 0 )
273
277
274
- # draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
278
+ # draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY
279
+ # max_dims=None ==> multidim indexing arrays
275
280
if sum (arr_index ) > 0 :
276
281
index_shapes = data .draw (
277
282
hh .mutually_broadcastable_shapes (
278
- sum (arr_index ), min_dims = 1 , max_dims = 1 , min_side = 1
283
+ sum (arr_index ), min_dims = 1 , max_dims = idx_max_dims , min_side = 1
279
284
)
280
285
)
281
286
index_shapes = list (index_shapes )
@@ -298,7 +303,7 @@ def _test_getitem_arrays_and_ints_1D(shape, data):
298
303
# draw an integer
299
304
key .append (data .draw (st .integers (- shape [i ], shape [i ]- 1 )))
300
305
301
- # print(f"??? {x.shape = } {key = }")
306
+ print (f"??? { x .shape = } { len ( key ) = } { [ xp . asarray ( k ). shape for k in key ] } " )
302
307
303
308
key = tuple (key )
304
309
out = x [key ]
0 commit comments