Skip to content

Commit cb6a3ec

Browse files
authored
Merge pull request #232 from crusaderky/ndonnx
ENH: ndonnx device() support; TST: better ndonnx test coverage
2 parents 7948ac0 + 8434019 commit cb6a3ec

File tree

5 files changed

+43
-14
lines changed

5 files changed

+43
-14
lines changed

array_api_compat/common/_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
776776
device : Hardware device the array data resides on.
777777
778778
"""
779-
if is_numpy_array(x):
779+
if is_numpy_array(x) or is_ndonnx_array(x):
780780
if stream is not None:
781781
raise ValueError("The stream argument to to_device() is not supported")
782782
if device == 'cpu':

docs/supported-array-libraries.md

+5
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ The minimum supported Dask version is 2023.12.0.
138138

139139
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
140140

141+
(ndonnx-support)=
142+
## [ndonnx](https://github.com/quantco/ndonnx)
143+
144+
Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`.
145+
141146
(array-api-strict-support)=
142147
## [array-api-strict](https://data-apis.org/array-api-strict/)
143148

tests/_helpers.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import pytest
44

55
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
6-
all_libraries = wrapped_libraries + ["array_api_strict", "jax.numpy", "sparse"]
7-
6+
all_libraries = wrapped_libraries + [
7+
"array_api_strict", "jax.numpy", "ndonnx", "sparse"
8+
]
89

910
def import_(library, wrapper=False):
10-
if library == 'cupy':
11+
if library in ('cupy', 'ndonnx'):
1112
pytest.importorskip(library)
1213
if wrapper:
1314
if 'jax' in library:
@@ -20,3 +21,14 @@ def import_(library, wrapper=False):
2021
library = 'array_api_compat.' + library
2122

2223
return import_module(library)
24+
25+
26+
def xfail(request: pytest.FixtureRequest, reason: str) -> None:
27+
"""
28+
XFAIL the currently running test.
29+
30+
Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
31+
halting it, so that it may result in a XPASS.
32+
xref https://github.com/pandas-dev/pandas/issues/38902
33+
"""
34+
request.node.add_marker(pytest.mark.xfail(reason=reason))

tests/test_array_namespace.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def test_array_namespace(library, api_version, use_compat):
2222
if use_compat and library not in wrapped_libraries:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
25+
if library == "ndonnx" and api_version in ("2021.12", "2022.12"):
26+
pytest.skip("Unsupported API version")
27+
2528
namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)
2629

2730
if use_compat is False or use_compat is None and library not in wrapped_libraries:

tests/test_common.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88
from array_api_compat import ( # noqa: F401
99
is_numpy_array, is_cupy_array, is_torch_array,
1010
is_dask_array, is_jax_array, is_pydata_sparse_array,
11+
is_ndonnx_array,
1112
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
1213
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
13-
is_array_api_strict_namespace,
14+
is_array_api_strict_namespace, is_ndonnx_namespace,
1415
)
1516

1617
from array_api_compat import (
1718
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
1819
)
19-
from ._helpers import import_, wrapped_libraries, all_libraries
20+
from ._helpers import all_libraries, import_, wrapped_libraries, xfail
21+
2022

2123
is_array_functions = {
2224
'numpy': 'is_numpy_array',
@@ -25,6 +27,7 @@
2527
'dask.array': 'is_dask_array',
2628
'jax.numpy': 'is_jax_array',
2729
'sparse': 'is_pydata_sparse_array',
30+
'ndonnx': 'is_ndonnx_array',
2831
}
2932

3033
is_namespace_functions = {
@@ -35,6 +38,7 @@
3538
'jax.numpy': 'is_jax_namespace',
3639
'sparse': 'is_pydata_sparse_namespace',
3740
'array_api_strict': 'is_array_api_strict_namespace',
41+
'ndonnx': 'is_ndonnx_namespace',
3842
}
3943

4044

@@ -185,7 +189,10 @@ class C:
185189

186190

187191
@pytest.mark.parametrize("library", all_libraries)
188-
def test_device(library):
192+
def test_device(library, request):
193+
if library == "ndonnx":
194+
xfail(request, reason="Needs ndonnx >=0.9.4")
195+
189196
xp = import_(library, wrapper=True)
190197

191198
# We can't test much for device() and to_device() other than that
@@ -223,17 +230,19 @@ def test_to_device_host(library):
223230
@pytest.mark.parametrize("target_library", is_array_functions.keys())
224231
@pytest.mark.parametrize("source_library", is_array_functions.keys())
225232
def test_asarray_cross_library(source_library, target_library, request):
226-
def _xfail(reason: str) -> None:
227-
# Allow rest of test to execute instead of immediately xfailing
228-
# xref https://github.com/pandas-dev/pandas/issues/38902
229-
request.node.add_marker(pytest.mark.xfail(reason=reason))
230-
231233
if source_library == "dask.array" and target_library == "torch":
232234
# TODO: remove xfail once
233235
# https://github.com/dask/dask/issues/8260 is resolved
234-
_xfail(reason="Bug in dask raising error on conversion")
236+
xfail(request, reason="Bug in dask raising error on conversion")
237+
elif (
238+
source_library == "ndonnx"
239+
and target_library not in ("array_api_strict", "ndonnx", "numpy")
240+
):
241+
xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown")
242+
elif source_library == "ndonnx" and target_library == "numpy":
243+
xfail(request, reason="produces numpy array of ndonnx scalar arrays")
235244
elif source_library == "jax.numpy" and target_library == "torch":
236-
_xfail(reason="casts int to float")
245+
xfail(request, reason="casts int to float")
237246
elif source_library == "cupy" and target_library != "cupy":
238247
# cupy explicitly disallows implicit conversions to CPU
239248
pytest.skip(reason="cupy does not support implicit conversion to CPU")

0 commit comments

Comments
 (0)