@@ -242,6 +242,81 @@ def test_setitem_masking(shape, data):
242
242
)
243
243
244
244
245
+ # ### Fancy indexing ###
246
+
247
+ @pytest .mark .min_version ("2024.12" )
248
+ @pytest .mark .unvectorized
249
+ @pytest .mark .parametrize ("idx_max_dims" , [1 , None ])
250
+ @given (shape = hh .shapes (min_dims = 2 ), data = st .data ())
251
+ def test_getitem_arrays_and_ints_1 (shape , data , idx_max_dims ):
252
+ # min_dims=2 : test multidim `x` arrays
253
+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
254
+ _test_getitem_arrays_and_ints (shape , data , idx_max_dims )
255
+
256
+
257
+ @pytest .mark .min_version ("2024.12" )
258
+ @pytest .mark .unvectorized
259
+ @pytest .mark .parametrize ("idx_max_dims" , [1 , None ])
260
+ @given (shape = hh .shapes (min_dims = 1 ), data = st .data ())
261
+ def test_getitem_arrays_and_ints_2 (shape , data , idx_max_dims ):
262
+ # min_dims=1 : favor 1D `x` arrays
263
+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
264
+ _test_getitem_arrays_and_ints (shape , data , idx_max_dims )
265
+
266
+
267
+ def _test_getitem_arrays_and_ints (shape , data , idx_max_dims ):
268
+ assume ((len (shape ) > 0 ) and all (sh > 0 for sh in shape ))
269
+
270
+ dtype = xp .int32
271
+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
272
+ x = xp .asarray (obj , dtype = dtype )
273
+
274
+ # draw a mix of ints and index arrays
275
+ arr_index = [data .draw (st .booleans ()) for _ in range (len (shape ))]
276
+ assume (sum (arr_index ) > 0 )
277
+
278
+ # draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY
279
+ # max_dims=None ==> multidim indexing arrays
280
+ if sum (arr_index ) > 0 :
281
+ index_shapes = data .draw (
282
+ hh .mutually_broadcastable_shapes (
283
+ sum (arr_index ), min_dims = 1 , max_dims = idx_max_dims , min_side = 1
284
+ )
285
+ )
286
+ index_shapes = list (index_shapes )
287
+
288
+ # prepare the indexing tuple, a mix of integer indices and index arrays
289
+ key = []
290
+ for i ,typ in enumerate (arr_index ):
291
+ if typ :
292
+ # draw an array index
293
+ this_idx = data .draw (
294
+ xps .arrays (
295
+ dtype ,
296
+ shape = index_shapes .pop (),
297
+ elements = st .integers (0 , shape [i ]- 1 )
298
+ )
299
+ )
300
+ key .append (this_idx )
301
+
302
+ else :
303
+ # draw an integer
304
+ key .append (data .draw (st .integers (- shape [i ], shape [i ]- 1 )))
305
+
306
+ print (f"??? { x .shape = } { len (key ) = } { [xp .asarray (k ).shape for k in key ]} " )
307
+
308
+ key = tuple (key )
309
+ out = x [key ]
310
+
311
+ arrays = [xp .asarray (k ) for k in key ]
312
+ bcast_shape = sh .broadcast_shapes (* [arr .shape for arr in arrays ])
313
+ bcast_key = [xp .broadcast_to (arr , bcast_shape ) for arr in arrays ]
314
+
315
+ for idx in sh .ndindex (bcast_shape ):
316
+ tpl = tuple (k [idx ] for k in bcast_key )
317
+ assert out [idx ] == x [tpl ], f"failing at { idx = } w/ { key = } "
318
+
319
+
245
320
def make_scalar_casting_param (
246
321
method_name : str , dtype : DataType , stype : ScalarType
247
322
) -> Param :
0 commit comments