Skip to content

Commit 7b9da30

Browse files
committed
Merge branch 'main' into test_all
2 parents fd3cd70 + 2b5e289 commit 7b9da30

File tree

8 files changed

+67
-5
lines changed

8 files changed

+67
-5
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ def count_nonzero(
123123
return result
124124

125125

126+
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
127+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
128+
return cp.take_along_axis(x, indices, axis=axis)
129+
130+
126131
# These functions are completely new here. If the library already has them
127132
# (i.e., numpy 2.0), use the library version instead of our wrapper.
128133
if hasattr(cp, 'vecdot'):
@@ -144,7 +149,8 @@ def count_nonzero(
144149
'acos', 'acosh', 'asin', 'asinh', 'atan',
145150
'atan2', 'atanh', 'bitwise_left_shift',
146151
'bitwise_invert', 'bitwise_right_shift',
147-
'bool', 'concat', 'count_nonzero', 'pow', 'sign']
152+
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
153+
'take_along_axis']
148154

149155
def __dir__() -> list[str]:
150156
return __all__

array_api_compat/numpy/_aliases.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def count_nonzero(
139139
return result
140140

141141

142+
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
143+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
144+
return np.take_along_axis(x, indices, axis=axis)
145+
146+
142147
# These functions are completely new here. If the library already has them
143148
# (i.e., numpy 2.0), use the library version instead of our wrapper.
144149
if hasattr(np, "vecdot"):
@@ -173,6 +178,7 @@ def count_nonzero(
173178
"concat",
174179
"count_nonzero",
175180
"pow",
181+
"take_along_axis"
176182
]
177183

178184

array_api_compat/torch/_aliases.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from functools import reduce as _reduce, wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
5-
from typing import Any, List, Optional, Sequence, Tuple, Union
5+
from typing import Any, List, Optional, Sequence, Tuple, Union, Literal
66

77
import torch
88

@@ -547,8 +547,12 @@ def count_nonzero(
547547
) -> Array:
548548
result = torch.count_nonzero(x, dim=axis)
549549
if keepdims:
550-
if axis is not None:
550+
if isinstance(axis, int):
551551
return result.unsqueeze(axis)
552+
elif isinstance(axis, tuple):
553+
n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis]
554+
sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)]
555+
return torch.reshape(result, sh)
552556
return _axis_none_keepdims(result, x.ndim, keepdims)
553557
else:
554558
return result
@@ -823,6 +827,12 @@ def sign(x: Array, /) -> Array:
823827
return out
824828

825829

830+
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]:
831+
# enforce the default of 'xy'
832+
# TODO: is the return type a list or a tuple
833+
return list(torch.meshgrid(*arrays, indexing='xy'))
834+
835+
826836
__all__ = ['asarray', 'result_type', 'can_cast',
827837
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
828838
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
@@ -839,4 +849,4 @@ def sign(x: Array, /) -> Array:
839849
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
840850
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
841851
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
842-
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat']
852+
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid']

cupy-xfails.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub
3434
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)]
3535
# floating point inaccuracy
3636
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
37+
# incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1)
38+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
39+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
40+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
41+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
3742

3843
# cupy (arg)min/max wrong with infinities
3944
# https://github.com/cupy/cupy/issues/7424
@@ -182,7 +187,6 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
182187
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
183188

184189
# 2024.12 support
185-
array_api_tests/test_signatures.py::test_func_signature[count_nonzero]
186190
array_api_tests/test_signatures.py::test_func_signature[bitwise_and]
187191
array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
188192
array_api_tests/test_signatures.py::test_func_signature[bitwise_or]

dask-xfails.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ array_api_tests/test_creation_functions.py::test_linspace
2424
# Shape mismatch
2525
array_api_tests/test_indexing_functions.py::test_take
2626

27+
# missing `take_along_axis`, https://github.com/dask/dask/issues/3663
28+
array_api_tests/test_indexing_functions.py::test_take_along_axis
29+
2730
# Array methods and attributes not already on da.Array cannot be wrapped
2831
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
2932
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]

numpy-1-22-xfails.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,20 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtr
123123
array_api_tests/test_searching_functions.py::test_where
124124
array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0]
125125

126+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[add]
127+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide]
128+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
129+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract]
130+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
131+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
132+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply]
133+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum]
134+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign]
135+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
136+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
137+
138+
array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
139+
126140
# 2023.12 support
127141
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
128142
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]

numpy-1-26-xfails.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
5050
array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
5151
array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
5252

53+
array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
54+
5355
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
5456
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
5557
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

tests/test_torch.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,20 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b):
100100
assert dtype_1 == dtype_2
101101
finally:
102102
torch.set_default_dtype(prev_default)
103+
104+
105+
def test_meshgrid():
106+
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""
107+
108+
x, y = xp.asarray([1, 2]), xp.asarray([4])
109+
110+
X, Y = xp.meshgrid(x, y)
111+
112+
# output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different
113+
X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]])
114+
115+
assert X.shape == X_xy.shape
116+
assert xp.all(X == X_xy)
117+
118+
assert Y.shape == Y_xy.shape
119+
assert xp.all(Y == Y_xy)

0 commit comments

Comments
 (0)