Skip to content

Commit af9d538

Browse files
authored
MAINT: A few updates to the array_api (#20066)
* Allow casting in the array API asarray() * Restrict multidimensional indexing in the array API namespace The spec has recently been updated to only require multiaxis (i.e., tuple) indices in the case where every axis is indexed, meaning there are either as many indices as axes or the index has an ellipsis. * Fix type promotion for numpy.array_api.where where does value-based promotion for 0-dimensional arrays, so we use the same trick as in the Array operators to avoid this. * Print empty array_api arrays using empty() Printing behavior isn't required by the spec. This is just to make things easier to understand, especially with the array API test suite. * Fix an incorrect slice bounds guard in the array API * Disallow multiple different dtypes in the input to np.array_api.meshgrid * Remove DLPack support from numpy.array_api.asarray() from_dlpack() should be used to create arrays using DLPack. * Remove __len__ from the array API array object * Add astype() to numpy.array_api * Update the unique_* functions in numpy.array_api unique() in the array API was replaced with three separate functions, unique_all(), unique_inverse(), and unique_values(), in order to avoid polymorphic return types. Additionally, it should be noted that these functions to not currently conform to the spec with respect to NaN behavior. The spec requires multiple NaNs to be returned, but np.unique() returns a single NaN. Since this is currently an open issue in NumPy to possibly revert, I have not yet worked around this. See numpy/numpy#20326. * Add the stream argument to the array API to_device method This does nothing in NumPy, and is just present so that the signature is valid according to the spec. * Use the NamedTuple classes for the type signatures * Add unique_counts to the array API namespace * Remove some unused imports * Update the array_api indexing restrictions The "multiaxis indexing must index every axis explicitly or use an ellipsis" was supposed to include any type of index, not just tuple indices. * Use a simpler type annotation for the array API to_device method * Fix a test failure in the array_api submodule The array_api cannot use the NumPy testing functions because array_api arrays do not mix with NumPy arrays, and also NumPy testing functions may use APIs that aren't supported in the array API. * Add dlpack support to the array_api submodule Original NumPy Commit: ff2e2a1e7eea29d925063b13922e096d14331222
1 parent 10ae3ea commit af9d538

8 files changed

+156
-49
lines changed

array_api_strict/__init__.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
empty,
137137
empty_like,
138138
eye,
139-
_from_dlpack,
139+
from_dlpack,
140140
full,
141141
full_like,
142142
linspace,
@@ -155,7 +155,7 @@
155155
"empty",
156156
"empty_like",
157157
"eye",
158-
"_from_dlpack",
158+
"from_dlpack",
159159
"full",
160160
"full_like",
161161
"linspace",
@@ -169,6 +169,7 @@
169169
]
170170

171171
from ._data_type_functions import (
172+
astype,
172173
broadcast_arrays,
173174
broadcast_to,
174175
can_cast,
@@ -178,6 +179,7 @@
178179
)
179180

