Skip to content

Commit 9b1f5f7

Browse files
committed
Rename is_pydata_sparse to is_pydata_sparse_array
1 parent 44bf2af commit 9b1f5f7

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

array_api_compat/common/_helpers.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def is_numpy_array(x):
5050
is_torch_array
5151
is_dask_array
5252
is_jax_array
53-
is_pydata_sparse
53+
is_pydata_sparse_array
5454
"""
5555
# Avoid importing NumPy if it isn't already
5656
if 'numpy' not in sys.modules:
@@ -80,7 +80,7 @@ def is_cupy_array(x):
8080
is_torch_array
8181
is_dask_array
8282
is_jax_array
83-
is_pydata_sparse
83+
is_pydata_sparse_array
8484
"""
8585
# Avoid importing NumPy if it isn't already
8686
if 'cupy' not in sys.modules:
@@ -107,7 +107,7 @@ def is_torch_array(x):
107107
is_cupy_array
108108
is_dask_array
109109
is_jax_array
110-
is_pydata_sparse
110+
is_pydata_sparse_array
111111
"""
112112
# Avoid importing torch if it isn't already
113113
if 'torch' not in sys.modules:
@@ -134,7 +134,7 @@ def is_dask_array(x):
134134
is_cupy_array
135135
is_torch_array
136136
is_jax_array
137-
is_pydata_sparse
137+
is_pydata_sparse_array
138138
"""
139139
# Avoid importing dask if it isn't already
140140
if 'dask.array' not in sys.modules:
@@ -161,7 +161,7 @@ def is_jax_array(x):
161161
is_cupy_array
162162
is_torch_array
163163
is_dask_array
164-
is_pydata_sparse
164+
is_pydata_sparse_array
165165
"""
166166
# Avoid importing jax if it isn't already
167167
if 'jax' not in sys.modules:
@@ -172,7 +172,7 @@ def is_jax_array(x):
172172
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
173173

174174

175-
def is_pydata_sparse(x) -> bool:
175+
def is_pydata_sparse_array(x) -> bool:
176176
"""
177177
Return True if `x` is an array from the `sparse` package.
178178
@@ -219,7 +219,7 @@ def is_array_api_obj(x):
219219
or is_torch_array(x) \
220220
or is_dask_array(x) \
221221
or is_jax_array(x) \
222-
or is_pydata_sparse(x) \
222+
or is_pydata_sparse_array(x) \
223223
or hasattr(x, '__array_namespace__')
224224

225225
def _check_api_version(api_version):
@@ -288,7 +288,7 @@ def your_function(x, y):
288288
is_torch_array
289289
is_dask_array
290290
is_jax_array
291-
is_pydata_sparse
291+
is_pydata_sparse_array
292292
293293
"""
294294
if use_compat not in [None, True, False]:
@@ -348,7 +348,7 @@ def your_function(x, y):
348348
# not have a wrapper submodule for it.
349349
import jax.experimental.array_api as jnp
350350
namespaces.add(jnp)
351-
elif is_pydata_sparse(x):
351+
elif is_pydata_sparse_array(x):
352352
if use_compat is True:
353353
_check_api_version(api_version)
354354
raise ValueError("`sparse` does not have an array-api-compat wrapper")
@@ -451,7 +451,7 @@ def device(x: Array, /) -> Device:
451451
return x.device()
452452
else:
453453
return x.device
454-
elif is_pydata_sparse(x):
454+
elif is_pydata_sparse_array(x):
455455
# `sparse` will gain `.device`, so check for this first.
456456
x_device = getattr(x, 'device', None)
457457
if x_device is not None:
@@ -583,7 +583,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
583583
# This import adds to_device to x
584584
import jax.experimental.array_api # noqa: F401
585585
return x.to_device(device, stream=stream)
586-
elif is_pydata_sparse(x) and device == _device(x):
586+
elif is_pydata_sparse_array(x) and device == _device(x):
587587
# Perform trivial check to return the same array if
588588
# device is same instead of err-ing.
589589
return x
@@ -613,7 +613,7 @@ def size(x):
613613
"is_jax_array",
614614
"is_numpy_array",
615615
"is_torch_array",
616-
"is_pydata_sparse",
616+
"is_pydata_sparse_array",
617617
"size",
618618
"to_device",
619619
]

docs/helper-functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ yet.
4949
.. autofunction:: is_torch_array
5050
.. autofunction:: is_dask_array
5151
.. autofunction:: is_jax_array
52+
.. autofunction:: is_pydata_sparse_array

tests/test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
2-
is_dask_array, is_jax_array, is_pydata_sparse)
2+
is_dask_array, is_jax_array, is_pydata_sparse_array)
33

44
from array_api_compat import is_array_api_obj, device, to_device
55

@@ -16,7 +16,7 @@
1616
'torch': 'is_torch_array',
1717
'dask.array': 'is_dask_array',
1818
'jax.numpy': 'is_jax_array',
19-
'sparse': 'is_pydata_sparse',
19+
'sparse': 'is_pydata_sparse_array',
2020
}
2121

2222
@pytest.mark.parametrize('library', is_functions.keys())

0 commit comments

Comments
 (0)