12
12
import math
13
13
import sys
14
14
import warnings
15
+ from collections .abc import Collection , Hashable
16
+ from functools import lru_cache
15
17
from types import NoneType
16
18
from typing import (
17
19
TYPE_CHECKING ,
56
58
_API_VERSIONS : Final = _API_VERSIONS_OLD | frozenset ({"2024.12" })
57
59
58
60
61
+ @lru_cache (100 )
62
+ def _issubclass_fast (cls : type , modname : str , clsname : str ) -> bool :
63
+ try :
64
+ mod = sys .modules [modname ]
65
+ except KeyError :
66
+ return False
67
+ parent_cls = getattr (mod , clsname )
68
+ return issubclass (cls , parent_cls )
69
+
70
+
59
71
def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
60
72
"""Return True if `x` is a zero-gradient array.
61
73
62
74
These arrays are a design quirk of Jax that may one day be removed.
63
75
See https://github.com/google/jax/issues/20620.
64
76
"""
65
- if "numpy" not in sys .modules or "jax" not in sys .modules :
77
+ # Fast exit
78
+ try :
79
+ dtype = x .dtype # type: ignore[attr-defined]
80
+ except AttributeError :
81
+ return False
82
+ cls = cast (Hashable , type (dtype ))
83
+ if not _issubclass_fast (cls , "numpy.dtypes" , "VoidDType" ):
66
84
return False
67
85
68
- import jax
69
- import numpy as np
86
+ if " jax" not in sys . modules :
87
+ return False
70
88
71
- jax_float0 = cast ("np.dtype[np.void]" , jax .float0 )
72
- return (
73
- isinstance (x , np .ndarray )
74
- and cast ("npt.NDArray[np.void]" , x ).dtype == jax_float0
75
- )
89
+ import jax
90
+ # jax.float0 is a np.dtype([('float0', 'V')])
91
+ return dtype == jax .float0
76
92
77
93
78
94
def is_numpy_array (x : object ) -> TypeIs [npt .NDArray [Any ]]:
@@ -96,15 +112,12 @@ def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
96
112
is_jax_array
97
113
is_pydata_sparse_array
98
114
"""
99
- # Avoid importing NumPy if it isn't already
100
- if "numpy" not in sys .modules :
101
- return False
102
-
103
- import numpy as np
104
-
105
115
# TODO: Should we reject ndarray subclasses?
106
- return (isinstance (x , (np .ndarray , np .generic ))
107
- and not _is_jax_zero_gradient_array (x )) # pyright: ignore[reportUnknownArgumentType] # fmt: skip
116
+ cls = cast (Hashable , type (x ))
117
+ return (
118
+ _issubclass_fast (cls , "numpy" , "ndarray" )
119
+ or _issubclass_fast (cls , "numpy" , "generic" )
120
+ ) and not _is_jax_zero_gradient_array (x )
108
121
109
122
110
123
def is_cupy_array (x : object ) -> bool :
@@ -128,14 +141,8 @@ def is_cupy_array(x: object) -> bool:
128
141
is_jax_array
129
142
is_pydata_sparse_array
130
143
"""
131
- # Avoid importing CuPy if it isn't already
132
- if "cupy" not in sys .modules :
133
- return False
134
-
135
- import cupy as cp
136
-
137
- # TODO: Should we reject ndarray subclasses?
138
- return isinstance (x , cp .ndarray ) # pyright: ignore[reportUnknownMemberType]
144
+ cls = cast (Hashable , type (x ))
145
+ return _issubclass_fast (cls , "cupy" , "ndarray" )
139
146
140
147
141
148
def is_torch_array (x : object ) -> TypeIs [torch .Tensor ]:
@@ -156,14 +163,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
156
163
is_jax_array
157
164
is_pydata_sparse_array
158
165
"""
159
- # Avoid importing torch if it isn't already
160
- if "torch" not in sys .modules :
161
- return False
162
-
163
- import torch
164
-
165
- # TODO: Should we reject ndarray subclasses?
166
- return isinstance (x , torch .Tensor )
166
+ cls = cast (Hashable , type (x ))
167
+ return _issubclass_fast (cls , "torch" , "Tensor" )
167
168
168
169
169
170
def is_ndonnx_array (x : object ) -> TypeIs [ndx .Array ]:
@@ -185,13 +186,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
185
186
is_jax_array
186
187
is_pydata_sparse_array
187
188
"""
188
- # Avoid importing torch if it isn't already
189
- if "ndonnx" not in sys .modules :
190
- return False
191
-
192
- import ndonnx as ndx
193
-
194
- return isinstance (x , ndx .Array )
189
+ cls = cast (Hashable , type (x ))
190
+ return _issubclass_fast (cls , "ndonnx" , "Array" )
195
191
196
192
197
193
def is_dask_array (x : object ) -> TypeIs [da .Array ]:
@@ -213,13 +209,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
213
209
is_jax_array
214
210
is_pydata_sparse_array
215
211
"""
216
- # Avoid importing dask if it isn't already
217
- if "dask.array" not in sys .modules :
218
- return False
219
-
220
- import dask .array
221
-
222
- return isinstance (x , dask .array .Array )
212
+ cls = cast (Hashable , type (x ))
213
+ return _issubclass_fast (cls , "dask.array" , "Array" )
223
214
224
215
225
216
def is_jax_array (x : object ) -> TypeIs [jax .Array ]:
@@ -242,13 +233,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
242
233
is_dask_array
243
234
is_pydata_sparse_array
244
235
"""
245
- # Avoid importing jax if it isn't already
246
- if "jax" not in sys .modules :
247
- return False
248
-
249
- import jax
250
-
251
- return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
236
+ cls = cast (Hashable , type (x ))
237
+ return _issubclass_fast (cls , "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
252
238
253
239
254
240
def is_pydata_sparse_array (x : object ) -> TypeIs [sparse .SparseArray ]:
@@ -271,14 +257,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
271
257
is_dask_array
272
258
is_jax_array
273
259
"""
274
- # Avoid importing jax if it isn't already
275
- if "sparse" not in sys .modules :
276
- return False
277
-
278
- import sparse
279
-
280
260
# TODO: Account for other backends.
281
- return isinstance (x , sparse .SparseArray )
261
+ cls = cast (Hashable , type (x ))
262
+ return _issubclass_fast (cls , "sparse" , "SparseArray" )
282
263
283
264
284
265
def is_array_api_obj (x : object ) -> TypeGuard [_ArrayApiObj ]:
@@ -297,13 +278,23 @@ def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
297
278
is_jax_array
298
279
"""
299
280
return (
300
- is_numpy_array (x )
301
- or is_cupy_array (x )
302
- or is_torch_array (x )
303
- or is_dask_array (x )
304
- or is_jax_array (x )
305
- or is_pydata_sparse_array (x )
306
- or hasattr (x , "__array_namespace__" )
281
+ hasattr (x , '__array_namespace__' )
282
+ or _is_array_api_cls (cast (Hashable , type (x )))
283
+ )
284
+
285
+
286
+ @lru_cache (100 )
287
+ def _is_array_api_cls (cls : type ) -> bool :
288
+ return (
289
+ # TODO: drop support for numpy<2 which didn't have __array_namespace__
290
+ _issubclass_fast (cls , "numpy" , "ndarray" )
291
+ or _issubclass_fast (cls , "numpy" , "generic" )
292
+ or _issubclass_fast (cls , "cupy" , "ndarray" )
293
+ or _issubclass_fast (cls , "torch" , "Tensor" )
294
+ or _issubclass_fast (cls , "dask.array" , "Array" )
295
+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
296
+ # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
297
+ or _issubclass_fast (cls , "jax" , "Array" )
307
298
)
308
299
309
300
@@ -312,6 +303,7 @@ def _compat_module_name() -> str:
312
303
return __name__ .removesuffix (".common._helpers" )
313
304
314
305
306
+ @lru_cache (100 )
315
307
def is_numpy_namespace (xp : Namespace ) -> bool :
316
308
"""
317
309
Returns True if `xp` is a NumPy namespace.
@@ -333,6 +325,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
333
325
return xp .__name__ in {"numpy" , _compat_module_name () + ".numpy" }
334
326
335
327
328
+ @lru_cache (100 )
336
329
def is_cupy_namespace (xp : Namespace ) -> bool :
337
330
"""
338
331
Returns True if `xp` is a CuPy namespace.
@@ -354,6 +347,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
354
347
return xp .__name__ in {"cupy" , _compat_module_name () + ".cupy" }
355
348
356
349
350
+ @lru_cache (100 )
357
351
def is_torch_namespace (xp : Namespace ) -> bool :
358
352
"""
359
353
Returns True if `xp` is a PyTorch namespace.
@@ -394,6 +388,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
394
388
return xp .__name__ == "ndonnx"
395
389
396
390
391
+ @lru_cache (100 )
397
392
def is_dask_namespace (xp : Namespace ) -> bool :
398
393
"""
399
394
Returns True if `xp` is a Dask namespace.
@@ -934,6 +929,19 @@ def size(x: HasShape[float | None]) -> int | None:
934
929
return None if math .isnan (out ) else cast (int , out )
935
930
936
931
932
+ @lru_cache (100 )
933
+ def _is_writeable_cls (cls : type ) -> bool | None :
934
+ if (
935
+ _issubclass_fast (cls , "numpy" , "generic" )
936
+ or _issubclass_fast (cls , "jax" , "Array" )
937
+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
938
+ ):
939
+ return False
940
+ if _is_array_api_cls (cls ):
941
+ return True
942
+ return None
943
+
944
+
937
945
def is_writeable_array (x : object ) -> TypeGuard [_ArrayApiObj ]:
938
946
"""
939
947
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
@@ -944,11 +952,32 @@ def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
944
952
As there is no standard way to check if an array is writeable without actually
945
953
writing to it, this function blindly returns True for all unknown array types.
946
954
"""
947
- if is_numpy_array (x ):
948
- return x .flags .writeable
949
- if is_jax_array (x ) or is_pydata_sparse_array (x ):
955
+ cls = cast (Hashable , type (x ))
956
+ if _issubclass_fast (cls , "numpy" , "ndarray" ):
957
+ return cast ("npt.NDArray" , x ).flags .writeable
958
+ res = _is_writeable_cls (cls )
959
+ if res is not None :
960
+ return res
961
+ return hasattr (x , '__array_namespace__' )
962
+
963
+
964
+ @lru_cache (100 )
965
+ def _is_lazy_cls (cls : type ) -> bool | None :
966
+ if (
967
+ _issubclass_fast (cls , "numpy" , "ndarray" )
968
+ or _issubclass_fast (cls , "numpy" , "generic" )
969
+ or _issubclass_fast (cls , "cupy" , "ndarray" )
970
+ or _issubclass_fast (cls , "torch" , "Tensor" )
971
+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
972
+ ):
950
973
return False
951
- return is_array_api_obj (x )
974
+ if (
975
+ _issubclass_fast (cls , "jax" , "Array" )
976
+ or _issubclass_fast (cls , "dask.array" , "Array" )
977
+ or _issubclass_fast (cls , "ndonnx" , "Array" )
978
+ ):
979
+ return True
980
+ return None
952
981
953
982
954
983
def is_lazy_array (x : object ) -> TypeGuard [_ArrayApiObj ]:
@@ -964,14 +993,6 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
964
993
This function errs on the side of caution for array types that may or may not be
965
994
lazy, e.g. JAX arrays, by always returning True for them.
966
995
"""
967
- if (
968
- is_numpy_array (x )
969
- or is_cupy_array (x )
970
- or is_torch_array (x )
971
- or is_pydata_sparse_array (x )
972
- ):
973
- return False
974
-
975
996
# **JAX note:** while it is possible to determine if you're inside or outside
976
997
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
977
998
# as we do below for unknown arrays, this is not recommended by JAX best practices.
@@ -981,10 +1002,14 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
981
1002
# compatibility, is highly detrimental to performance as the whole graph will end
982
1003
# up being computed multiple times.
983
1004
984
- if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
985
- return True
1005
+ # Note: skipping reclassification of JAX zero gradient arrays, as one will
1006
+ # exclusively get them once they leave a jax.grad JIT context.
1007
+ cls = cast (Hashable , type (x ))
1008
+ res = _is_lazy_cls (cls )
1009
+ if res is not None :
1010
+ return res
986
1011
987
- if not is_array_api_obj ( x ):
1012
+ if not hasattr ( x , "__array_namespace__" ):
988
1013
return False
989
1014
990
1015
# Unknown Array API compatible object. Note that this test may have dire consequences
@@ -1037,7 +1062,7 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
1037
1062
"to_device" ,
1038
1063
]
1039
1064
1040
- _all_ignore = [" sys" , " math" , " inspect" , " warnings" ]
1065
+ _all_ignore = ['lru_cache' , ' sys' , ' math' , ' inspect' , ' warnings' ]
1041
1066
1042
1067
def __dir__ () -> list [str ]:
1043
1068
return __all__
0 commit comments