From 6488ad81748a1b92f7a0de42e5a10461a9df6b62 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 08:55:20 +0100 Subject: [PATCH 1/2] TST: revisit test for `asarray` `copy=` parameter --- array_api_compat/dask/array/_aliases.py | 2 +- tests/test_common.py | 67 ++++++++++--------------- 2 files changed, 28 insertions(+), 41 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 9687a9cd..d43881ab 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -171,7 +171,7 @@ def asarray( return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] if copy is False: - raise NotImplementedError( + raise ValueError( "Unable to avoid copy when converting a non-dask object to dask" ) diff --git a/tests/test_common.py b/tests/test_common.py index d1933899..fe4fe598 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -278,87 +278,73 @@ def test_asarray_copy(library): xp = import_(library, wrapper=True) asarray = xp.asarray is_lib_func = globals()[is_array_functions[library]] - all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - - if library == 'cupy': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = False - elif library == 'dask.array': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = True - else: - supports_copy_false_other_ns = True - supports_copy_false_same_ns = True a = asarray([1]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 1) - assert all(a[0] == 0) + assert b[0] == 1 + assert a[0] == 0 a = asarray([1]) - if supports_copy_false_same_ns: - b = asarray(a, copy=False) - assert is_lib_func(b) - a[0] = 0 - assert all(b[0] == 0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) - a = asarray([1]) - if supports_copy_false_same_ns: - pytest.raises(ValueError, lambda: asarray(a, copy=False, - dtype=xp.float64)) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64)) + # Test copy=False within the same namespace + b = asarray(a, copy=False) + assert is_lib_func(b) + a[0] = 0 + assert b[0] == 0 + with pytest.raises(ValueError): + asarray(a, copy=False, dtype=xp.float64) + # copy=None defaults to False when possible a = asarray([1]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 0) + assert b[0] == 0 + # copy=None defaults to True when impossible a = asarray([1.0], dtype=xp.float32) assert a.dtype == xp.float32 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 + # copy=None defaults to False when possible a = asarray([1.0], dtype=xp.float64) assert a.dtype == xp.float64 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 0.0) + assert b[0] == 0.0 # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: asarray(obj, copy=True) # No error asarray(obj, copy=None) # No error - if supports_copy_false_other_ns: - pytest.raises(ValueError, lambda: asarray(obj, copy=False)) - else: - pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) + + with pytest.raises(ValueError): + asarray(obj, copy=False) # Use the standard library array to test the buffer protocol a = array.array('f', [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 a = array.array('f', [1.0]) - if supports_copy_false_other_ns: + if library in ('cupy', 'dask.array'): + with pytest.raises(ValueError): + asarray(a, copy=False) + else: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 0.0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + assert b[0] == 0.0 a = array.array('f', [1.0]) b = asarray(a, copy=None) @@ -369,9 +355,10 @@ def test_asarray_copy(library): # dask changed behaviour of copy=None in 2024.12 to copy; # this wrapper ensures the same behaviour in older versions too. # https://github.com/dask/dask/pull/11524/ - assert all(b[0] == 1.0) + assert b[0] == 1.0 else: - assert all(b[0] == 0.0) + # copy=None defaults to False when possible + assert b[0] == 0.0 @pytest.mark.parametrize("library", ["numpy", "cupy", "torch"]) From 44e7828b0666e5edd26958fb2337d755cf1a6002 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:08:18 +0100 Subject: [PATCH 2/2] lint --- tests/test_common.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_common.py b/tests/test_common.py index fe4fe598..54b5ed69 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -268,7 +268,6 @@ def test_asarray_cross_library(source_library, target_library, request): assert b.dtype == tgt_lib.int32 - @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): # Note, we have this test here because the test suite currently doesn't @@ -323,21 +322,21 @@ def test_asarray_copy(library): # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: - asarray(obj, copy=True) # No error - asarray(obj, copy=None) # No error + asarray(obj, copy=True) # No error + asarray(obj, copy=None) # No error with pytest.raises(ValueError): asarray(obj, copy=False) # Use the standard library array to test the buffer protocol - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 assert b[0] == 1.0 - a = array.array('f', [1.0]) - if library in ('cupy', 'dask.array'): + a = array.array("f", [1.0]) + if library in ("cupy", "dask.array"): with pytest.raises(ValueError): asarray(a, copy=False) else: @@ -346,11 +345,11 @@ def test_asarray_copy(library): a[0] = 0.0 assert b[0] == 0.0 - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0.0 - if library in ('cupy', 'dask.array'): + if library in ("cupy", "dask.array"): # A copy is required for libraries where the default device is not CPU # dask changed behaviour of copy=None in 2024.12 to copy; # this wrapper ensures the same behaviour in older versions too.