Skip to content

Commit 9eecff7

Browse files
honnolezcano
andcommitted
Broadcast rather than extend v in put()
Co-authored-by: Mario Lezcano Casado <[email protected]>
1 parent e7ff741 commit 9eecff7

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

torch_np/_funcs_impl.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -916,11 +916,16 @@ def put(
916916
mode: NotImplementedType = "raise",
917917
):
918918
v = v.type(a.dtype)
919-
numel_ratio = ind.numel() / v.numel()
920-
if numel_ratio.is_integer():
921-
sizes = [int(numel_ratio)]
922-
sizes.extend([1 for _ in range(v.dim() - 1)])
923-
v = v.repeat(*sizes)
919+
# If ind is larger than v, broadcast v to the would-be resulting shape. Any
920+
# unnecessary trailing elements are then trimmed.
921+
if ind.numel() > v.numel():
922+
result_shape = torch.broadcast_shapes(v.shape, ind.shape)
923+
v = torch.broadcast_to(v, result_shape)
924+
# Trim unnecessary elements, regarldess if v was broadcasted or not. Note
925+
# np.put() trims v to match ind by default too.
926+
if ind.numel() < v.numel():
927+
v = v.flatten()
928+
v = v[: ind.numel()]
924929
a.put_(ind, v)
925930
return None
926931

0 commit comments

Comments
 (0)