@@ -242,6 +242,76 @@ def test_setitem_masking(shape, data):
242
242
)
243
243
244
244
245
+ @pytest .mark .min_version ("2024.12" )
246
+ @pytest .mark .unvectorized
247
+ @given (shape = hh .shapes (min_dims = 2 ), data = st .data ())
248
+ def test_getitem_arrays_and_ints_1 (shape , data ):
249
+ # min_dims=2 : test multidim `x` arrays
250
+ # index arrays are all 1D
251
+ _test_getitem_arrays_and_ints_1D (shape , data )
252
+
253
+
254
+ @pytest .mark .min_version ("2024.12" )
255
+ @pytest .mark .unvectorized
256
+ @given (shape = hh .shapes (min_dims = 1 ), data = st .data ())
257
+ def test_getitem_arrays_and_ints_2 (shape , data ):
258
+ # min_dims=1 : favor 1D `x` arrays
259
+ # index arrays are all 1D
260
+ _test_getitem_arrays_and_ints_1D (shape , data )
261
+
262
+
263
+ def _test_getitem_arrays_and_ints_1D (shape , data ):
264
+ assume ((len (shape ) > 0 ) and all (sh > 0 for sh in shape ))
265
+
266
+ dtype = xp .int32
267
+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
268
+ x = xp .asarray (obj , dtype = dtype )
269
+
270
+ # draw a mix of ints and index arrays
271
+ arr_index = [data .draw (st .booleans ()) for _ in range (len (shape ))]
272
+ assume (sum (arr_index ) > 0 )
273
+
274
+ # draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
275
+ if sum (arr_index ) > 0 :
276
+ index_shapes = data .draw (
277
+ hh .mutually_broadcastable_shapes (
278
+ sum (arr_index ), min_dims = 1 , max_dims = 1 , min_side = 1
279
+ )
280
+ )
281
+ index_shapes = list (index_shapes )
282
+
283
+ # prepare the indexing tuple, a mix of integer indices and index arrays
284
+ key = []
285
+ for i ,typ in enumerate (arr_index ):
286
+ if typ :
287
+ # draw an array index
288
+ this_idx = data .draw (
289
+ xps .arrays (
290
+ dtype ,
291
+ shape = index_shapes .pop (),
292
+ elements = st .integers (0 , shape [i ]- 1 )
293
+ )
294
+ )
295
+ key .append (this_idx )
296
+
297
+ else :
298
+ # draw an integer
299
+ key .append (data .draw (st .integers (- shape [i ], shape [i ]- 1 )))
300
+
301
+ # print(f"??? {x.shape = } {key = }")
302
+
303
+ key = tuple (key )
304
+ out = x [key ]
305
+
306
+ arrays = [xp .asarray (k ) for k in key ]
307
+ bcast_shape = sh .broadcast_shapes (* [arr .shape for arr in arrays ])
308
+ bcast_key = [xp .broadcast_to (arr , bcast_shape ) for arr in arrays ]
309
+
310
+ for idx in sh .ndindex (bcast_shape ):
311
+ tpl = tuple (k [idx ] for k in bcast_key )
312
+ assert out [idx ] == x [tpl ], f"failing at { idx = } w/ { key = } "
313
+
314
+
245
315
def make_scalar_casting_param (
246
316
method_name : str , dtype : DataType , stype : ScalarType
247
317
) -> Param :
0 commit comments