Skip to content

Commit 88abd2b

Browse files
committed
WIP: Add sparse compatibility layer.
1 parent 376038e commit 88abd2b

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

array_api_compat/common/_helpers.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def is_numpy_array(x):
5050
is_torch_array
5151
is_dask_array
5252
is_jax_array
53+
is_sparse_array
5354
"""
5455
# Avoid importing NumPy if it isn't already
5556
if 'numpy' not in sys.modules:
@@ -79,6 +80,7 @@ def is_cupy_array(x):
7980
is_torch_array
8081
is_dask_array
8182
is_jax_array
83+
is_sparse_array
8284
"""
8385
# Avoid importing NumPy if it isn't already
8486
if 'cupy' not in sys.modules:
@@ -105,6 +107,7 @@ def is_torch_array(x):
105107
is_cupy_array
106108
is_dask_array
107109
is_jax_array
110+
is_sparse_array
108111
"""
109112
# Avoid importing torch if it isn't already
110113
if 'torch' not in sys.modules:
@@ -131,6 +134,7 @@ def is_dask_array(x):
131134
is_cupy_array
132135
is_torch_array
133136
is_jax_array
137+
is_sparse_array
134138
"""
135139
# Avoid importing dask if it isn't already
136140
if 'dask.array' not in sys.modules:
@@ -157,6 +161,7 @@ def is_jax_array(x):
157161
is_cupy_array
158162
is_torch_array
159163
is_dask_array
164+
is_sparse_array
160165
"""
161166
# Avoid importing jax if it isn't already
162167
if 'jax' not in sys.modules:
@@ -166,6 +171,35 @@ def is_jax_array(x):
166171

167172
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
168173

174+
175+
def is_sparse_array(x) -> bool:
176+
"""
177+
Return True if `x` is a `sparse` array.
178+
179+
This function does not import `sparse` if it has not already been imported
180+
and is therefore cheap to use.
181+
182+
183+
See Also
184+
--------
185+
186+
array_namespace
187+
is_array_api_obj
188+
is_numpy_array
189+
is_cupy_array
190+
is_torch_array
191+
is_dask_array
192+
is_jax_array
193+
"""
194+
# Avoid importing jax if it isn't already
195+
if 'sparse' not in sys.modules:
196+
return False
197+
198+
import sparse
199+
200+
# TODO: Account for other backends.
201+
return isinstance(x, sparse.SparseArray)
202+
169203
def is_array_api_obj(x):
170204
"""
171205
Return True if `x` is an array API compatible array object.
@@ -185,6 +219,7 @@ def is_array_api_obj(x):
185219
or is_torch_array(x) \
186220
or is_dask_array(x) \
187221
or is_jax_array(x) \
222+
or is_sparse_array(x) \
188223
or hasattr(x, '__array_namespace__')
189224

190225
def _check_api_version(api_version):
@@ -253,6 +288,7 @@ def your_function(x, y):
253288
is_torch_array
254289
is_dask_array
255290
is_jax_array
291+
is_sparse_array
256292
257293
"""
258294
if use_compat not in [None, True, False]:
@@ -312,6 +348,13 @@ def your_function(x, y):
312348
# not have a wrapper submodule for it.
313349
import jax.experimental.array_api as jnp
314350
namespaces.add(jnp)
351+
elif is_sparse_array(x):
352+
if use_compat is True:
353+
_check_api_version(api_version)
354+
raise ValueError("`sparse` does not have an array-api-compat wrapper")
355+
else:
356+
import sparse
357+
namespaces.add(sparse)
315358
elif hasattr(x, '__array_namespace__'):
316359
if use_compat is True:
317360
raise ValueError("The given array does not have an array-api-compat wrapper")
@@ -406,8 +449,23 @@ def device(x: Array, /) -> Device:
406449
return x.device()
407450
else:
408451
return x.device
452+
elif is_sparse_array(x):
453+
# `sparse` will gain `.device`, so check for this first.
454+
x_device = getattr(x, 'device', None)
455+
if x_device is not None:
456+
return x_device
457+
# Everything but DOK has this attr.
458+
try:
459+
inner = x.data
460+
except AttributeError:
461+
return "cpu"
462+
# Return the device of the constituent array
463+
return device(inner)
409464
return x.device
410465

466+
# Prevent shadowing, used below
467+
_device = device
468+
411469
# Based on cupy.array_api.Array.to_device
412470
def _cupy_to_device(x, device, /, stream=None):
413471
import cupy as cp
@@ -523,6 +581,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
523581
# This import adds to_device to x
524582
import jax.experimental.array_api # noqa: F401
525583
return x.to_device(device, stream=stream)
584+
elif is_sparse_array(x) and device == _device(x):
585+
# Perform trivial check to return the same array if
586+
# device is same instead of err-ing.
587+
return x
526588
return x.to_device(device, stream=stream)
527589

528590
def size(x):
@@ -549,6 +611,7 @@ def size(x):
549611
"is_jax_array",
550612
"is_numpy_array",
551613
"is_torch_array",
614+
"is_sparse_array",
552615
"size",
553616
"to_device",
554617
]

array_api_compat/sparse/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from sparse import * # noqa: F403
2+
from ..common._aliases import * # noqa: F403
3+
from ..common._helpers import * # noqa: F401,F403
4+
5+
__array_api_version__ = '2022.12'

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ jax[cpu]
44
numpy
55
pytest
66
torch
7+
sparse >=0.15.1

0 commit comments

Comments
 (0)