File tree Expand file tree Collapse file tree 1 file changed +10
-5
lines changed Expand file tree Collapse file tree 1 file changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -916,11 +916,16 @@ def put(
916
916
mode : NotImplementedType = "raise" ,
917
917
):
918
918
v = v .type (a .dtype )
919
- numel_ratio = ind .numel () / v .numel ()
920
- if numel_ratio .is_integer ():
921
- sizes = [int (numel_ratio )]
922
- sizes .extend ([1 for _ in range (v .dim () - 1 )])
923
- v = v .repeat (* sizes )
919
+ # If ind is larger than v, broadcast v to the would-be resulting shape. Any
920
+ # unnecessary trailing elements are then trimmed.
921
+ if ind .numel () > v .numel ():
922
+ result_shape = torch .broadcast_shapes (v .shape , ind .shape )
923
+ v = torch .broadcast_to (v , result_shape )
924
+ # Trim unnecessary elements, regarldess if v was broadcasted or not. Note
925
+ # np.put() trims v to match ind by default too.
926
+ if ind .numel () < v .numel ():
927
+ v = v .flatten ()
928
+ v = v [: ind .numel ()]
924
929
a .put_ (ind , v )
925
930
return None
926
931
You can’t perform that action at this time.
0 commit comments