diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index 65362382..1007fa1a 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -7,7 +7,8 @@ # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 from __future__ import annotations -from collections.abc import Callable, Iterable, Sequence +import contextlib +from collections.abc import Callable, Iterable, Iterator, Sequence from functools import wraps from types import ModuleType from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -42,6 +43,8 @@ def override(func: Callable[P, T]) -> Callable[P, T]: T = TypeVar("T") +_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[no-any-explicit] + def lazy_xp_function( # type: ignore[no-any-explicit] func: Callable[..., Any], @@ -132,12 +135,16 @@ def test_myfunc(xp): a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c = mymodule.myfunc(a) # This is not """ - func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] - if jax_jit: - func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] - "static_argnums": static_argnums, - "static_argnames": static_argnames, - } + tags = { + "allow_dask_compute": allow_dask_compute, + "jax_jit": jax_jit, + "static_argnums": static_argnums, + "static_argnames": static_argnames, + } + try: + func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess] + except AttributeError: # @cython.vectorize + _ufuncs_tags[func] = tags def patch_lazy_xp_functions( @@ -179,24 +186,37 @@ def xp(request, monkeypatch): """ globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit] - if is_dask_namespace(xp): + def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit] for name, func in globals_.items(): - n = getattr(func, "allow_dask_compute", None) - if n is not None: - assert isinstance(n, int) - wrapped = _allow_dask_compute(func, n) - monkeypatch.setitem(globals_, name, wrapped) + tags: dict[str, Any] | None = None # type: ignore[no-any-explicit] + with contextlib.suppress(AttributeError): + tags = func._lazy_xp_function # pylint: disable=protected-access + if tags is None: + with contextlib.suppress(KeyError, TypeError): + tags = _ufuncs_tags[func] + if tags is not None: + yield name, func, tags + + if is_dask_namespace(xp): + for name, func, tags in iter_tagged(): + n = tags["allow_dask_compute"] + wrapped = _allow_dask_compute(func, n) + monkeypatch.setitem(globals_, name, wrapped) elif is_jax_namespace(xp): import jax - for name, func in globals_.items(): - kwargs = cast( # type: ignore[no-any-explicit] - "dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None) - ) - if kwargs is not None: + for name, func, tags in iter_tagged(): + if tags["jax_jit"]: # suppress unused-ignore to run mypy in -e lint as well as -e dev - wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore] + wrapped = cast( # type: ignore[no-any-explicit] + Callable[..., Any], + jax.jit( + func, + static_argnums=tags["static_argnums"], + static_argnames=tags["static_argnames"], + ), + ) monkeypatch.setitem(globals_, name, wrapped) diff --git a/tests/test_testing.py b/tests/test_testing.py index c9a1e32f..aa7faaf8 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -202,3 +202,33 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0])) xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0])) xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0])) + + +try: + # Test an arbitrary Cython ufunc (@cython.vectorize). + # When SCIPY_ARRAY_API is not set, this is the same as + # scipy.special.erf. + from scipy.special._ufuncs import erf # type: ignore[import-not-found] + + lazy_xp_function(erf) # pyright: ignore[reportUnknownArgumentType] +except ImportError: + erf = None + + +@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # torch +def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend): + pytest.importorskip("scipy") + assert erf is not None + x = xp.asarray([6.0, 7.0]) + if library in (Backend.ARRAY_API_STRICT, Backend.JAX): + # array-api-strict arrays are auto-converted to numpy + # which results in an assertion error for mismatched namespaces + # eager jax arrays are auto-converted to numpy in eager jax + # and fail in jax.jit (which lazy_xp_function tests here) + with pytest.raises((TypeError, AssertionError)): + xp_assert_equal(erf(x), xp.asarray([1.0, 1.0])) + else: + # cupy, dask and sparse define __array_ufunc__ and dispatch accordingly + # note that when sparse reduces to scalar it returns a np.generic, which + # would make xp_assert_equal fail. + xp_assert_equal(erf(x), xp.asarray([1.0, 1.0]))