Skip to content

Commit c414b23

Browse files
committed
BUG: lazy_xp_function crashes with ufuncs
1 parent 8077623 commit c414b23

File tree

2 files changed

+72
-19
lines changed

2 files changed

+72
-19
lines changed

src/array_api_extra/testing.py

+51-19
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
88
from __future__ import annotations
99

10-
from collections.abc import Callable, Iterable, Sequence
10+
import sys
11+
from collections.abc import Callable, Iterable, Iterator, Sequence
1112
from functools import wraps
1213
from types import ModuleType
1314
from typing import TYPE_CHECKING, Any, TypeVar, cast
@@ -42,6 +43,8 @@ def override(func: Callable[P, T]) -> Callable[P, T]:
4243

4344
T = TypeVar("T")
4445

46+
_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[no-any-explicit]
47+
4548

4649
def lazy_xp_function( # type: ignore[no-any-explicit]
4750
func: Callable[..., Any],
@@ -132,12 +135,30 @@ def test_myfunc(xp):
132135
a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
133136
mymodule.myfunc(a) # This is not
134137
"""
135-
func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
136-
if jax_jit:
137-
func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
138-
"static_argnums": static_argnums,
139-
"static_argnames": static_argnames,
140-
}
138+
tags = {
139+
"allow_dask_compute": allow_dask_compute,
140+
"jax_jit": jax_jit,
141+
"static_argnums": static_argnums,
142+
"static_argnames": static_argnames,
143+
}
144+
if _is_numpy_ufunc(func):
145+
_ufuncs_tags[func] = tags # Can't assign attributes
146+
else:
147+
func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
148+
149+
150+
def _is_numpy_ufunc(f: Callable[..., Any]) -> bool: # type: ignore[no-any-explicit]
151+
"""Check if a function is a numpy ufunc
152+
native from numpy, CPython extension, or Cythonized).
153+
"""
154+
if "numpy" not in sys.modules:
155+
# FIXME: is it possible to import a CPython or Cython module
156+
# that defines a ufunc without importing numpy first?
157+
return False
158+
159+
import numpy as np
160+
161+
return isinstance(f, np.ufunc)
141162

142163

143164
def patch_lazy_xp_functions(
@@ -179,24 +200,35 @@ def xp(request, monkeypatch):
179200
"""
180201
globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit]
181202

182-
if is_dask_namespace(xp):
203+
def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit]
183204
for name, func in globals_.items():
184-
n = getattr(func, "allow_dask_compute", None)
185-
if n is not None:
186-
assert isinstance(n, int)
187-
wrapped = _allow_dask_compute(func, n)
188-
monkeypatch.setitem(globals_, name, wrapped)
205+
if _is_numpy_ufunc(func):
206+
tags = _ufuncs_tags.get(func)
207+
else:
208+
tags = getattr(func, "_lazy_xp_function", None) # type: ignore[assignment]
209+
if tags:
210+
yield name, func, tags
211+
212+
if is_dask_namespace(xp):
213+
for name, func, tags in iter_tagged():
214+
n = tags["allow_dask_compute"]
215+
wrapped = _allow_dask_compute(func, n)
216+
monkeypatch.setitem(globals_, name, wrapped)
189217

190218
elif is_jax_namespace(xp):
191219
import jax
192220

193-
for name, func in globals_.items():
194-
kwargs = cast( # type: ignore[no-any-explicit]
195-
"dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None)
196-
)
197-
if kwargs is not None:
221+
for name, func, tags in iter_tagged():
222+
if tags["jax_jit"]:
198223
# suppress unused-ignore to run mypy in -e lint as well as -e dev
199-
wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore]
224+
wrapped = cast( # type: ignore[no-any-explicit]
225+
Callable[..., Any],
226+
jax.jit(
227+
func,
228+
static_argnums=tags["static_argnums"],
229+
static_argnames=tags["static_argnames"],
230+
),
231+
)
200232
monkeypatch.setitem(globals_, name, wrapped)
201233

202234

tests/test_testing.py

+21
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import pytest
6+
from numpy import min as numpy_min
67

78
from array_api_extra._lib import Backend
89
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
@@ -202,3 +203,23 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra
202203
xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0]))
203204
xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0]))
204205
xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0]))
206+
207+
208+
lazy_xp_function(numpy_min, static_argnames="axis")
209+
210+
211+
def test_lazy_xp_function_ufunc(xp: ModuleType, library: Backend):
212+
x = xp.asarray([[1, 4], [3, 2]])
213+
if library in (Backend.ARRAY_API_STRICT, Backend.TORCH, Backend.JAX):
214+
# array-api-strict, torch and jax don't define __array_ufunc__
215+
# numpy ufuncs can't auto-convert to numpy from torch
216+
# array-api-strict arrays are auto-converted to numpy
217+
# eager jax arrays are auto-converted to numpy in eager jax
218+
# and fail in jax.jit (which lazy_xp_function tests here)
219+
with pytest.raises((TypeError, AssertionError)):
220+
xp_assert_equal(numpy_min(x, axis=0), xp.asarray([1, 2]))
221+
else:
222+
# cupy, dask and sparse define __array_ufunc__ and dispatch accordingly
223+
# note that when sparse reduces to scalar it returns a np.generic, which
224+
# would make xp_assert_equal fail.
225+
xp_assert_equal(numpy_min(x, axis=0), xp.asarray([1, 2]))

0 commit comments

Comments
 (0)