File tree 2 files changed +9
-8
lines changed 2 files changed +9
-8
lines changed Original file line number Diff line number Diff line change @@ -901,14 +901,12 @@ def put(
901
901
mode : NotImplementedType = "raise" ,
902
902
):
903
903
v = v .type (a .dtype )
904
- # If ind is larger than v, broadcast v to the would-be resulting shape . Any
904
+ # If ind is larger than v, expand v to at least the size of ind . Any
905
905
# unnecessary trailing elements are then trimmed.
906
906
if ind .numel () > v .numel ():
907
907
ratio = (ind .numel () + v .numel () - 1 ) // v .numel ()
908
- sizes = [ratio ]
909
- sizes .extend ([1 for _ in range (v .dim () - 1 )])
910
- v = v .repeat (* sizes )
911
- # Trim unnecessary elements, regarldess if v was broadcasted or not. Note
908
+ v = v .unsqueeze (0 ).expand ((ratio ,) + v .shape )
909
+ # Trim unnecessary elements, regarldess if v was expanded or not. Note
912
910
# np.put() trims v to match ind by default too.
913
911
if ind .numel () < v .numel ():
914
912
v = v .flatten ()
Original file line number Diff line number Diff line change @@ -117,9 +117,12 @@ def test_put(np_x, data):
117
117
tnp_x = tnp .asarray (np_x .copy ()).astype (np_x .dtype .name )
118
118
119
119
result_shape = data .draw (nps .array_shapes (), label = "result_shape" )
120
- ind_strat = nps .integer_array_indices (
121
- np_x .shape , result_shape = st .just (result_shape )
122
- )
120
+ if result_shape == ():
121
+ ind_strat = st .integers (np_x .size )
122
+ else :
123
+ ind_strat = nps .integer_array_indices (
124
+ np_x .shape , result_shape = st .just (result_shape )
125
+ )
123
126
ind = data .draw (ind_strat | ind_strat .map (np .asarray ), label = "ind" )
124
127
v = data .draw (
125
128
nps .arrays (
You can’t perform that action at this time.
0 commit comments