Skip to content

Commit 43d60b5

Browse files
committed
Add draft implementation of take_along_axis for 2024.12
As far as I can tell, NumPy matches the standard specification, except for the fact that NumPy does not set a default value for axis.
1 parent 9726bc0 commit 43d60b5

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@
262262
"trunc",
263263
]
264264

265-
from ._indexing_functions import take
265+
from ._indexing_functions import take, take_along_axis
266266

267-
__all__ += ["take"]
267+
__all__ += ["take", "take_along_axis"]
268268

269269
from ._info import __array_namespace_info__
270270

array_api_strict/_indexing_functions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._array_object import Array
44
from ._dtypes import _integer_dtypes
5+
from ._flags import requires_api_version
56

67
from typing import TYPE_CHECKING
78

@@ -25,3 +26,14 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
2526
if x.device != indices.device:
2627
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
2728
return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device)
29+
30+
@requires_api_version('2024.12')
31+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
32+
"""
33+
Array API compatible wrapper for :py:func:`np.take_along_axis <numpy.take_along_axis>`.
34+
35+
See its docstring for more information.
36+
"""
37+
if x.device != indices.device:
38+
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
39+
return Array._new(np.take_along_axis(x._array, indices._array, axis), device=x.device)

array_api_strict/tests/test_flags.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,6 @@ def test_fft(func_name):
284284
'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0),
285285
}
286286

287-
api_version_2024_12_examples = {
288-
'diff': lambda: xp.diff(xp.asarray([0, 1, 2])),
289-
'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)),
290-
'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])),
291-
}
292-
293287
@pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys())
294288
def test_api_version_2023_12(func_name):
295289
func = api_version_2023_12_examples[func_name]
@@ -308,6 +302,14 @@ def test_api_version_2023_12(func_name):
308302
set_array_api_strict_flags(api_version='2022.12')
309303
pytest.raises(RuntimeError, func)
310304

305+
api_version_2024_12_examples = {
306+
'diff': lambda: xp.diff(xp.asarray([0, 1, 2])),
307+
'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)),
308+
'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])),
309+
'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)),
310+
xp.zeros((1, 4), dtype=xp.int64)),
311+
}
312+
311313
@pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys())
312314
def test_api_version_2024_12(func_name):
313315
func = api_version_2024_12_examples[func_name]

0 commit comments

Comments
 (0)