Skip to content

Commit 855756d

Browse files
committed
Modify testsuite to add sparse support.
1 parent d3c6636 commit 855756d

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

tests/_helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
wrapped_libraries = ["cupy", "torch", "dask.array"]
6-
all_libraries = wrapped_libraries + ["numpy", "jax.numpy"]
6+
all_libraries = wrapped_libraries + ["numpy", "jax.numpy", "sparse"]
77
import numpy as np
88
if np.__version__[0] == '1':
99
wrapped_libraries.append("numpy")
@@ -14,6 +14,8 @@ def import_(library, wrapper=False):
1414
if wrapper:
1515
if 'jax' in library:
1616
library = 'jax.experimental.array_api'
17+
elif library.startswith('sparse'):
18+
library = 'sparse'
1719
else:
1820
library = 'array_api_compat.' + library
1921

tests/test_array_namespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_array_namespace(library, api_version, use_compat):
1919
xp = import_(library)
2020

2121
array = xp.asarray([1.0, 2.0, 3.0])
22-
if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
22+
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
2525
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)

tests/test_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
2-
is_dask_array, is_jax_array)
2+
is_dask_array, is_jax_array, is_pydata_sparse)
33

44
from array_api_compat import is_array_api_obj, device, to_device
55

@@ -16,6 +16,7 @@
1616
'torch': 'is_torch_array',
1717
'dask.array': 'is_dask_array',
1818
'jax.numpy': 'is_jax_array',
19+
'sparse': 'is_pydata_sparse',
1920
}
2021

2122
@pytest.mark.parametrize('library', is_functions.keys())
@@ -76,6 +77,8 @@ def test_asarray_cross_library(source_library, target_library, request):
7677
if source_library == "cupy" and target_library != "cupy":
7778
# cupy explicitly disallows implicit conversions to CPU
7879
pytest.skip(reason="cupy does not support implicit conversion to CPU")
80+
elif source_library == "sparse" and target_library != "sparse":
81+
pytest.skip(reason="`sparse` does not allow implicit densification")
7982
src_lib = import_(source_library, wrapper=True)
8083
tgt_lib = import_(target_library, wrapper=True)
8184
is_tgt_type = globals()[is_functions[target_library]]

tests/test_no_dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _test_dependency(mod):
3333

3434
# array-api-strict is an example of an array API library that isn't
3535
# wrapped by array-api-compat.
36-
if "strict" not in mod:
36+
if "strict" not in mod and mod != "sparse":
3737
is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array")
3838
assert not is_mod_array(a)
3939
assert mod not in sys.modules
@@ -50,7 +50,7 @@ def _test_dependency(mod):
5050
# Y (except most array libraries actually do themselves depend on numpy).
5151

5252
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
53-
"jax.numpy", "array_api_strict"])
53+
"jax.numpy", "sparse", "array_api_strict"])
5454
def test_numpy_dependency(library):
5555
# This import is here because it imports numpy
5656
from ._helpers import import_

0 commit comments

Comments
 (0)