Skip to content

Commit 7a85a36

Browse files
authored
Merge branch 'main' into count_nonzero_torch_tuple_axis_
2 parents df16fc2 + 2adea00 commit 7a85a36

File tree

8 files changed

+84
-68
lines changed

8 files changed

+84
-68
lines changed

array_api_compat/common/_helpers.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -775,42 +775,28 @@ def _cupy_to_device(
775775
/,
776776
stream: int | Any | None = None,
777777
) -> _CupyArray:
778-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
779-
from cupy.cuda import Device as _Device # pyright: ignore
780-
from cupy.cuda import stream as stream_module # pyright: ignore
781-
from cupy_backends.cuda.api import runtime # pyright: ignore
778+
import cupy as cp
782779

783-
if device == x.device:
784-
return x
785-
elif device == "cpu":
780+
if device == "cpu":
786781
# allowing us to use `to_device(x, "cpu")`
787782
# is useful for portable test swapping between
788783
# host and device backends
789784
return x.get()
790-
elif not isinstance(device, _Device):
791-
raise ValueError(f"Unsupported device {device!r}")
792-
else:
793-
# see cupy/cupy#5985 for the reason how we handle device/stream here
794-
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
795-
prev_stream = None
796-
if stream is not None:
797-
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore
798-
# stream can be an int as specified in __dlpack__, or a CuPy stream
799-
if isinstance(stream, int):
800-
stream = cp.cuda.ExternalStream(stream) # pyright: ignore
801-
elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType]
802-
pass
803-
else:
804-
raise ValueError("the input stream is not recognized")
805-
stream.use() # pyright: ignore[reportUnknownMemberType]
806-
try:
807-
runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType]
808-
arr = x.copy()
809-
finally:
810-
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType]
811-
if stream is not None:
812-
prev_stream.use()
813-
return arr
785+
if not isinstance(device, cp.cuda.Device):
786+
raise TypeError(f"Unsupported device type {device!r}")
787+
788+
if stream is None:
789+
with device:
790+
return cp.asarray(x)
791+
792+
# stream can be an int as specified in __dlpack__, or a CuPy stream
793+
if isinstance(stream, int):
794+
stream = cp.cuda.ExternalStream(stream)
795+
elif not isinstance(stream, cp.cuda.Stream):
796+
raise TypeError(f"Unsupported stream type {stream!r}")
797+
798+
with device, stream:
799+
return cp.asarray(x)
814800

815801

816802
def _torch_to_device(

array_api_compat/cupy/_aliases.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@
6464
finfo = get_xp(cp)(_aliases.finfo)
6565
iinfo = get_xp(cp)(_aliases.iinfo)
6666

67-
_copy_default = object()
68-
6967

7068
# asarray also adds the copy keyword, which is not present in numpy 1.0.
7169
def asarray(
@@ -79,7 +77,7 @@ def asarray(
7977
*,
8078
dtype: Optional[DType] = None,
8179
device: Optional[Device] = None,
82-
copy: Optional[bool] = _copy_default,
80+
copy: Optional[bool] = None,
8381
**kwargs,
8482
) -> Array:
8583
"""
@@ -89,25 +87,13 @@ def asarray(
8987
specification for more details.
9088
"""
9189
with cp.cuda.Device(device):
92-
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
93-
# in asarray in numpy/_aliases.py.
94-
if copy is not _copy_default:
95-
# A future version of CuPy will change the meaning of copy=False
96-
# to mean no-copy. We don't know for certain what version it will
97-
# be yet, so to avoid breaking that version, we use a different
98-
# default value for copy so asarray(obj) with no copy kwarg will
99-
# always do the copy-if-needed behavior.
100-
101-
# This will still need to be updated to remove the
102-
# NotImplementedError for copy=False, but at least this won't
103-
# break the default or existing behavior.
104-
if copy is None:
105-
copy = False
106-
elif copy is False:
107-
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
108-
kwargs['copy'] = copy
109-
110-
return cp.array(obj, dtype=dtype, **kwargs)
90+
if copy is None:
91+
return cp.asarray(obj, dtype=dtype, **kwargs)
92+
else:
93+
res = cp.array(obj, dtype=dtype, copy=copy, **kwargs)
94+
if not copy and res is not obj:
95+
raise ValueError("Unable to avoid copy while creating an array as requested")
96+
return res
11197

11298

11399
def astype(
@@ -138,6 +124,11 @@ def count_nonzero(
138124
return result
139125

140126

127+
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
128+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
129+
return cp.take_along_axis(x, indices, axis=axis)
130+
131+
141132
# These functions are completely new here. If the library already has them
142133
# (i.e., numpy 2.0), use the library version instead of our wrapper.
143134
if hasattr(cp, 'vecdot'):
@@ -159,6 +150,7 @@ def count_nonzero(
159150
'acos', 'acosh', 'asin', 'asinh', 'atan',
160151
'atan2', 'atanh', 'bitwise_left_shift',
161152
'bitwise_invert', 'bitwise_right_shift',
162-
'bool', 'concat', 'count_nonzero', 'pow', 'sign']
153+
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
154+
'take_along_axis']
163155

164156
_all_ignore = ['cp', 'get_xp']

array_api_compat/numpy/_aliases.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def count_nonzero(
140140
return result
141141

142142

143+
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
144+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
145+
return np.take_along_axis(x, indices, axis=axis)
146+
147+
143148
# These functions are completely new here. If the library already has them
144149
# (i.e., numpy 2.0), use the library version instead of our wrapper.
145150
if hasattr(np, "vecdot"):
@@ -175,6 +180,7 @@ def count_nonzero(
175180
"concat",
176181
"count_nonzero",
177182
"pow",
183+
"take_along_axis"
178184
]
179185
__all__ += _aliases.__all__
180186
_all_ignore = ["np", "get_xp"]

cupy-xfails.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)]
1111
# testsuite bug (https://github.com/data-apis/array-api-tests/issues/172)
1212
array_api_tests/test_array_object.py::test_getitem
1313

14-
# copy=False is not yet implemented
15-
array_api_tests/test_creation_functions.py::test_asarray_arrays
16-
1714
# attributes are np.float32 instead of float
1815
# (see also https://github.com/data-apis/array-api/issues/405)
1916
array_api_tests/test_data_type_functions.py::test_finfo[float32]
@@ -37,6 +34,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub
3734
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)]
3835
# floating point inaccuracy
3936
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
37+
# incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1)
38+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
39+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
40+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
41+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
4042

