Skip to content

Commit a51cfba

Browse files
committed
Broadcast rather than extend v in put()
1 parent e7ff741 commit a51cfba

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

torch_np/_funcs_impl.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -916,11 +916,14 @@ def put(
916916
mode: NotImplementedType = "raise",
917917
):
918918
v = v.type(a.dtype)
919-
numel_ratio = ind.numel() / v.numel()
920-
if numel_ratio.is_integer():
921-
sizes = [int(numel_ratio)]
922-
sizes.extend([1 for _ in range(v.dim() - 1)])
923-
v = v.repeat(*sizes)
919+
# If ind is larger than v, broadcast v to the would-be resulting shape. Any
920+
# unnecessary trailing elements are then trimmed.
921+
if ind.numel() > v.numel():
922+
result_shape = torch.broadcast_shapes(v.shape, ind.shape)
923+
v = torch.broadcast_to(v, result_shape)
924+
if ind.numel() < v.numel():
925+
v = v.flatten()
926+
v = v[: ind.numel()]
924927
a.put_(ind, v)
925928
return None
926929

0 commit comments

Comments
 (0)