180181
__all__ += [
182+
"astype",
181183
"broadcast_arrays",
182184
"broadcast_to",
183185
"can_cast",
@@ -358,9 +360,9 @@
358360

359361
__all__ += ["argmax", "argmin", "nonzero", "where"]
360362

361-
from ._set_functions import unique
363+
from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
362364

363-
__all__ += ["unique"]
365+
__all__ += ["unique_all", "unique_counts", "unique_inverse", "unique_values"]
364366

365367
from ._sorting_functions import argsort, sort
366368

array_api_strict/_array_object.py

+31-17
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
3333

3434
if TYPE_CHECKING:
35-
from ._typing import PyCapsule, Device, Dtype
35+
from ._typing import Any, PyCapsule, Device, Dtype
3636

3737
import numpy as np
3838

@@ -99,9 +99,13 @@ def __repr__(self: Array, /) -> str:
9999
"""
100100
Performs the operation __repr__.
101101
"""
102-
prefix = "Array("
103102
suffix = f", dtype={self.dtype.name})"
104-
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
103+
if 0 in self.shape:
104+
prefix = "empty("
105+
mid = str(self.shape)
106+
else:
107+
prefix = "Array("
108+
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
105109
return prefix + mid + suffix
106110

107111
# These are various helper functions to make the array behavior match the
@@ -244,6 +248,10 @@ def _validate_index(key, shape):
244248
The following cases are allowed by NumPy, but not specified by the array
245249
API specification:
246250
251+
- Indices to not include an implicit ellipsis at the end. That is,
252+
every axis of an array must be explicitly indexed or an ellipsis
253+
included.
254+
247255
- The start and stop of a slice may not be out of bounds. In
248256
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
249257
following are allowed:
@@ -270,14 +278,18 @@ def _validate_index(key, shape):
270278
return key
271279
if shape == ():
272280
return key
281+
if len(shape) > 1:
282+
raise IndexError(
283+
"Multidimensional arrays must include an index for every axis or use an ellipsis"
284+
)
273285
size = shape[0]
274286
# Ensure invalid slice entries are passed through.
275287
if key.start is not None:
276288
try:
277289
operator.index(key.start)
278290
except TypeError:
279291
return key
280-
if not (-size <= key.start <= max(0, size - 1)):
292+
if not (-size <= key.start <= size):
281293
raise IndexError(
282294
"Slices with out-of-bounds start are not allowed in the array API namespace"
283295
)
@@ -322,6 +334,10 @@ def _validate_index(key, shape):
322334
zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
323335
):
324336
Array._validate_index(idx, (size,))
337+
if n_ellipsis == 0 and len(key) < len(shape):
338+
raise IndexError(
339+
"Multidimensional arrays must include an index for every axis or use an ellipsis"
340+
)
325341
return key
326342
elif isinstance(key, bool):
327343
return key
@@ -339,7 +355,12 @@ def _validate_index(key, shape):
339355
"newaxis indices are not allowed in the array API namespace"
340356
)
341357
try:
342-
return operator.index(key)
358+
key = operator.index(key)
359+
if shape is not None and len(shape) > 1:
360+
raise IndexError(
361+
"Multidimensional arrays must include an index for every axis or use an ellipsis"
362+
)
363+
return key
343364
except TypeError:
344365
# Note: This also omits boolean arrays that are not already in
345366
# Array() form, like a list of booleans.
@@ -403,16 +424,14 @@ def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
403424
"""
404425
Performs the operation __dlpack__.
405426
"""
406-
res = self._array.__dlpack__(stream=stream)
407-
return self.__class__._new(res)
427+
return self._array.__dlpack__(stream=stream)
408428

409429
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
410430
"""
411431
Performs the operation __dlpack_device__.
412432
"""
413433
# Note: device support is required for this
414-
res = self._array.__dlpack_device__()
415-
return self.__class__._new(res)
434+
return self._array.__dlpack_device__()
416435

417436
def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
418437
"""
@@ -527,13 +546,6 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
527546
res = self._array.__le__(other._array)
528547
return self.__class__._new(res)
529548

530-
# Note: __len__ may end up being removed from the array API spec.
531-
def __len__(self, /) -> int:
532-
"""
533-
Performs the operation __len__.
534-
"""
535-
return self._array.__len__()
536-
537549
def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
538550
"""
539551
Performs the operation __lshift__.
@@ -995,7 +1007,9 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
9951007
res = self._array.__rxor__(other._array)
9961008
return self.__class__._new(res)
9971009

998-
def to_device(self: Array, device: Device, /) -> Array:
1010+
def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
1011+
if stream is not None:
1012+
raise ValueError("The stream argument to to_device() is not supported")
9991013
if device == 'cpu':
10001014
return self
10011015
raise ValueError(f"Unsupported device {device!r}")

