Skip to content

Commit 765b82a

Browse files
committed
Treat ind as ArrayLike and rely on its normalisation
1 parent a13de94 commit 765b82a

File tree

1 file changed

+5
-19
lines changed

1 file changed

+5
-19
lines changed

torch_np/_funcs_impl.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -911,29 +911,15 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
911911

912912
def put(
913913
a: ArrayLike,
914-
ind: Sequence[ArrayLike],
914+
ind: ArrayLike,
915915
v: ArrayLike,
916916
mode: NotImplementedType = "raise",
917917
):
918-
indexes = list(ind)
919-
for i, index in enumerate(indexes):
920-
if not isinstance(index, torch.Tensor):
921-
indexes[i] = torch.as_tensor(index)
922-
index = torch.concat(indexes)
923-
index[index < 0] += a.numel() # normalise negative indices
924-
index_u, index_c = torch.unique(index, return_counts=True)
925-
duplicated_indices = index_u[index_c > 1]
926-
if duplicated_indices.numel() > 0:
927-
raise NotImplementedError(
928-
"duplicated indices are not supported. duplicated indices: "
929-
f"{duplicated_indices}"
930-
)
931-
source = v
932-
if source.numel() < index.numel():
933-
numel_ratio = float(index.numel() / source.numel())
918+
if v.numel() < ind.numel():
919+
numel_ratio = float(ind.numel() / v.numel())
934920
if numel_ratio.is_integer():
935-
source = torch.stack([source for _ in range(int(numel_ratio))])
936-
a.put_(index, source)
921+
v = torch.stack([v for _ in range(int(numel_ratio))])
922+
a.put_(ind, v)
937923
return None
938924

939925

0 commit comments

Comments
 (0)