Skip to content

Commit d8fa04a

Browse files
committed
MAINT: move take_along_axis to common/_aliases
1 parent bbf346c commit d8fa04a

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

array_api_compat/common/_aliases.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,12 @@ def sort(
517517
return res
518518

519519

520+
# take_along_axis: axis defaults to -1; numpy, cupy do not have a default value;
521+
# pytorch defaults to None, which ravels.
522+
def take_along_axis(x: Array, indices: Array, /, *, xp: Namespace, axis: int = -1):
523+
return xp.take_along_axis(x, indices, axis=axis)
524+
525+
520526
# nonzero should error for zero-dimensional arrays
521527
def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
522528
if x.ndim == 0:
@@ -713,6 +719,7 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
713719
"matmul",
714720
"matrix_transpose",
715721
"tensordot",
722+
"take_along_axis",
716723
"vecdot",
717724
"isdtype",
718725
"unstack",

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
sign = get_xp(cp)(_aliases.sign)
6464
finfo = get_xp(cp)(_aliases.finfo)
6565
iinfo = get_xp(cp)(_aliases.iinfo)
66+
take_along_axis = get_xp(cp)(_aliases.take_along_axis)
6667

6768
_copy_default = object()
6869

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
sign = get_xp(np)(_aliases.sign)
7373
finfo = get_xp(np)(_aliases.finfo)
7474
iinfo = get_xp(np)(_aliases.iinfo)
75+
take_along_axis = get_xp(np)(_aliases.take_along_axis)
7576

7677

7778
def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction]
@@ -140,13 +141,6 @@ def count_nonzero(
140141
return result
141142

142143

143-
# "axis=-1" is an optional argument of `take_along_axis` but numpy has no default
144-
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
145-
if axis is None:
146-
axis = -1
147-
return np.take_along_axis(x, indices, axis=axis)
148-
149-
150144
# These functions are completely new here. If the library already has them
151145
# (i.e., numpy 2.0), use the library version instead of our wrapper.
152146
if hasattr(np, "vecdot"):

0 commit comments

Comments
 (0)