Skip to content

Commit 8711041

Browse files
committed
Merge branch 'main' into typ_v4
2 parents 0172300 + 52e01be commit 8711041

File tree

1 file changed

+108
-83
lines changed

1 file changed

+108
-83
lines changed

array_api_compat/common/_helpers.py

Lines changed: 108 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import math
1313
import sys
1414
import warnings
15+
from collections.abc import Collection, Hashable
16+
from functools import lru_cache
1517
from types import NoneType
1618
from typing import (
1719
TYPE_CHECKING,
@@ -56,23 +58,37 @@
5658
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
5759

5860

61+
@lru_cache(100)
62+
def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
63+
try:
64+
mod = sys.modules[modname]
65+
except KeyError:
66+
return False
67+
parent_cls = getattr(mod, clsname)
68+
return issubclass(cls, parent_cls)
69+
70+
5971
def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
6072
"""Return True if `x` is a zero-gradient array.
6173
6274
These arrays are a design quirk of Jax that may one day be removed.
6375
See https://github.com/google/jax/issues/20620.
6476
"""
65-
if "numpy" not in sys.modules or "jax" not in sys.modules:
77+
# Fast exit
78+
try:
79+
dtype = x.dtype # type: ignore[attr-defined]
80+
except AttributeError:
81+
return False
82+
cls = cast(Hashable, type(dtype))
83+
if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"):
6684
return False
6785

68-
import jax
69-
import numpy as np
86+
if "jax" not in sys.modules:
87+
return False
7088

71-
jax_float0 = cast("np.dtype[np.void]", jax.float0)
72-
return (
73-
isinstance(x, np.ndarray)
74-
and cast("npt.NDArray[np.void]", x).dtype == jax_float0
75-
)
89+
import jax
90+
# jax.float0 is a np.dtype([('float0', 'V')])
91+
return dtype == jax.float0
7692

7793

7894
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
@@ -96,15 +112,12 @@ def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
96112
is_jax_array
97113
is_pydata_sparse_array
98114
"""
99-
# Avoid importing NumPy if it isn't already
100-
if "numpy" not in sys.modules:
101-
return False
102-
103-
import numpy as np
104-
105115
# TODO: Should we reject ndarray subclasses?
106-
return (isinstance(x, (np.ndarray, np.generic))
107-
and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip
116+
cls = cast(Hashable, type(x))
117+
return (
118+
_issubclass_fast(cls, "numpy", "ndarray")
119+
or _issubclass_fast(cls, "numpy", "generic")
120+
) and not _is_jax_zero_gradient_array(x)
108121

109122

110123
def is_cupy_array(x: object) -> bool:
@@ -128,14 +141,8 @@ def is_cupy_array(x: object) -> bool:
128141
is_jax_array
129142
is_pydata_sparse_array
130143
"""
131-
# Avoid importing CuPy if it isn't already
132-
if "cupy" not in sys.modules:
133-
return False
134-
135-
import cupy as cp
136-
137-
# TODO: Should we reject ndarray subclasses?
138-
return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType]
144+
cls = cast(Hashable, type(x))
145+
return _issubclass_fast(cls, "cupy", "ndarray")
139146

140147

141148
def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
@@ -156,14 +163,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
156163
is_jax_array
157164
is_pydata_sparse_array
158165
"""
159-
# Avoid importing torch if it isn't already
160-
if "torch" not in sys.modules:
161-
return False
162-
163-
import torch
164-
165-
# TODO: Should we reject ndarray subclasses?
166-
return isinstance(x, torch.Tensor)
166+
cls = cast(Hashable, type(x))
167+
return _issubclass_fast(cls, "torch", "Tensor")
167168

168169

169170
def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
@@ -185,13 +186,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
185186
is_jax_array
186187
is_pydata_sparse_array
187188
"""
188-
# Avoid importing torch if it isn't already
189-
if "ndonnx" not in sys.modules:
190-
return False
191-
192-
import ndonnx as ndx
193-
194-
return isinstance(x, ndx.Array)
189+
cls = cast(Hashable, type(x))
190+
return _issubclass_fast(cls, "ndonnx", "Array")
195191

196192

197193
def is_dask_array(x: object) -> TypeIs[da.Array]:
@@ -213,13 +209,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
213209
is_jax_array
214210
is_pydata_sparse_array
215211
"""
216-
# Avoid importing dask if it isn't already
217-
if "dask.array" not in sys.modules:
218-
return False
219-
220-
import dask.array
221-
222-
return isinstance(x, dask.array.Array)
212+
cls = cast(Hashable, type(x))
213+
return _issubclass_fast(cls, "dask.array", "Array")
223214

224215

225216
def is_jax_array(x: object) -> TypeIs[jax.Array]:
@@ -242,13 +233,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
242233
is_dask_array
243234
is_pydata_sparse_array
244235
"""
245-
# Avoid importing jax if it isn't already
246-
if "jax" not in sys.modules:
247-
return False
248-
249-
import jax
250-
251-
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
236+
cls = cast(Hashable, type(x))
237+
return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
252238

253239

