Skip to content

Commit f3ed389

Browse files
committed
ENH: is_lazy_array()
1 parent beac55b commit f3ed389

File tree

3 files changed

+109
-7
lines changed

3 files changed

+109
-7
lines changed

array_api_compat/common/_helpers.py

+58
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,63 @@ def is_writeable_array(x) -> bool:
819819
return True
820820

821821

822+
def is_lazy_array(x) -> bool:
823+
"""Return True if x is potentially a future or it may be otherwise impossible or
824+
expensive to eagerly read its contents, regardless of their size, e.g. by
825+
calling ``bool(x)`` or ``float(x)``.
826+
827+
Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
828+
cheap as long as the array has the right dtype.
829+
830+
Note
831+
----
832+
This function errs on the side of caution for array types that may or may not be
833+
lazy, e.g. JAX arrays, by always returning True for them.
834+
"""
835+
if (
836+
is_numpy_array(x)
837+
or is_cupy_array(x)
838+
or is_torch_array(x)
839+
or is_pydata_sparse_array(x)
840+
):
841+
return False
842+
843+
# **JAX note:** while it is possible to determine if you're inside or outside
844+
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
845+
# as we do below for unknown arrays, this is not recommended by JAX best practices.
846+
847+
# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
848+
# This behaviour, while impossible to change without breaking backwards
849+
# compatibility, is highly detrimental to performance as the whole graph will end
850+
# up being computed multiple times.
851+
852+
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
853+
return True
854+
855+
# Unknown Array API compatible object. Note that this test may have dire consequences
856+
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
857+
# on __bool__ (dask is one such example, which however is special-cased above).
858+
859+
# Select a single point of the array
860+
s = size(x)
861+
if s is None or math.isnan(s):
862+
return True
863+
xp = array_namespace(x)
864+
if s > 1:
865+
x = xp.reshape(x, (-1,))[0]
866+
# Cast to dtype=bool and deal with size 0 arrays
867+
x = xp.any(x)
868+
869+
try:
870+
bool(x)
871+
return False
872+
# The Array API standard dictactes that __bool__ should raise TypeError if the
873+
# output cannot be defined.
874+
# Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
875+
except Exception:
876+
return True
877+
878+
822879
__all__ = [
823880
"array_namespace",
824881
"device",
@@ -840,6 +897,7 @@ def is_writeable_array(x) -> bool:
840897
"is_pydata_sparse_array",
841898
"is_pydata_sparse_namespace",
842899
"is_writeable_array",
900+
"is_lazy_array",
843901
"size",
844902
"to_device",
845903
]

docs/helper-functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ yet.
5252
.. autofunction:: is_pydata_sparse_array
5353
.. autofunction:: is_ndonnx_array
5454
.. autofunction:: is_writeable_array
55+
.. autofunction:: is_lazy_array
5556
.. autofunction:: is_numpy_namespace
5657
.. autofunction:: is_cupy_namespace
5758
.. autofunction:: is_torch_namespace

tests/test_common.py

+50-7
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
import math
2+
3+
import pytest
4+
import numpy as np
5+
import array
6+
from numpy.testing import assert_allclose
7+
18
from array_api_compat import ( # noqa: F401
29
is_numpy_array, is_cupy_array, is_torch_array,
310
is_dask_array, is_jax_array, is_pydata_sparse_array,
411
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
512
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
613
)
714

8-
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
9-
15+
from array_api_compat import (
16+
device, is_array_api_obj, is_lazy_array, is_writeable_array, to_device
17+
)
1018
from ._helpers import import_, wrapped_libraries, all_libraries
1119

12-
import pytest
13-
import numpy as np
14-
import array
15-
from numpy.testing import assert_allclose
16-
1720
is_array_functions = {
1821
'numpy': 'is_numpy_array',
1922
'cupy': 'is_cupy_array',
@@ -92,6 +95,45 @@ def test_is_writeable_array_numpy():
9295
assert not is_writeable_array(x)
9396

9497

98+
@pytest.mark.parametrize("library", all_libraries)
99+
def test_is_lazy_array(library):
100+
lib = import_(library)
101+
x = lib.asarray([1, 2, 3])
102+
assert isinstance(is_lazy_array(x), bool)
103+
104+
105+
@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)])
106+
def test_is_lazy_array_nan_size(shape, monkeypatch):
107+
"""Test is_lazy_array() on an unknown Array API compliant object
108+
with NaN (like Dask) or None (like ndonnx) in its shape
109+
"""
110+
xp = import_("array_api_strict")
111+
x = xp.asarray(1)
112+
assert not is_lazy_array(x)
113+
monkeypatch.setattr(type(x), "shape", shape)
114+
assert is_lazy_array(x)
115+
116+
117+
@pytest.mark.parametrize("exc", [TypeError, AssertionError])
118+
def test_is_lazy_array_bool_raises(exc, monkeypatch):
119+
"""Test is_lazy_array() on an unknown Array API compliant object
120+
where calling bool() raises:
121+
- TypeError: e.g. like jitted JAX. This is the proper exception which
122+
lazy arrays should raise as per the Array API specification
123+
- something else: e.g. like Dask, where bool() triggers compute()
124+
which can result in any kind of exception to be raised
125+
"""
126+
xp = import_("array_api_strict")
127+
x = xp.asarray(1)
128+
assert not is_lazy_array(x)
129+
130+
def __bool__(self):
131+
raise exc("Hello world")
132+
133+
monkeypatch.setattr(type(x), "__bool__", __bool__)
134+
assert is_lazy_array(x)
135+
136+
95137
@pytest.mark.parametrize("library", all_libraries)
96138
def test_device(library):
97139
xp = import_(library, wrapper=True)
@@ -149,6 +191,7 @@ def test_asarray_cross_library(source_library, target_library, request):
149191

150192
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
151193

194+
152195
@pytest.mark.parametrize("library", wrapped_libraries)
153196
def test_asarray_copy(library):
154197
# Note, we have this test here because the test suite currently doesn't

0 commit comments

Comments
 (0)