Skip to content

Commit 7c5408c

Browse files
committed
Merge branch 'main' into typ_v4
2 parents 014e20f + e600449 commit 7c5408c

File tree

5 files changed

+59
-64
lines changed

5 files changed

+59
-64
lines changed

array_api_compat/common/_helpers.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -772,41 +772,27 @@ def _cupy_to_device(
772772
stream: int | Any | None = None,
773773
) -> cp.ndarray:
774774
import cupy as cp
775-
from cupy.cuda import Device as _Device # pyright: ignore
776-
from cupy.cuda import stream as stream_module # pyright: ignore
777-
from cupy_backends.cuda.api import runtime # pyright: ignore
778775

779-
if device == x.device:
780-
return x
781-
elif device == "cpu":
776+
if device == "cpu":
782777
# allowing us to use `to_device(x, "cpu")`
783778
# is useful for portable test swapping between
784779
# host and device backends
785780
return x.get()
786-
elif not isinstance(device, _Device):
787-
raise ValueError(f"Unsupported device {device!r}")
788-
else:
789-
# see cupy/cupy#5985 for the reason how we handle device/stream here
790-
prev_device: Device = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
791-
prev_stream = None
792-
if stream is not None:
793-
prev_stream = stream_module.get_current_stream() # pyright: ignore
794-
# stream can be an int as specified in __dlpack__, or a CuPy stream
795-
if isinstance(stream, int):
796-
stream = cp.cuda.ExternalStream(stream) # pyright: ignore
797-
elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType]
798-
pass
799-
else:
800-
raise ValueError("the input stream is not recognized")
801-
stream.use() # pyright: ignore[reportUnknownMemberType]
802-
try:
803-
runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType]
804-
arr = x.copy()
805-
finally:
806-
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType]
807-
if prev_stream is not None:
808-
prev_stream.use()
809-
return arr
781+
if not isinstance(device, cp.cuda.Device):
782+
raise TypeError(f"Unsupported device type {device!r}")
783+
784+
if stream is None:
785+
with device:
786+
return cp.asarray(x)
787+
788+
# stream can be an int as specified in __dlpack__, or a CuPy stream
789+
if isinstance(stream, int):
790+
stream = cp.cuda.ExternalStream(stream)
791+
elif not isinstance(stream, cp.cuda.Stream):
792+
raise TypeError(f"Unsupported stream type {stream!r}")
793+
794+
with device, stream:
795+
return cp.asarray(x)
810796

811797

812798
def _torch_to_device(

array_api_compat/cupy/_aliases.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@
6363
finfo = get_xp(cp)(_aliases.finfo)
6464
iinfo = get_xp(cp)(_aliases.iinfo)
6565

66-
_copy_default = object()
67-
6866

6967
# asarray also adds the copy keyword, which is not present in numpy 1.0.
7068
def asarray(
@@ -83,25 +81,13 @@ def asarray(
8381
specification for more details.
8482
"""
8583
with cp.cuda.Device(device):
86-
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
87-
# in asarray in numpy/_aliases.py.
88-
if copy is not _copy_default:
89-
# A future version of CuPy will change the meaning of copy=False
90-
# to mean no-copy. We don't know for certain what version it will
91-
# be yet, so to avoid breaking that version, we use a different
92-
# default value for copy so asarray(obj) with no copy kwarg will
93-
# always do the copy-if-needed behavior.
94-
95-
# This will still need to be updated to remove the
96-
# NotImplementedError for copy=False, but at least this won't
97-
# break the default or existing behavior.
98-
if copy is None:
99-
copy = False
100-
elif copy is False:
101-
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
102-
kwargs['copy'] = copy
103-
104-
return cp.array(obj, dtype=dtype, **kwargs)
84+
if copy is None:
85+
return cp.asarray(obj, dtype=dtype, **kwargs)
86+
else:
87+
res = cp.array(obj, dtype=dtype, copy=copy, **kwargs)
88+
if not copy and res is not obj:
89+
raise ValueError("Unable to avoid copy while creating an array as requested")
90+
return res
10591

10692

10793
def astype(

cupy-xfails.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)]
1111
# testsuite bug (https://github.com/data-apis/array-api-tests/issues/172)
1212
array_api_tests/test_array_object.py::test_getitem
1313

14-
# copy=False is not yet implemented
15-
array_api_tests/test_creation_functions.py::test_asarray_arrays
16-
1714
# attributes are np.float32 instead of float
1815
# (see also https://github.com/data-apis/array-api/issues/405)
1916
array_api_tests/test_data_type_functions.py::test_finfo[float32]

tests/test_common.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from array_api_compat import (
1818
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
1919
)
20+
from array_api_compat.common._helpers import _DASK_DEVICE
2021
from ._helpers import all_libraries, import_, wrapped_libraries, xfail
2122

2223

@@ -189,23 +190,26 @@ class C:
189190

190191

191192
@pytest.mark.parametrize("library", all_libraries)
192-
def test_device(library, request):
193+
def test_device_to_device(library, request):
193194
if library == "ndonnx":
194-
xfail(request, reason="Needs ndonnx >=0.9.4")
195+
xfail(request, reason="Stub raises ValueError")
196+
if library == "sparse":
197+
xfail(request, reason="No __array_namespace_info__()")
195198

196199
xp = import_(library, wrapper=True)
200+
devices = xp.__array_namespace_info__().devices()
197201

198-
# We can't test much for device() and to_device() other than that
199-
# x.to_device(x.device) works.
200-
202+
# Default device
201203
x = xp.asarray([1, 2, 3])
202204
dev = device(x)
203205

204-
x2 = to_device(x, dev)
205-
assert device(x2) == device(x)
206-
207-
x3 = xp.asarray(x, device=dev)
208-
assert device(x3) == device(x)
206+
for dev in devices:
207+
if dev is None: # JAX >=0.5.3
208+
continue
209+
if dev is _DASK_DEVICE: # TODO this needs a better design
210+
continue
211+
y = to_device(x, dev)
212+
assert device(y) == dev
209213

210214

211215
@pytest.mark.parametrize("library", wrapped_libraries)

tests/test_cupy.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
from array_api_compat import device, to_device
3+
4+
xp = pytest.importorskip("array_api_compat.cupy")
5+
from cupy.cuda import Stream
6+
7+
8+
def test_to_device_with_stream():
9+
devices = xp.__array_namespace_info__().devices()
10+
streams = [
11+
Stream(),
12+
Stream(non_blocking=True),
13+
Stream(null=True),
14+
Stream(ptds=True),
15+
123, # dlpack stream
16+
]
17+
18+
a = xp.asarray([1, 2, 3])
19+
for dev in devices:
20+
for stream in streams:
21+
b = to_device(a, dev, stream=stream)
22+
assert device(b) == dev

0 commit comments

Comments
 (0)