diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 9687a9cd..d43881ab 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -171,7 +171,7 @@ def asarray( return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] if copy is False: - raise NotImplementedError( + raise ValueError( "Unable to avoid copy when converting a non-dask object to dask" ) diff --git a/tests/test_common.py b/tests/test_common.py index d1933899..54b5ed69 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -268,7 +268,6 @@ def test_asarray_cross_library(source_library, target_library, request): assert b.dtype == tgt_lib.int32 - @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): # Note, we have this test here because the test suite currently doesn't @@ -278,100 +277,87 @@ def test_asarray_copy(library): xp = import_(library, wrapper=True) asarray = xp.asarray is_lib_func = globals()[is_array_functions[library]] - all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - - if library == 'cupy': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = False - elif library == 'dask.array': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = True - else: - supports_copy_false_other_ns = True - supports_copy_false_same_ns = True a = asarray([1]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 1) - assert all(a[0] == 0) + assert b[0] == 1 + assert a[0] == 0 a = asarray([1]) - if supports_copy_false_same_ns: - b = asarray(a, copy=False) - assert is_lib_func(b) - a[0] = 0 - assert all(b[0] == 0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) - a = asarray([1]) - if supports_copy_false_same_ns: - pytest.raises(ValueError, lambda: asarray(a, copy=False, - dtype=xp.float64)) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64)) + # Test copy=False within the same namespace + b = asarray(a, copy=False) + assert is_lib_func(b) + a[0] = 0 + assert b[0] == 0 + with pytest.raises(ValueError): + asarray(a, copy=False, dtype=xp.float64) + # copy=None defaults to False when possible a = asarray([1]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 0) + assert b[0] == 0 + # copy=None defaults to True when impossible a = asarray([1.0], dtype=xp.float32) assert a.dtype == xp.float32 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 + # copy=None defaults to False when possible a = asarray([1.0], dtype=xp.float64) assert a.dtype == xp.float64 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 0.0) + assert b[0] == 0.0 # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: - asarray(obj, copy=True) # No error - asarray(obj, copy=None) # No error - if supports_copy_false_other_ns: - pytest.raises(ValueError, lambda: asarray(obj, copy=False)) - else: - pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) + asarray(obj, copy=True) # No error + asarray(obj, copy=None) # No error + + with pytest.raises(ValueError): + asarray(obj, copy=False) # Use the standard library array to test the buffer protocol - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 - a = array.array('f', [1.0]) - if supports_copy_false_other_ns: + a = array.array("f", [1.0]) + if library in ("cupy", "dask.array"): + with pytest.raises(ValueError): + asarray(a, copy=False) + else: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 0.0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + assert b[0] == 0.0 - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0.0 - if library in ('cupy', 'dask.array'): + if library in ("cupy", "dask.array"): # A copy is required for libraries where the default device is not CPU # dask changed behaviour of copy=None in 2024.12 to copy; # this wrapper ensures the same behaviour in older versions too. # https://github.com/dask/dask/pull/11524/ - assert all(b[0] == 1.0) + assert b[0] == 1.0 else: - assert all(b[0] == 0.0) + # copy=None defaults to False when possible + assert b[0] == 0.0 @pytest.mark.parametrize("library", ["numpy", "cupy", "torch"])