Skip to content

Commit 6488ad8

Browse files
committed
TST: revisit test for asarray copy= parameter
1 parent 2b5e289 commit 6488ad8

File tree

2 files changed

+28
-41
lines changed

2 files changed

+28
-41
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def asarray(
171171
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
172172

173173
if copy is False:
174-
raise NotImplementedError(
174+
raise ValueError(
175175
"Unable to avoid copy when converting a non-dask object to dask"
176176
)
177177

tests/test_common.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -278,87 +278,73 @@ def test_asarray_copy(library):
278278
xp = import_(library, wrapper=True)
279279
asarray = xp.asarray
280280
is_lib_func = globals()[is_array_functions[library]]
281-
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
282-
283-
if library == 'cupy':
284-
supports_copy_false_other_ns = False
285-
supports_copy_false_same_ns = False
286-
elif library == 'dask.array':
287-
supports_copy_false_other_ns = False
288-
supports_copy_false_same_ns = True
289-
else:
290-
supports_copy_false_other_ns = True
291-
supports_copy_false_same_ns = True
292281

293282
a = asarray([1])
294283
b = asarray(a, copy=True)
295284
assert is_lib_func(b)
296285
a[0] = 0
297-
assert all(b[0] == 1)
298-
assert all(a[0] == 0)
286+
assert b[0] == 1
287+
assert a[0] == 0
299288

300289
a = asarray([1])
301-
if supports_copy_false_same_ns:
302-
b = asarray(a, copy=False)
303-
assert is_lib_func(b)
304-
a[0] = 0
305-
assert all(b[0] == 0)
306-
else:
307-
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
308290

309-
a = asarray([1])
310-
if supports_copy_false_same_ns:
311-
pytest.raises(ValueError, lambda: asarray(a, copy=False,
312-
dtype=xp.float64))
313-
else:
314-
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))
291+
# Test copy=False within the same namespace
292+
b = asarray(a, copy=False)
293+
assert is_lib_func(b)
294+
a[0] = 0
295+
assert b[0] == 0
296+
with pytest.raises(ValueError):
297+
asarray(a, copy=False, dtype=xp.float64)
315298

299+
# copy=None defaults to False when possible
316300
a = asarray([1])
317301
b = asarray(a, copy=None)
318302
assert is_lib_func(b)
319303
a[0] = 0
320-
assert all(b[0] == 0)
304+
assert b[0] == 0
321305

306+
# copy=None defaults to True when impossible
322307
a = asarray([1.0], dtype=xp.float32)
323308
assert a.dtype == xp.float32
324309
b = asarray(a, dtype=xp.float64, copy=None)
325310
assert is_lib_func(b)
326311
assert b.dtype == xp.float64
327312
a[0] = 0.0
328-
assert all(b[0] == 1.0)
313+
assert b[0] == 1.0
329314

315+
# copy=None defaults to False when possible
330316
a = asarray([1.0], dtype=xp.float64)
331317
assert a.dtype == xp.float64
332318
b = asarray(a, dtype=xp.float64, copy=None)
333319
assert is_lib_func(b)
334320
assert b.dtype == xp.float64
335321
a[0] = 0.0
336-
assert all(b[0] == 0.0)
322+
assert b[0] == 0.0
337323

338324
# Python built-in types
339325
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
340326
asarray(obj, copy=True) # No error
341327
asarray(obj, copy=None) # No error
342-
if supports_copy_false_other_ns:
343-
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
344-
else:
345-
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
328+
329+
with pytest.raises(ValueError):
330+
asarray(obj, copy=False)
346331

347332
# Use the standard library array to test the buffer protocol
348333
a = array.array('f', [1.0])
349334
b = asarray(a, copy=True)
350335
assert is_lib_func(b)
351336
a[0] = 0.0
352-
assert all(b[0] == 1.0)
337+
assert b[0] == 1.0
353338

354339
a = array.array('f', [1.0])
355-
if supports_copy_false_other_ns:
340+
if library in ('cupy', 'dask.array'):
341+
with pytest.raises(ValueError):
342+
asarray(a, copy=False)
343+
else:
356344
b = asarray(a, copy=False)
357345
assert is_lib_func(b)
358346
a[0] = 0.0
359-
assert all(b[0] == 0.0)
360-
else:
361-
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
347+
assert b[0] == 0.0
362348

363349
a = array.array('f', [1.0])
364350
b = asarray(a, copy=None)
@@ -369,9 +355,10 @@ def test_asarray_copy(library):
369355
# dask changed behaviour of copy=None in 2024.12 to copy;
370356
# this wrapper ensures the same behaviour in older versions too.
371357
# https://github.com/dask/dask/pull/11524/
372-
assert all(b[0] == 1.0)
358+
assert b[0] == 1.0
373359
else:
374-
assert all(b[0] == 0.0)
360+
# copy=None defaults to False when possible
361+
assert b[0] == 0.0
375362

376363

377364
@pytest.mark.parametrize("library", ["numpy", "cupy", "torch"])

0 commit comments

Comments
 (0)