We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
v
put()
1 parent e7ff741 commit a51cfbaCopy full SHA for a51cfba
torch_np/_funcs_impl.py
@@ -916,11 +916,14 @@ def put(
916
mode: NotImplementedType = "raise",
917
):
918
v = v.type(a.dtype)
919
- numel_ratio = ind.numel() / v.numel()
920
- if numel_ratio.is_integer():
921
- sizes = [int(numel_ratio)]
922
- sizes.extend([1 for _ in range(v.dim() - 1)])
923
- v = v.repeat(*sizes)
+ # If ind is larger than v, broadcast v to the would-be resulting shape. Any
+ # unnecessary trailing elements are then trimmed.
+ if ind.numel() > v.numel():
+ result_shape = torch.broadcast_shapes(v.shape, ind.shape)
+ v = torch.broadcast_to(v, result_shape)
924
+ if ind.numel() < v.numel():
925
+ v = v.flatten()
926
+ v = v[: ind.numel()]
927
a.put_(ind, v)
928
return None
929
0 commit comments