diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 706821c4..27d8ef54 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -788,19 +788,24 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] return x.to_device(device, stream=stream) -def size(x): +def size(x: Array) -> int | None: """ Return the total number of elements of x. This is equivalent to `x.size` according to the `standard `__. + This helper is included because PyTorch defines `size` in an :external+torch:meth:`incompatible way `. - + It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas + the standard requires None. """ + # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None - return math.prod(x.shape) + out = math.prod(x.shape) + # dask.array.Array.shape can contain NaN + return None if math.isnan(out) else out def is_writeable_array(x) -> bool: diff --git a/tests/test_common.py b/tests/test_common.py index 7503481e..1a4a32dc 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,8 +5,9 @@ is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, ) -from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device - +from array_api_compat import ( + device, is_array_api_obj, is_writeable_array, size, to_device +) from ._helpers import import_, wrapped_libraries, all_libraries import pytest @@ -92,6 +93,28 @@ def test_is_writeable_array_numpy(): assert not is_writeable_array(x) +@pytest.mark.parametrize("library", all_libraries) +def test_size(library): + xp = import_(library) + x = xp.asarray([1, 2, 3]) + assert size(x) == 3 + + +@pytest.mark.parametrize("library", all_libraries) +def test_size_none(library): + if library == "sparse": + pytest.skip("No arange(); no indexing by sparse arrays") + + xp = import_(library) + x = xp.arange(10) + x = x[x < 5] + + # dask.array now has shape=(nan, ) and size=nan + # ndonnx now has shape=(None, ) and size=None + # Eager libraries have shape=(5, ) and size=5 + assert size(x) in (None, 5) + + @pytest.mark.parametrize("library", all_libraries) def test_device(library): xp = import_(library, wrapper=True)