@@ -278,87 +278,73 @@ def test_asarray_copy(library):
278
278
xp = import_ (library , wrapper = True )
279
279
asarray = xp .asarray
280
280
is_lib_func = globals ()[is_array_functions [library ]]
281
- all = xp .all if library != 'dask.array' else lambda x : xp .all (x ).compute ()
282
-
283
- if library == 'cupy' :
284
- supports_copy_false_other_ns = False
285
- supports_copy_false_same_ns = False
286
- elif library == 'dask.array' :
287
- supports_copy_false_other_ns = False
288
- supports_copy_false_same_ns = True
289
- else :
290
- supports_copy_false_other_ns = True
291
- supports_copy_false_same_ns = True
292
281
293
282
a = asarray ([1 ])
294
283
b = asarray (a , copy = True )
295
284
assert is_lib_func (b )
296
285
a [0 ] = 0
297
- assert all ( b [0 ] == 1 )
298
- assert all ( a [0 ] == 0 )
286
+ assert b [0 ] == 1
287
+ assert a [0 ] == 0
299
288
300
289
a = asarray ([1 ])
301
- if supports_copy_false_same_ns :
302
- b = asarray (a , copy = False )
303
- assert is_lib_func (b )
304
- a [0 ] = 0
305
- assert all (b [0 ] == 0 )
306
- else :
307
- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False ))
308
290
309
- a = asarray ([1 ])
310
- if supports_copy_false_same_ns :
311
- pytest .raises (ValueError , lambda : asarray (a , copy = False ,
312
- dtype = xp .float64 ))
313
- else :
314
- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False , dtype = xp .float64 ))
291
+ # Test copy=False within the same namespace
292
+ b = asarray (a , copy = False )
293
+ assert is_lib_func (b )
294
+ a [0 ] = 0
295
+ assert b [0 ] == 0
296
+ with pytest .raises (ValueError ):
297
+ asarray (a , copy = False , dtype = xp .float64 )
315
298
299
+ # copy=None defaults to False when possible
316
300
a = asarray ([1 ])
317
301
b = asarray (a , copy = None )
318
302
assert is_lib_func (b )
319
303
a [0 ] = 0
320
- assert all ( b [0 ] == 0 )
304
+ assert b [0 ] == 0
321
305
306
+ # copy=None defaults to True when impossible
322
307
a = asarray ([1.0 ], dtype = xp .float32 )
323
308
assert a .dtype == xp .float32
324
309
b = asarray (a , dtype = xp .float64 , copy = None )
325
310
assert is_lib_func (b )
326
311
assert b .dtype == xp .float64
327
312
a [0 ] = 0.0
328
- assert all ( b [0 ] == 1.0 )
313
+ assert b [0 ] == 1.0
329
314
315
+ # copy=None defaults to False when possible
330
316
a = asarray ([1.0 ], dtype = xp .float64 )
331
317
assert a .dtype == xp .float64
332
318
b = asarray (a , dtype = xp .float64 , copy = None )
333
319
assert is_lib_func (b )
334
320
assert b .dtype == xp .float64
335
321
a [0 ] = 0.0
336
- assert all ( b [0 ] == 0.0 )
322
+ assert b [0 ] == 0.0
337
323
338
324
# Python built-in types
339
325
for obj in [True , 0 , 0.0 , 0j , [0 ], [[0 ]]]:
340
326
asarray (obj , copy = True ) # No error
341
327
asarray (obj , copy = None ) # No error
342
- if supports_copy_false_other_ns :
343
- pytest .raises (ValueError , lambda : asarray (obj , copy = False ))
344
- else :
345
- pytest .raises (NotImplementedError , lambda : asarray (obj , copy = False ))
328
+
329
+ with pytest .raises (ValueError ):
330
+ asarray (obj , copy = False )
346
331
347
332
# Use the standard library array to test the buffer protocol
348
333
a = array .array ('f' , [1.0 ])
349
334
b = asarray (a , copy = True )
350
335
assert is_lib_func (b )
351
336
a [0 ] = 0.0
352
- assert all ( b [0 ] == 1.0 )
337
+ assert b [0 ] == 1.0
353
338
354
339
a = array .array ('f' , [1.0 ])
355
- if supports_copy_false_other_ns :
340
+ if library in ('cupy' , 'dask.array' ):
341
+ with pytest .raises (ValueError ):
342
+ asarray (a , copy = False )
343
+ else :
356
344
b = asarray (a , copy = False )
357
345
assert is_lib_func (b )
358
346
a [0 ] = 0.0
359
- assert all (b [0 ] == 0.0 )
360
- else :
361
- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False ))
347
+ assert b [0 ] == 0.0
362
348
363
349
a = array .array ('f' , [1.0 ])
364
350
b = asarray (a , copy = None )
@@ -369,9 +355,10 @@ def test_asarray_copy(library):
369
355
# dask changed behaviour of copy=None in 2024.12 to copy;
370
356
# this wrapper ensures the same behaviour in older versions too.
371
357
# https://github.com/dask/dask/pull/11524/
372
- assert all ( b [0 ] == 1.0 )
358
+ assert b [0 ] == 1.0
373
359
else :
374
- assert all (b [0 ] == 0.0 )
360
+ # copy=None defaults to False when possible
361
+ assert b [0 ] == 0.0
375
362
376
363
377
364
@pytest .mark .parametrize ("library" , ["numpy" , "cupy" , "torch" ])
0 commit comments