Skip to content

Commit 0aaefd6

Browse files
committed
BUG: cannot add take_along_axis to common/_aliases because of dask
1 parent d8fa04a commit 0aaefd6

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

array_api_compat/common/_aliases.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -517,12 +517,6 @@ def sort(
517517
return res
518518

519519

520-
# take_along_axis: axis defaults to -1; numpy, cupy do not have a default value;
521-
# pytorch defaults to None, which ravels.
522-
def take_along_axis(x: Array, indices: Array, /, *, xp: Namespace, axis: int = -1):
523-
return xp.take_along_axis(x, indices, axis=axis)
524-
525-
526520
# nonzero should error for zero-dimensional arrays
527521
def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
528522
if x.ndim == 0:
@@ -719,7 +713,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
719713
"matmul",
720714
"matrix_transpose",
721715
"tensordot",
722-
"take_along_axis",
723716
"vecdot",
724717
"isdtype",
725718
"unstack",

array_api_compat/cupy/_aliases.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
sign = get_xp(cp)(_aliases.sign)
6464
finfo = get_xp(cp)(_aliases.finfo)
6565
iinfo = get_xp(cp)(_aliases.iinfo)
66-
take_along_axis = get_xp(cp)(_aliases.take_along_axis)
6766

6867
_copy_default = object()
6968

@@ -139,6 +138,11 @@ def count_nonzero(
139138
return result
140139

141140

141+
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
142+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
143+
return xp.take_along_axis(x, indices, axis=axis)
144+
145+
142146
# These functions are completely new here. If the library already has them
143147
# (i.e., numpy 2.0), use the library version instead of our wrapper.
144148
if hasattr(cp, 'vecdot'):
@@ -160,6 +164,7 @@ def count_nonzero(
160164
'acos', 'acosh', 'asin', 'asinh', 'atan',
161165
'atan2', 'atanh', 'bitwise_left_shift',
162166
'bitwise_invert', 'bitwise_right_shift',
163-
'bool', 'concat', 'count_nonzero', 'pow', 'sign']
167+
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
168+
'take_along_axis']
164169

165170
_all_ignore = ['cp', 'get_xp']

array_api_compat/numpy/_aliases.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
sign = get_xp(np)(_aliases.sign)
7373
finfo = get_xp(np)(_aliases.finfo)
7474
iinfo = get_xp(np)(_aliases.iinfo)
75-
take_along_axis = get_xp(np)(_aliases.take_along_axis)
7675

7776

7877
def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction]
@@ -141,6 +140,11 @@ def count_nonzero(
141140
return result
142141

143142

143+
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
144+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
145+
return xp.take_along_axis(x, indices, axis=axis)
146+
147+
144148
# These functions are completely new here. If the library already has them
145149
# (i.e., numpy 2.0), use the library version instead of our wrapper.
146150
if hasattr(np, "vecdot"):
@@ -158,7 +162,6 @@ def count_nonzero(
158162
else:
159163
unstack = get_xp(np)(_aliases.unstack)
160164

161-
162165
__all__ = [
163166
"__array_namespace_info__",
164167
"asarray",
@@ -185,4 +188,3 @@ def count_nonzero(
185188

186189
def __dir__() -> list[str]:
187190
return __all__
188-

0 commit comments

Comments
 (0)