4143
# cupy (arg)min/max wrong with infinities
4244
# https://github.com/cupy/cupy/issues/7424
@@ -185,7 +187,6 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
185187
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
186188

187189
# 2024.12 support
188-
array_api_tests/test_signatures.py::test_func_signature[count_nonzero]
189190
array_api_tests/test_signatures.py::test_func_signature[bitwise_and]
190191
array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
191192
array_api_tests/test_signatures.py::test_func_signature[bitwise_or]

dask-xfails.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ array_api_tests/test_creation_functions.py::test_linspace
2424
# Shape mismatch
2525
array_api_tests/test_indexing_functions.py::test_take
2626

27+
# missing `take_along_axis`, https://github.com/dask/dask/issues/3663
28+
array_api_tests/test_indexing_functions.py::test_take_along_axis
29+
2730
# Array methods and attributes not already on da.Array cannot be wrapped
2831
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
2932
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]

numpy-1-26-xfails.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
5050
array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
5151
array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
5252

53+
array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
54+
5355
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
5456
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
5557
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

tests/test_common.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from array_api_compat import (
1818
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
1919
)
20+
from array_api_compat.common._helpers import _DASK_DEVICE
2021
from ._helpers import all_libraries, import_, wrapped_libraries, xfail
2122

2223

@@ -189,23 +190,26 @@ class C:
189190

190191

191192
@pytest.mark.parametrize("library", all_libraries)
192-
def test_device(library, request):
193+
def test_device_to_device(library, request):
193194
if library == "ndonnx":
194-
xfail(request, reason="Needs ndonnx >=0.9.4")
195+
xfail(request, reason="Stub raises ValueError")
196+
if library == "sparse":
197+
xfail(request, reason="No __array_namespace_info__()")
195198

196199
xp = import_(library, wrapper=True)
200+
devices = xp.__array_namespace_info__().devices()
197201

198-
# We can't test much for device() and to_device() other than that
199-
# x.to_device(x.device) works.
200-
202+
# Default device
201203
x = xp.asarray([1, 2, 3])
202204
dev = device(x)
203205

204-
x2 = to_device(x, dev)
205-
assert device(x2) == device(x)
206-
207-
x3 = xp.asarray(x, device=dev)
208-
assert device(x3) == device(x)
206+
for dev in devices:
207+
if dev is None: # JAX >=0.5.3
208+
continue
209+
if dev is _DASK_DEVICE: # TODO this needs a better design
210+
continue
211+
y = to_device(x, dev)
212+
assert device(y) == dev
209213

210214

211215
@pytest.mark.parametrize("library", wrapped_libraries)

tests/test_cupy.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
from array_api_compat import device, to_device
3+
4+
xp = pytest.importorskip("array_api_compat.cupy")
5+
from cupy.cuda import Stream
6+
7+
8+
def test_to_device_with_stream():
9+
devices = xp.__array_namespace_info__().devices()
10+
streams = [
11+
Stream(),
12+
Stream(non_blocking=True),
13+
Stream(null=True),
14+
Stream(ptds=True),
15+
123, # dlpack stream
16+
]
17+
18+
a = xp.asarray([1, 2, 3])
19+
for dev in devices:
20+
for stream in streams:
21+
b = to_device(a, dev, stream=stream)
22+
assert device(b) == dev

0 commit comments

Comments
 (0)