From 7c3d68c47147663399cf4f23de24b9a4193d6f65 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:01:14 +0100 Subject: [PATCH 1/4] TST: fix cupy `to_device` test on multiple devices --- tests/test_cupy.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index f8b4a4d8..fb0c69e4 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -8,15 +8,17 @@ def test_to_device_with_stream(): devices = xp.__array_namespace_info__().devices() streams = [ - Stream(), - Stream(non_blocking=True), - Stream(null=True), - Stream(ptds=True), - 123, # dlpack stream + lambda: Stream(), + lambda: Stream(non_blocking=True), + lambda: Stream(null=True), + lambda: Stream(ptds=True), + lambda: 123, # dlpack stream ] a = xp.asarray([1, 2, 3]) for dev in devices: - for stream in streams: + for stream_gen in streams: + with dev: + stream = stream_gen() b = to_device(a, dev, stream=stream) assert device(b) == dev From c829ef744cb04474b8eedf520557f1ca05bb77dc Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:07:15 +0100 Subject: [PATCH 2/4] nits --- tests/test_cupy.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index fb0c69e4..5aac36f8 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -5,20 +5,26 @@ from cupy.cuda import Stream -def test_to_device_with_stream(): - devices = xp.__array_namespace_info__().devices() - streams = [ +@pytest.mark.parametrize( + "make_stream", + [ lambda: Stream(), - lambda: Stream(non_blocking=True), + lambda: Stream(non_blocking=True), lambda: Stream(null=True), - lambda: Stream(ptds=True), + lambda: Stream(ptds=True), lambda: 123, # dlpack stream - ] + ], +) +def test_to_device_with_stream(make_stream): + devices = xp.__array_namespace_info__().devices() a = xp.asarray([1, 2, 3]) for dev in devices: - for stream_gen in streams: - with dev: - stream = stream_gen() - b = to_device(a, dev, stream=stream) - assert device(b) == dev + # Streams are device-specific and must be created within + # the context of the device... + with dev: + stream = make_stream() + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=stream) + assert device(b) == dev From 0433b8e94ca802d9f6402acacb81f9c4fef6f84a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:26:53 +0100 Subject: [PATCH 3/4] skip segmentation fault --- tests/test_cupy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index 5aac36f8..8b71d978 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -12,7 +12,11 @@ lambda: Stream(non_blocking=True), lambda: Stream(null=True), lambda: Stream(ptds=True), - lambda: 123, # dlpack stream + pytest.param( + lambda: 123, + id="dlpack stream", + marks=pytest.mark.skip(reason="segmentation fault reported (#326)") + ), ], ) def test_to_device_with_stream(make_stream): From ebd3fd9356664c0502506adba96d1df72c47ec49 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:35:08 +0100 Subject: [PATCH 4/4] Use pointers --- tests/test_cupy.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index 8b71d978..4745b983 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -12,11 +12,6 @@ lambda: Stream(non_blocking=True), lambda: Stream(null=True), lambda: Stream(ptds=True), - pytest.param( - lambda: 123, - id="dlpack stream", - marks=pytest.mark.skip(reason="segmentation fault reported (#326)") - ), ], ) def test_to_device_with_stream(make_stream): @@ -32,3 +27,19 @@ def test_to_device_with_stream(make_stream): # device context. b = to_device(a, dev, stream=stream) assert device(b) == dev + + +def test_to_device_with_dlpack_stream(): + devices = xp.__array_namespace_info__().devices() + + a = xp.asarray([1, 2, 3]) + for dev in devices: + # Streams are device-specific and must be created within + # the context of the device... + with dev: + s1 = Stream() + + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=s1.ptr) + assert device(b) == dev