diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index 4f8288cf..37e8e69e 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -39,7 +39,7 @@ def override(func: object) -> object: def lazy_xp_function( # type: ignore[explicit-any] func: Callable[..., Any], *, - allow_dask_compute: int = 0, + allow_dask_compute: bool | int = False, jax_jit: bool = True, static_argnums: int | Sequence[int] | None = None, static_argnames: str | Iterable[str] | None = None, @@ -59,9 +59,10 @@ def lazy_xp_function( # type: ignore[explicit-any] ---------- func : callable Function to be tested. - allow_dask_compute : int, optional - Number of times `func` is allowed to internally materialize the Dask graph. This - is typically triggered by ``bool()``, ``float()``, or ``np.asarray()``. + allow_dask_compute : bool | int, optional + Whether `func` is allowed to internally materialize the Dask graph, or maximum + number of times it is allowed to do so. This is typically triggered by + ``bool()``, ``float()``, or ``np.asarray()``. Set to 1 if you are aware that `func` converts the input parameters to NumPy and want to let it do so at least for the time being, knowing that it is going to be @@ -75,7 +76,10 @@ def lazy_xp_function( # type: ignore[explicit-any] a test function that invokes `func` multiple times should still work with this parameter set to 1. - Default: 0, meaning that `func` must be fully lazy and never materialize the + Set to True to allow `func` to materialize the graph an unlimited number + of times. + + Default: False, meaning that `func` must be fully lazy and never materialize the graph. jax_jit : bool, optional Set to True to replace `func` with ``jax.jit(func)`` after calling the @@ -235,6 +239,10 @@ def iter_tagged() -> ( # type: ignore[explicit-any] if is_dask_namespace(xp): for mod, name, func, tags in iter_tagged(): n = tags["allow_dask_compute"] + if n is True: + n = 1_000_000 + elif n is False: + n = 0 wrapped = _dask_wrap(func, n) monkeypatch.setattr(mod, name, wrapped) diff --git a/tests/test_testing.py b/tests/test_testing.py index ff67121b..fb9ba581 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -130,13 +130,18 @@ def non_materializable4(x: Array) -> Array: return non_materializable(x) +def non_materializable5(x: Array) -> Array: + return non_materializable(x) + + lazy_xp_function(good_lazy) # Works on JAX and Dask lazy_xp_function(non_materializable2, jax_jit=False, allow_dask_compute=2) +lazy_xp_function(non_materializable3, jax_jit=False, allow_dask_compute=True) # Works on JAX, but not Dask -lazy_xp_function(non_materializable3, jax_jit=False, allow_dask_compute=1) +lazy_xp_function(non_materializable4, jax_jit=False, allow_dask_compute=1) # Works neither on Dask nor JAX -lazy_xp_function(non_materializable4) +lazy_xp_function(non_materializable5) def test_lazy_xp_function(xp: ModuleType): @@ -147,29 +152,30 @@ def test_lazy_xp_function(xp: ModuleType): xp_assert_equal(non_materializable(x), xp.asarray([1.0, 2.0])) # Wrapping explicitly disabled xp_assert_equal(non_materializable2(x), xp.asarray([1.0, 2.0])) + xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0])) if is_jax_namespace(xp): - xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0])) + xp_assert_equal(non_materializable4(x), xp.asarray([1.0, 2.0])) with pytest.raises( TypeError, match="Attempted boolean conversion of traced array" ): - _ = non_materializable4(x) # Wrapped + _ = non_materializable5(x) # Wrapped elif is_dask_namespace(xp): with pytest.raises( AssertionError, match=r"dask\.compute.* 2 times, but only up to 1 calls are allowed", ): - _ = non_materializable3(x) + _ = non_materializable4(x) with pytest.raises( AssertionError, match=r"dask\.compute.* 1 times, but no calls are allowed", ): - _ = non_materializable4(x) + _ = non_materializable5(x) else: - xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0])) xp_assert_equal(non_materializable4(x), xp.asarray([1.0, 2.0])) + xp_assert_equal(non_materializable5(x), xp.asarray([1.0, 2.0])) def static_params(x: Array, n: int, flag: bool = False) -> Array: