Skip to content

Commit c829ef7

Browse files
committed
nits
1 parent 7c3d68c commit c829ef7

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

tests/test_cupy.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,26 @@
55
from cupy.cuda import Stream
66

77

8-
def test_to_device_with_stream():
9-
devices = xp.__array_namespace_info__().devices()
10-
streams = [
8+
@pytest.mark.parametrize(
9+
"make_stream",
10+
[
1111
lambda: Stream(),
12-
lambda: Stream(non_blocking=True),
12+
lambda: Stream(non_blocking=True),
1313
lambda: Stream(null=True),
14-
lambda: Stream(ptds=True),
14+
lambda: Stream(ptds=True),
1515
lambda: 123, # dlpack stream
16-
]
16+
],
17+
)
18+
def test_to_device_with_stream(make_stream):
19+
devices = xp.__array_namespace_info__().devices()
1720

1821
a = xp.asarray([1, 2, 3])
1922
for dev in devices:
20-
for stream_gen in streams:
21-
with dev:
22-
stream = stream_gen()
23-
b = to_device(a, dev, stream=stream)
24-
assert device(b) == dev
23+
# Streams are device-specific and must be created within
24+
# the context of the device...
25+
with dev:
26+
stream = make_stream()
27+
# ... however, to_device() does not need to be inside the
28+
# device context.
29+
b = to_device(a, dev, stream=stream)
30+
assert device(b) == dev

0 commit comments

Comments
 (0)