@@ -50,6 +50,7 @@ def is_numpy_array(x):
50
50
is_torch_array
51
51
is_dask_array
52
52
is_jax_array
53
+ is_sparse_array
53
54
"""
54
55
# Avoid importing NumPy if it isn't already
55
56
if 'numpy' not in sys .modules :
@@ -79,6 +80,7 @@ def is_cupy_array(x):
79
80
is_torch_array
80
81
is_dask_array
81
82
is_jax_array
83
+ is_sparse_array
82
84
"""
83
85
# Avoid importing NumPy if it isn't already
84
86
if 'cupy' not in sys .modules :
@@ -105,6 +107,7 @@ def is_torch_array(x):
105
107
is_cupy_array
106
108
is_dask_array
107
109
is_jax_array
110
+ is_sparse_array
108
111
"""
109
112
# Avoid importing torch if it isn't already
110
113
if 'torch' not in sys .modules :
@@ -131,6 +134,7 @@ def is_dask_array(x):
131
134
is_cupy_array
132
135
is_torch_array
133
136
is_jax_array
137
+ is_sparse_array
134
138
"""
135
139
# Avoid importing dask if it isn't already
136
140
if 'dask.array' not in sys .modules :
@@ -157,6 +161,7 @@ def is_jax_array(x):
157
161
is_cupy_array
158
162
is_torch_array
159
163
is_dask_array
164
+ is_sparse_array
160
165
"""
161
166
# Avoid importing jax if it isn't already
162
167
if 'jax' not in sys .modules :
@@ -166,6 +171,35 @@ def is_jax_array(x):
166
171
167
172
return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
168
173
174
+
175
+ def is_sparse_array (x ) -> bool :
176
+ """
177
+ Return True if `x` is a `sparse` array.
178
+
179
+ This function does not import `sparse` if it has not already been imported
180
+ and is therefore cheap to use.
181
+
182
+
183
+ See Also
184
+ --------
185
+
186
+ array_namespace
187
+ is_array_api_obj
188
+ is_numpy_array
189
+ is_cupy_array
190
+ is_torch_array
191
+ is_dask_array
192
+ is_jax_array
193
+ """
194
+ # Avoid importing jax if it isn't already
195
+ if 'sparse' not in sys .modules :
196
+ return False
197
+
198
+ import sparse
199
+
200
+ # TODO: Account for other backends.
201
+ return isinstance (x , sparse .SparseArray )
202
+
169
203
def is_array_api_obj (x ):
170
204
"""
171
205
Return True if `x` is an array API compatible array object.
@@ -185,6 +219,7 @@ def is_array_api_obj(x):
185
219
or is_torch_array (x ) \
186
220
or is_dask_array (x ) \
187
221
or is_jax_array (x ) \
222
+ or is_sparse_array (x ) \
188
223
or hasattr (x , '__array_namespace__' )
189
224
190
225
def _check_api_version (api_version ):
@@ -253,6 +288,7 @@ def your_function(x, y):
253
288
is_torch_array
254
289
is_dask_array
255
290
is_jax_array
291
+ is_sparse_array
256
292
257
293
"""
258
294
if use_compat not in [None , True , False ]:
@@ -312,6 +348,13 @@ def your_function(x, y):
312
348
# not have a wrapper submodule for it.
313
349
import jax .experimental .array_api as jnp
314
350
namespaces .add (jnp )
351
+ elif is_sparse_array (x ):
352
+ if use_compat is True :
353
+ _check_api_version (api_version )
354
+ raise ValueError ("`sparse` does not have an array-api-compat wrapper" )
355
+ else :
356
+ import sparse
357
+ namespaces .add (sparse )
315
358
elif hasattr (x , '__array_namespace__' ):
316
359
if use_compat is True :
317
360
raise ValueError ("The given array does not have an array-api-compat wrapper" )
@@ -406,8 +449,23 @@ def device(x: Array, /) -> Device:
406
449
return x .device ()
407
450
else :
408
451
return x .device
452
+ elif is_sparse_array (x ):
453
+ # `sparse` will gain `.device`, so check for this first.
454
+ x_device = getattr (x , 'device' , None )
455
+ if x_device is not None :
456
+ return x_device
457
+ # Everything but DOK has this attr.
458
+ try :
459
+ inner = x .data
460
+ except AttributeError :
461
+ return "cpu"
462
+ # Return the device of the constituent array
463
+ return device (inner )
409
464
return x .device
410
465
466
+ # Prevent shadowing, used below
467
+ _device = device
468
+
411
469
# Based on cupy.array_api.Array.to_device
412
470
def _cupy_to_device (x , device , / , stream = None ):
413
471
import cupy as cp
@@ -523,6 +581,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
523
581
# This import adds to_device to x
524
582
import jax .experimental .array_api # noqa: F401
525
583
return x .to_device (device , stream = stream )
584
+ elif is_sparse_array (x ) and device == _device (x ):
585
+ # Perform trivial check to return the same array if
586
+ # device is same instead of err-ing.
587
+ return x
526
588
return x .to_device (device , stream = stream )
527
589
528
590
def size (x ):
@@ -549,6 +611,7 @@ def size(x):
549
611
"is_jax_array" ,
550
612
"is_numpy_array" ,
551
613
"is_torch_array" ,
614
+ "is_sparse_array" ,
552
615
"size" ,
553
616
"to_device" ,
554
617
]
0 commit comments