Skip to content

Commit e55688f

Browse files
honnoev-br
authored andcommitted
take() implementation
1 parent ae5d1a5 commit e55688f

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

torch_np/_funcs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,21 @@ def asfarray():
929929

930930
# ### put/take_along_axis ###
931931

932+
@normalizer
933+
def take(a: ArrayLike, indices: ArrayLike, axis=None, out : Optional[NDArray]=None, mode="raise"):
934+
if out is not None:
935+
raise NotImplementedError(f"{out=}")
936+
if mode != "raise":
937+
raise NotImplementedError(f"{mode=}")
938+
939+
(a,), axis = _util.axis_none_ravel(a, axis=axis)
940+
axis = _util.normalize_axis_index(axis, a.ndim)
941+
idx = (slice(None),) * axis + (indices, ...)
942+
result = a[idx]
943+
return result
944+
945+
946+
932947

933948
@normalizer
934949
def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):

torch_np/_ndarray.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,7 @@ def __setitem__(self, index, value):
407407
value = _helpers.ndarrays_to_tensors(value)
408408
return self.tensor.__setitem__(index, value)
409409

410-
def take(*a, **kw):
411-
raise NotImplementedError()
410+
take = _funcs.take
412411

413412

414413
# This is the ideally the only place which talks to ndarray directly.

0 commit comments

Comments
 (0)