|
17 | 17 | from array_api_compat import (
|
18 | 18 | device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
|
19 | 19 | )
|
| 20 | +from array_api_compat.common._helpers import _DASK_DEVICE |
20 | 21 | from ._helpers import all_libraries, import_, wrapped_libraries, xfail
|
21 | 22 |
|
22 | 23 |
|
@@ -189,23 +190,26 @@ class C:
|
189 | 190 |
|
190 | 191 |
|
191 | 192 | @pytest.mark.parametrize("library", all_libraries)
|
192 |
| -def test_device(library, request): |
| 193 | +def test_device_to_device(library, request): |
193 | 194 | if library == "ndonnx":
|
194 |
| - xfail(request, reason="Needs ndonnx >=0.9.4") |
| 195 | + xfail(request, reason="Stub raises ValueError") |
| 196 | + if library == "sparse": |
| 197 | + xfail(request, reason="No __array_namespace_info__()") |
195 | 198 |
|
196 | 199 | xp = import_(library, wrapper=True)
|
| 200 | + devices = xp.__array_namespace_info__().devices() |
197 | 201 |
|
198 |
| - # We can't test much for device() and to_device() other than that |
199 |
| - # x.to_device(x.device) works. |
200 |
| - |
| 202 | + # Default device |
201 | 203 | x = xp.asarray([1, 2, 3])
|
202 | 204 | dev = device(x)
|
203 | 205 |
|
204 |
| - x2 = to_device(x, dev) |
205 |
| - assert device(x2) == device(x) |
206 |
| - |
207 |
| - x3 = xp.asarray(x, device=dev) |
208 |
| - assert device(x3) == device(x) |
| 206 | + for dev in devices: |
| 207 | + if dev is None: # JAX >=0.5.3 |
| 208 | + continue |
| 209 | + if dev is _DASK_DEVICE: # TODO this needs a better design |
| 210 | + continue |
| 211 | + y = to_device(x, dev) |
| 212 | + assert device(y) == dev |
209 | 213 |
|
210 | 214 |
|
211 | 215 | @pytest.mark.parametrize("library", wrapped_libraries)
|
|
0 commit comments