|
7 | 7 | # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
|
8 | 8 | from __future__ import annotations
|
9 | 9 |
|
10 |
| -from collections.abc import Callable, Iterable, Sequence |
| 10 | +import sys |
| 11 | +from collections.abc import Callable, Iterable, Iterator, Sequence |
11 | 12 | from functools import wraps
|
12 | 13 | from types import ModuleType
|
13 | 14 | from typing import TYPE_CHECKING, Any, TypeVar, cast
|
@@ -42,6 +43,8 @@ def override(func: Callable[P, T]) -> Callable[P, T]:
|
42 | 43 |
|
43 | 44 | T = TypeVar("T")
|
44 | 45 |
|
| 46 | +_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[no-any-explicit] |
| 47 | + |
45 | 48 |
|
46 | 49 | def lazy_xp_function( # type: ignore[no-any-explicit]
|
47 | 50 | func: Callable[..., Any],
|
@@ -132,12 +135,30 @@ def test_myfunc(xp):
|
132 | 135 | a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
|
133 | 136 | mymodule.myfunc(a) # This is not
|
134 | 137 | """
|
135 |
| - func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] |
136 |
| - if jax_jit: |
137 |
| - func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] |
138 |
| - "static_argnums": static_argnums, |
139 |
| - "static_argnames": static_argnames, |
140 |
| - } |
| 138 | + tags = { |
| 139 | + "allow_dask_compute": allow_dask_compute, |
| 140 | + "jax_jit": jax_jit, |
| 141 | + "static_argnums": static_argnums, |
| 142 | + "static_argnames": static_argnames, |
| 143 | + } |
| 144 | + if _is_numpy_ufunc(func): |
| 145 | + _ufuncs_tags[func] = tags # Can't assign attributes |
| 146 | + else: |
| 147 | + func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess] |
| 148 | + |
| 149 | + |
| 150 | +def _is_numpy_ufunc(f: Callable[..., Any]) -> bool: # type: ignore[no-any-explicit] |
| 151 | + """Check if a function is a numpy ufunc |
| 152 | + native from numpy, CPython extension, or Cythonized). |
| 153 | + """ |
| 154 | + if "numpy" not in sys.modules: |
| 155 | + # FIXME: is it possible to import a CPython or Cython module |
| 156 | + # that defines a ufunc without importing numpy first? |
| 157 | + return False |
| 158 | + |
| 159 | + import numpy as np |
| 160 | + |
| 161 | + return isinstance(f, np.ufunc) |
141 | 162 |
|
142 | 163 |
|
143 | 164 | def patch_lazy_xp_functions(
|
@@ -179,24 +200,35 @@ def xp(request, monkeypatch):
|
179 | 200 | """
|
180 | 201 | globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit]
|
181 | 202 |
|
182 |
| - if is_dask_namespace(xp): |
| 203 | + def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit] |
183 | 204 | for name, func in globals_.items():
|
184 |
| - n = getattr(func, "allow_dask_compute", None) |
185 |
| - if n is not None: |
186 |
| - assert isinstance(n, int) |
187 |
| - wrapped = _allow_dask_compute(func, n) |
188 |
| - monkeypatch.setitem(globals_, name, wrapped) |
| 205 | + if _is_numpy_ufunc(func): |
| 206 | + tags = _ufuncs_tags.get(func) |
| 207 | + else: |
| 208 | + tags = getattr(func, "_lazy_xp_function", None) # type: ignore[assignment] |
| 209 | + if tags: |
| 210 | + yield name, func, tags |
| 211 | + |
| 212 | + if is_dask_namespace(xp): |
| 213 | + for name, func, tags in iter_tagged(): |
| 214 | + n = tags["allow_dask_compute"] |
| 215 | + wrapped = _allow_dask_compute(func, n) |
| 216 | + monkeypatch.setitem(globals_, name, wrapped) |
189 | 217 |
|
190 | 218 | elif is_jax_namespace(xp):
|
191 | 219 | import jax
|
192 | 220 |
|
193 |
| - for name, func in globals_.items(): |
194 |
| - kwargs = cast( # type: ignore[no-any-explicit] |
195 |
| - "dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None) |
196 |
| - ) |
197 |
| - if kwargs is not None: |
| 221 | + for name, func, tags in iter_tagged(): |
| 222 | + if tags["jax_jit"]: |
198 | 223 | # suppress unused-ignore to run mypy in -e lint as well as -e dev
|
199 |
| - wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore] |
| 224 | + wrapped = cast( # type: ignore[no-any-explicit] |
| 225 | + Callable[..., Any], |
| 226 | + jax.jit( |
| 227 | + func, |
| 228 | + static_argnums=tags["static_argnums"], |
| 229 | + static_argnames=tags["static_argnames"], |
| 230 | + ), |
| 231 | + ) |
200 | 232 | monkeypatch.setitem(globals_, name, wrapped)
|
201 | 233 |
|
202 | 234 |
|
|
0 commit comments