diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index 1007fa1a..550665ae 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -200,7 +200,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: if is_dask_namespace(xp): for name, func, tags in iter_tagged(): n = tags["allow_dask_compute"] - wrapped = _allow_dask_compute(func, n) + wrapped = _dask_wrap(func, n) monkeypatch.setitem(globals_, name, wrapped) elif is_jax_namespace(xp): @@ -256,13 +256,15 @@ def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any: return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage] -def _allow_dask_compute( +def _dask_wrap( func: Callable[P, T], n: int ) -> Callable[P, T]: # numpydoc ignore=PR01,RT01 """ Wrap `func` to raise if it attempts to call `dask.compute` more than `n` times. + + After the function returns, materialize the graph in order to re-raise exceptions. """ - import dask.config + import dask func_name = getattr(func, "__name__", str(func)) n_str = f"only up to {n}" if n else "no" @@ -276,7 +278,12 @@ def _allow_dask_compute( @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 scheduler = CountingDaskScheduler(n, msg) - with dask.config.set({"scheduler": scheduler}): - return func(*args, **kwargs) + with dask.config.set({"scheduler": scheduler}): # pyright: ignore[reportPrivateImportUsage] + out = func(*args, **kwargs) + + # Block until the graph materializes and reraise exceptions. This allows + # `pytest.raises` and `pytest.warns` to work as expected. Note that this would + # not work on scheduler='distributed', as it would not block. + return dask.persist(out, scheduler="threads")[0] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] return wrapper diff --git a/tests/test_testing.py b/tests/test_testing.py index aa7faaf8..1649dd86 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -232,3 +232,32 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend): # note that when sparse reduces to scalar it returns a np.generic, which # would make xp_assert_equal fail. xp_assert_equal(erf(x), xp.asarray([1.0, 1.0])) + + +def dask_raises(x: Array) -> Array: + def _raises(x: Array) -> Array: + # Test that map_blocks doesn't eagerly call the function; + # dtype and meta should be sufficient to skip the trial run. + assert x.shape == (3,) + msg = "Hello world" + raise ValueError(msg) + + return x.map_blocks(_raises, dtype=x.dtype, meta=x._meta) + + +lazy_xp_function(dask_raises) + + +def test_lazy_xp_function_eagerly_raises(da: ModuleType): + """Test that the pattern:: + + with pytest.raises(Exception): + func(x) + + works with Dask, even though it normally wouldn't as we're disregarding the func + output so the graph would not be ordinarily materialized. + lazy_xp_function contains ad-hoc code to materialize and reraise exceptions. + """ + x = da.arange(3) + with pytest.raises(ValueError, match="Hello world"): + dask_raises(x)