array_api_strict/_creation_functions.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
Device,
1010
Dtype,
1111
NestedSequence,
12-
SupportsDLPack,
1312
SupportsBufferProtocol,
1413
)
1514
from collections.abc import Sequence
@@ -36,7 +35,6 @@ def asarray(
3635
int,
3736
float,
3837
NestedSequence[bool | int | float],
39-
SupportsDLPack,
4038
SupportsBufferProtocol,
4139
],
4240
/,
@@ -60,7 +58,9 @@ def asarray(
6058
if copy is False:
6159
# Note: copy=False is not yet implemented in np.asarray
6260
raise NotImplementedError("copy=False is not yet implemented")
63-
if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype):
61+
if isinstance(obj, Array):
62+
if dtype is not None and obj.dtype != dtype:
63+
copy = True
6464
if copy is True:
6565
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
6666
return obj
@@ -151,9 +151,10 @@ def eye(
151151
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
152152

153153

154-
def _from_dlpack(x: object, /) -> Array:
155-
# Note: dlpack support is not yet implemented on Array
156-
raise NotImplementedError("DLPack support is not yet implemented")
154+
def from_dlpack(x: object, /) -> Array:
155+
from ._array_object import Array
156+
157+
return Array._new(np._from_dlpack(x))
157158

158159

159160
def full(
@@ -240,6 +241,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
240241
"""
241242
from ._array_object import Array
242243

244+
# Note: unlike np.meshgrid, only inputs with all the same dtype are
245+
# allowed
246+
247+
if len({a.dtype for a in arrays}) > 1:
248+
raise ValueError("meshgrid inputs must all have the same dtype")
249+
243250
return [
244251
Array._new(array)
245252
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)

array_api_strict/_data_type_functions.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
import numpy as np
1414

1515

16+
# Note: astype is a function, not an array method as in NumPy.
17+
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
18+
if not copy and dtype == x.dtype:
19+
return x
20+
return Array._new(x._array.astype(dtype=dtype, copy=copy))
21+
22+
1623
def broadcast_arrays(*arrays: Array) -> List[Array]:
1724
"""
1825
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.

array_api_strict/_searching_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,5 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
4343
"""
4444
# Call result type here just to raise on disallowed type combinations
4545
_result_type(x1.dtype, x2.dtype)
46+
x1, x2 = Array._normalize_two_args(x1, x2)
4647
return Array._new(np.where(condition._array, x1._array, x2._array))

array_api_strict/_set_functions.py

+75-14
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,91 @@
22

33
from ._array_object import Array
44

5-
from typing import Tuple, Union
5+
from typing import NamedTuple
66

77
import numpy as np
88

9+
# Note: np.unique() is split into four functions in the array API:
10+
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
11+
# to remove polymorphic return types).
912

10-
def unique(
11-
x: Array,
12-
/,
13-
*,
14-
return_counts: bool = False,
15-
return_index: bool = False,
16-
return_inverse: bool = False,
17-
) -> Union[Array, Tuple[Array, ...]]:
13+
# Note: The various unique() functions are supposed to return multiple NaNs.
14+
# This does not match the NumPy behavior, however, this is currently left as a
15+
# TODO in this implementation as this behavior may be reverted in np.unique().
16+
# See https://github.com/numpy/numpy/issues/20326.
17+
18+
# Note: The functions here return a namedtuple (np.unique() returns a normal
19+
# tuple).
20+
21+
class UniqueAllResult(NamedTuple):
22+
values: Array
23+
indices: Array
24+
inverse_indices: Array
25+
counts: Array
26+
27+
28+
class UniqueCountsResult(NamedTuple):
29+
values: Array
30+
counts: Array
31+
32+
33+
class UniqueInverseResult(NamedTuple):
34+
values: Array
35+
inverse_indices: Array
36+
37+
38+
def unique_all(x: Array, /) -> UniqueAllResult:
39+
"""
40+
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
41+
42+
See its docstring for more information.
43+
"""
44+
res = np.unique(
45+
x._array,
46+
return_counts=True,
47+
return_index=True,
48+
return_inverse=True,
49+
)
50+
51+
return UniqueAllResult(*[Array._new(i) for i in res])
52+
53+
54+
def unique_counts(x: Array, /) -> UniqueCountsResult:
55+
res = np.unique(
56+
x._array,
57+
return_counts=True,
58+
return_index=False,
59+
return_inverse=False,
60+
)
61+
62+
return UniqueCountsResult(*[Array._new(i) for i in res])
63+
64+
65+
def unique_inverse(x: Array, /) -> UniqueInverseResult:
66+
"""
67+
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
68+
69+
See its docstring for more information.
70+
"""
71+
res = np.unique(
72+
x._array,
73+
return_counts=False,
74+
return_index=False,
75+
return_inverse=True,
76+
)
77+
return UniqueInverseResult(*[Array._new(i) for i in res])
78+
79+
80+
def unique_values(x: Array, /) -> Array:
1881
"""
1982
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
2083
2184
See its docstring for more information.
2285
"""
2386
res = np.unique(
2487
x._array,
25-
return_counts=return_counts,
26-
return_index=return_index,
27-
return_inverse=return_inverse,
88+
return_counts=False,
89+
return_index=False,
90+
return_inverse=False,
2891
)
29-
if isinstance(res, tuple):
30-
return tuple(Array._new(i) for i in res)
3192
return Array._new(res)

0 commit comments

Comments
 (0)