254240
def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
@@ -271,14 +257,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
271257
is_dask_array
272258
is_jax_array
273259
"""
274-
# Avoid importing jax if it isn't already
275-
if "sparse" not in sys.modules:
276-
return False
277-
278-
import sparse
279-
280260
# TODO: Account for other backends.
281-
return isinstance(x, sparse.SparseArray)
261+
cls = cast(Hashable, type(x))
262+
return _issubclass_fast(cls, "sparse", "SparseArray")
282263

283264

284265
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
@@ -297,13 +278,23 @@ def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
297278
is_jax_array
298279
"""
299280
return (
300-
is_numpy_array(x)
301-
or is_cupy_array(x)
302-
or is_torch_array(x)
303-
or is_dask_array(x)
304-
or is_jax_array(x)
305-
or is_pydata_sparse_array(x)
306-
or hasattr(x, "__array_namespace__")
281+
hasattr(x, '__array_namespace__')
282+
or _is_array_api_cls(cast(Hashable, type(x)))
283+
)
284+
285+
286+
@lru_cache(100)
287+
def _is_array_api_cls(cls: type) -> bool:
288+
return (
289+
# TODO: drop support for numpy<2 which didn't have __array_namespace__
290+
_issubclass_fast(cls, "numpy", "ndarray")
291+
or _issubclass_fast(cls, "numpy", "generic")
292+
or _issubclass_fast(cls, "cupy", "ndarray")
293+
or _issubclass_fast(cls, "torch", "Tensor")
294+
or _issubclass_fast(cls, "dask.array", "Array")
295+
or _issubclass_fast(cls, "sparse", "SparseArray")
296+
# TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
297+
or _issubclass_fast(cls, "jax", "Array")
307298
)
308299

309300

@@ -312,6 +303,7 @@ def _compat_module_name() -> str:
312303
return __name__.removesuffix(".common._helpers")
313304

314305

306+
@lru_cache(100)
315307
def is_numpy_namespace(xp: Namespace) -> bool:
316308
"""
317309
Returns True if `xp` is a NumPy namespace.
@@ -333,6 +325,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
333325
return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"}
334326

335327

328+
@lru_cache(100)
336329
def is_cupy_namespace(xp: Namespace) -> bool:
337330
"""
338331
Returns True if `xp` is a CuPy namespace.
@@ -354,6 +347,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
354347
return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"}
355348

356349

350+
@lru_cache(100)
357351
def is_torch_namespace(xp: Namespace) -> bool:
358352
"""
359353
Returns True if `xp` is a PyTorch namespace.
@@ -394,6 +388,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
394388
return xp.__name__ == "ndonnx"
395389

396390

391+
@lru_cache(100)
397392
def is_dask_namespace(xp: Namespace) -> bool:
398393
"""
399394
Returns True if `xp` is a Dask namespace.
@@ -934,6 +929,19 @@ def size(x: HasShape[float | None]) -> int | None:
934929
return None if math.isnan(out) else cast(int, out)
935930

936931

932+
@lru_cache(100)
933+
def _is_writeable_cls(cls: type) -> bool | None:
934+
if (
935+
_issubclass_fast(cls, "numpy", "generic")
936+
or _issubclass_fast(cls, "jax", "Array")
937+
or _issubclass_fast(cls, "sparse", "SparseArray")
938+
):
939+
return False
940+
if _is_array_api_cls(cls):
941+
return True
942+
return None
943+
944+
937945
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
938946
"""
939947
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
@@ -944,11 +952,32 @@ def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
944952
As there is no standard way to check if an array is writeable without actually
945953
writing to it, this function blindly returns True for all unknown array types.
946954
"""
947-
if is_numpy_array(x):
948-
return x.flags.writeable
949-
if is_jax_array(x) or is_pydata_sparse_array(x):
955+
cls = cast(Hashable, type(x))
956+
if _issubclass_fast(cls, "numpy", "ndarray"):
957+
return cast("npt.NDArray", x).flags.writeable
958+
res = _is_writeable_cls(cls)
959+
if res is not None:
960+
return res
961+
return hasattr(x, '__array_namespace__')
962+
963+
964+
@lru_cache(100)
965+
def _is_lazy_cls(cls: type) -> bool | None:
966+
if (
967+
_issubclass_fast(cls, "numpy", "ndarray")
968+
or _issubclass_fast(cls, "numpy", "generic")
969+
or _issubclass_fast(cls, "cupy", "ndarray")
970+
or _issubclass_fast(cls, "torch", "Tensor")
971+
or _issubclass_fast(cls, "sparse", "SparseArray")
972+
):
950973
return False
951-
return is_array_api_obj(x)
974+
if (
975+
_issubclass_fast(cls, "jax", "Array")
976+
or _issubclass_fast(cls, "dask.array", "Array")
977+
or _issubclass_fast(cls, "ndonnx", "Array")
978+
):
979+
return True
980+
return None
952981

953982

954983
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
@@ -964,14 +993,6 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
964993
This function errs on the side of caution for array types that may or may not be
965994
lazy, e.g. JAX arrays, by always returning True for them.
966995
"""
967-
if (
968-
is_numpy_array(x)
969-
or is_cupy_array(x)
970-
or is_torch_array(x)
971-
or is_pydata_sparse_array(x)
972-
):
973-
return False
974-
975996
# **JAX note:** while it is possible to determine if you're inside or outside
976997
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
977998
# as we do below for unknown arrays, this is not recommended by JAX best practices.
@@ -981,10 +1002,14 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
9811002
# compatibility, is highly detrimental to performance as the whole graph will end
9821003
# up being computed multiple times.
9831004

984-
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
985-
return True
1005+
# Note: skipping reclassification of JAX zero gradient arrays, as one will
1006+
# exclusively get them once they leave a jax.grad JIT context.
1007+
cls = cast(Hashable, type(x))
1008+
res = _is_lazy_cls(cls)
1009+
if res is not None:
1010+
return res
9861011

987-
if not is_array_api_obj(x):
1012+
if not hasattr(x, "__array_namespace__"):
9881013
return False
9891014

9901015
# Unknown Array API compatible object. Note that this test may have dire consequences
@@ -1037,7 +1062,7 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
10371062
"to_device",
10381063
]
10391064

1040-
_all_ignore = ["sys", "math", "inspect", "warnings"]
1065+
_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings']
10411066

10421067
def __dir__() -> list[str]:
10431068
return __all__

0 commit comments

Comments
 (0)