@@ -48,9 +48,10 @@ def is_numpy_array(x):
48
48
is_array_api_obj
49
49
is_cupy_array
50
50
is_torch_array
51
+ is_ndonnx_array
51
52
is_dask_array
52
53
is_jax_array
53
- is_pydata_sparse
54
+ is_pydata_sparse_array
54
55
"""
55
56
# Avoid importing NumPy if it isn't already
56
57
if 'numpy' not in sys .modules :
@@ -78,11 +79,12 @@ def is_cupy_array(x):
78
79
is_array_api_obj
79
80
is_numpy_array
80
81
is_torch_array
82
+ is_ndonnx_array
81
83
is_dask_array
82
84
is_jax_array
83
- is_pydata_sparse
85
+ is_pydata_sparse_array
84
86
"""
85
- # Avoid importing NumPy if it isn't already
87
+ # Avoid importing CuPy if it isn't already
86
88
if 'cupy' not in sys .modules :
87
89
return False
88
90
@@ -107,7 +109,7 @@ def is_torch_array(x):
107
109
is_cupy_array
108
110
is_dask_array
109
111
is_jax_array
110
- is_pydata_sparse
112
+ is_pydata_sparse_array
111
113
"""
112
114
# Avoid importing torch if it isn't already
113
115
if 'torch' not in sys .modules :
@@ -118,6 +120,33 @@ def is_torch_array(x):
118
120
# TODO: Should we reject ndarray subclasses?
119
121
return isinstance (x , torch .Tensor )
120
122
123
+ def is_ndonnx_array (x ):
124
+ """
125
+ Return True if `x` is a ndonnx Array.
126
+
127
+ This function does not import ndonnx if it has not already been imported
128
+ and is therefore cheap to use.
129
+
130
+ See Also
131
+ --------
132
+
133
+ array_namespace
134
+ is_array_api_obj
135
+ is_numpy_array
136
+ is_cupy_array
137
+ is_ndonnx_array
138
+ is_dask_array
139
+ is_jax_array
140
+ is_pydata_sparse_array
141
+ """
142
+ # Avoid importing torch if it isn't already
143
+ if 'ndonnx' not in sys .modules :
144
+ return False
145
+
146
+ import ndonnx as ndx
147
+
148
+ return isinstance (x , ndx .Array )
149
+
121
150
def is_dask_array (x ):
122
151
"""
123
152
Return True if `x` is a dask.array Array.
@@ -133,8 +162,9 @@ def is_dask_array(x):
133
162
is_numpy_array
134
163
is_cupy_array
135
164
is_torch_array
165
+ is_ndonnx_array
136
166
is_jax_array
137
- is_pydata_sparse
167
+ is_pydata_sparse_array
138
168
"""
139
169
# Avoid importing dask if it isn't already
140
170
if 'dask.array' not in sys .modules :
@@ -160,8 +190,9 @@ def is_jax_array(x):
160
190
is_numpy_array
161
191
is_cupy_array
162
192
is_torch_array
193
+ is_ndonnx_array
163
194
is_dask_array
164
- is_pydata_sparse
195
+ is_pydata_sparse_array
165
196
"""
166
197
# Avoid importing jax if it isn't already
167
198
if 'jax' not in sys .modules :
@@ -172,7 +203,7 @@ def is_jax_array(x):
172
203
return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
173
204
174
205
175
- def is_pydata_sparse (x ) -> bool :
206
+ def is_pydata_sparse_array (x ) -> bool :
176
207
"""
177
208
Return True if `x` is an array from the `sparse` package.
178
209
@@ -188,6 +219,7 @@ def is_pydata_sparse(x) -> bool:
188
219
is_numpy_array
189
220
is_cupy_array
190
221
is_torch_array
222
+ is_ndonnx_array
191
223
is_dask_array
192
224
is_jax_array
193
225
"""
@@ -211,6 +243,7 @@ def is_array_api_obj(x):
211
243
is_numpy_array
212
244
is_cupy_array
213
245
is_torch_array
246
+ is_ndonnx_array
214
247
is_dask_array
215
248
is_jax_array
216
249
"""
@@ -219,7 +252,7 @@ def is_array_api_obj(x):
219
252
or is_torch_array (x ) \
220
253
or is_dask_array (x ) \
221
254
or is_jax_array (x ) \
222
- or is_pydata_sparse (x ) \
255
+ or is_pydata_sparse_array (x ) \
223
256
or hasattr (x , '__array_namespace__' )
224
257
225
258
def _check_api_version (api_version ):
@@ -288,7 +321,7 @@ def your_function(x, y):
288
321
is_torch_array
289
322
is_dask_array
290
323
is_jax_array
291
- is_pydata_sparse
324
+ is_pydata_sparse_array
292
325
293
326
"""
294
327
if use_compat not in [None , True , False ]:
@@ -307,12 +340,9 @@ def your_function(x, y):
307
340
elif use_compat is False :
308
341
namespaces .add (np )
309
342
else :
310
- # numpy 2.0 has __array_namespace__ and is fully array API
343
+ # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
311
344
# compatible.
312
- if hasattr (x , '__array_namespace__' ):
313
- namespaces .add (x .__array_namespace__ (api_version = api_version ))
314
- else :
315
- namespaces .add (numpy_namespace )
345
+ namespaces .add (numpy_namespace )
316
346
elif is_cupy_array (x ):
317
347
if _use_compat :
318
348
_check_api_version (api_version )
@@ -344,11 +374,15 @@ def your_function(x, y):
344
374
elif use_compat is False :
345
375
import jax .numpy as jnp
346
376
else :
347
- # jax.experimental.array_api is already an array namespace. We do
348
- # not have a wrapper submodule for it.
349
- import jax .experimental .array_api as jnp
377
+ # JAX v0.4.32 and newer implements the array API directly in jax.numpy.
378
+ # For older JAX versions, it is available via jax.experimental.array_api.
379
+ import jax .numpy
380
+ if hasattr (jax .numpy , "__array_api_version__" ):
381
+ jnp = jax .numpy
382
+ else :
383
+ import jax .experimental .array_api as jnp
350
384
namespaces .add (jnp )
351
- elif is_pydata_sparse (x ):
385
+ elif is_pydata_sparse_array (x ):
352
386
if use_compat is True :
353
387
_check_api_version (api_version )
354
388
raise ValueError ("`sparse` does not have an array-api-compat wrapper" )
@@ -451,7 +485,7 @@ def device(x: Array, /) -> Device:
451
485
return x .device ()
452
486
else :
453
487
return x .device
454
- elif is_pydata_sparse (x ):
488
+ elif is_pydata_sparse_array (x ):
455
489
# `sparse` will gain `.device`, so check for this first.
456
490
x_device = getattr (x , 'device' , None )
457
491
if x_device is not None :
@@ -580,10 +614,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
580
614
return x
581
615
raise ValueError (f"Unsupported device { device !r} " )
582
616
elif is_jax_array (x ):
583
- # This import adds to_device to x
584
- import jax .experimental .array_api # noqa: F401
617
+ if not hasattr (x , "__array_namespace__" ):
618
+ # In JAX v0.4.31 and older, this import adds to_device method to x.
619
+ import jax .experimental .array_api # noqa: F401
585
620
return x .to_device (device , stream = stream )
586
- elif is_pydata_sparse (x ) and device == _device (x ):
621
+ elif is_pydata_sparse_array (x ) and device == _device (x ):
587
622
# Perform trivial check to return the same array if
588
623
# device is same instead of err-ing.
589
624
return x
@@ -613,7 +648,8 @@ def size(x):
613
648
"is_jax_array" ,
614
649
"is_numpy_array" ,
615
650
"is_torch_array" ,
616
- "is_pydata_sparse" ,
651
+ "is_ndonnx_array" ,
652
+ "is_pydata_sparse_array" ,
617
653
"size" ,
618
654
"to_device" ,
619
655
]
0 commit comments