Skip to content

Commit 27692be

Browse files
committed
Add tests
1 parent 8c31248 commit 27692be

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

tests/test_common.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from array_api_compat import (
1818
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
1919
)
20+
from array_api_compat.common._helpers import _DASK_DEVICE
2021
from ._helpers import all_libraries, import_, wrapped_libraries, xfail
2122

2223

@@ -189,23 +190,26 @@ class C:
189190

190191

191192
@pytest.mark.parametrize("library", all_libraries)
192-
def test_device(library, request):
193+
def test_device_to_device(library, request):
193194
if library == "ndonnx":
194-
xfail(request, reason="Needs ndonnx >=0.9.4")
195+
xfail(request, reason="Stub raises ValueError")
196+
if library == "sparse":
197+
xfail(request, reason="No __array_namespace_info__()")
195198

196199
xp = import_(library, wrapper=True)
200+
devices = xp.__array_namespace_info__().devices()
197201

198-
# We can't test much for device() and to_device() other than that
199-
# x.to_device(x.device) works.
200-
202+
# Default device
201203
x = xp.asarray([1, 2, 3])
202204
dev = device(x)
203205

204-
x2 = to_device(x, dev)
205-
assert device(x2) == device(x)
206-
207-
x3 = xp.asarray(x, device=dev)
208-
assert device(x3) == device(x)
206+
for dev in devices:
207+
if dev is None: # JAX >=0.5.3
208+
continue
209+
if dev is _DASK_DEVICE: # TODO this needs a better design
210+
continue
211+
y = to_device(x, dev)
212+
assert device(y) == dev
209213

210214

211215
@pytest.mark.parametrize("library", wrapped_libraries)

tests/test_cupy.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
from array_api_compat import device, to_device
3+
4+
xp = pytest.importorskip("array_api_compat.cupy")
5+
from cupy.cuda import Stream
6+
7+
8+
def test_to_device_with_stream():
9+
devices = xp.__array_namespace_info__().devices()
10+
streams = [
11+
Stream(),
12+
Stream(non_blocking=True),
13+
Stream(null=True),
14+
Stream(ptds=True),
15+
123, # lapack stream
16+
]
17+
18+
a = xp.asarray([1, 2, 3])
19+
for dev in devices:
20+
for stream in streams:
21+
b = to_device(a, dev, stream=stream)
22+
assert device(b) == dev

0 commit comments

Comments
 (0)