|
8 | 8 | from array_api_compat import ( # noqa: F401
|
9 | 9 | is_numpy_array, is_cupy_array, is_torch_array,
|
10 | 10 | is_dask_array, is_jax_array, is_pydata_sparse_array,
|
| 11 | + is_ndonnx_array, |
11 | 12 | is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
|
12 | 13 | is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
|
13 |
| - is_array_api_strict_namespace, |
| 14 | + is_array_api_strict_namespace, is_ndonnx_namespace, |
14 | 15 | )
|
15 | 16 |
|
16 | 17 | from array_api_compat import (
|
17 | 18 | device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
|
18 | 19 | )
|
19 |
| -from ._helpers import import_, wrapped_libraries, all_libraries |
| 20 | +from ._helpers import all_libraries, import_, wrapped_libraries, xfail |
| 21 | + |
20 | 22 |
|
21 | 23 | is_array_functions = {
|
22 | 24 | 'numpy': 'is_numpy_array',
|
|
25 | 27 | 'dask.array': 'is_dask_array',
|
26 | 28 | 'jax.numpy': 'is_jax_array',
|
27 | 29 | 'sparse': 'is_pydata_sparse_array',
|
| 30 | + 'ndonnx': 'is_ndonnx_array', |
28 | 31 | }
|
29 | 32 |
|
30 | 33 | is_namespace_functions = {
|
|
35 | 38 | 'jax.numpy': 'is_jax_namespace',
|
36 | 39 | 'sparse': 'is_pydata_sparse_namespace',
|
37 | 40 | 'array_api_strict': 'is_array_api_strict_namespace',
|
| 41 | + 'ndonnx': 'is_ndonnx_namespace', |
38 | 42 | }
|
39 | 43 |
|
40 | 44 |
|
@@ -185,7 +189,10 @@ class C:
|
185 | 189 |
|
186 | 190 |
|
187 | 191 | @pytest.mark.parametrize("library", all_libraries)
|
188 |
| -def test_device(library): |
| 192 | +def test_device(library, request): |
| 193 | + if library == "ndonnx": |
| 194 | + xfail(request, reason="Needs ndonnx >=0.9.4") |
| 195 | + |
189 | 196 | xp = import_(library, wrapper=True)
|
190 | 197 |
|
191 | 198 | # We can't test much for device() and to_device() other than that
|
@@ -223,17 +230,19 @@ def test_to_device_host(library):
|
223 | 230 | @pytest.mark.parametrize("target_library", is_array_functions.keys())
|
224 | 231 | @pytest.mark.parametrize("source_library", is_array_functions.keys())
|
225 | 232 | def test_asarray_cross_library(source_library, target_library, request):
|
226 |
| - def _xfail(reason: str) -> None: |
227 |
| - # Allow rest of test to execute instead of immediately xfailing |
228 |
| - # xref https://github.com/pandas-dev/pandas/issues/38902 |
229 |
| - request.node.add_marker(pytest.mark.xfail(reason=reason)) |
230 |
| - |
231 | 233 | if source_library == "dask.array" and target_library == "torch":
|
232 | 234 | # TODO: remove xfail once
|
233 | 235 | # https://github.com/dask/dask/issues/8260 is resolved
|
234 |
| - _xfail(reason="Bug in dask raising error on conversion") |
| 236 | + xfail(request, reason="Bug in dask raising error on conversion") |
| 237 | + elif ( |
| 238 | + source_library == "ndonnx" |
| 239 | + and target_library not in ("array_api_strict", "ndonnx", "numpy") |
| 240 | + ): |
| 241 | + xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") |
| 242 | + elif source_library == "ndonnx" and target_library == "numpy": |
| 243 | + xfail(request, reason="produces numpy array of ndonnx scalar arrays") |
235 | 244 | elif source_library == "jax.numpy" and target_library == "torch":
|
236 |
| - _xfail(reason="casts int to float") |
| 245 | + xfail(request, reason="casts int to float") |
237 | 246 | elif source_library == "cupy" and target_library != "cupy":
|
238 | 247 | # cupy explicitly disallows implicit conversions to CPU
|
239 | 248 | pytest.skip(reason="cupy does not support implicit conversion to CPU")
|
|
0 commit comments