diff --git a/tests/test_cupy.py b/tests/test_cupy.py index f8b4a4d8..4745b983 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -5,18 +5,41 @@ from cupy.cuda import Stream -def test_to_device_with_stream(): +@pytest.mark.parametrize( + "make_stream", + [ + lambda: Stream(), + lambda: Stream(non_blocking=True), + lambda: Stream(null=True), + lambda: Stream(ptds=True), + ], +) +def test_to_device_with_stream(make_stream): devices = xp.__array_namespace_info__().devices() - streams = [ - Stream(), - Stream(non_blocking=True), - Stream(null=True), - Stream(ptds=True), - 123, # dlpack stream - ] a = xp.asarray([1, 2, 3]) for dev in devices: - for stream in streams: - b = to_device(a, dev, stream=stream) - assert device(b) == dev + # Streams are device-specific and must be created within + # the context of the device... + with dev: + stream = make_stream() + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=stream) + assert device(b) == dev + + +def test_to_device_with_dlpack_stream(): + devices = xp.__array_namespace_info__().devices() + + a = xp.asarray([1, 2, 3]) + for dev in devices: + # Streams are device-specific and must be created within + # the context of the device... + with dev: + s1 = Stream() + + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=s1.ptr) + assert device(b) == dev