File tree Expand file tree Collapse file tree 2 files changed +18
-6
lines changed Expand file tree Collapse file tree 2 files changed +18
-6
lines changed Original file line number Diff line number Diff line change @@ -904,8 +904,10 @@ def put(
904
904
# If ind is larger than v, broadcast v to the would-be resulting shape. Any
905
905
# unnecessary trailing elements are then trimmed.
906
906
if ind .numel () > v .numel ():
907
- result_shape = torch .broadcast_shapes (v .shape , ind .shape )
908
- v = torch .broadcast_to (v , result_shape )
907
+ ratio = (ind .numel () + v .numel () - 1 ) // v .numel ()
908
+ sizes = [ratio ]
909
+ sizes .extend ([1 for _ in range (v .dim () - 1 )])
910
+ v = v .repeat (* sizes )
909
911
# Trim unnecessary elements, regarldess if v was broadcasted or not. Note
910
912
# np.put() trims v to match ind by default too.
911
913
if ind .numel () < v .numel ():
Original file line number Diff line number Diff line change 4
4
These tests aren't specifically for testing Array API adoption!
5
5
"""
6
6
import cmath
7
+ import math
7
8
import warnings
8
9
9
10
import pytest
@@ -115,11 +116,20 @@ def test_put(np_x, data):
115
116
116
117
tnp_x = tnp .asarray (np_x .copy ()).astype (np_x .dtype .name )
117
118
118
- result_shapes = st .shared (nps .array_shapes ())
119
- ind = data .draw (
120
- nps .integer_array_indices (np_x .shape , result_shape = result_shapes ), label = "ind"
119
+ result_shape = data .draw (nps .array_shapes (), label = "result_shape" )
120
+ ind_strat = nps .integer_array_indices (
121
+ np_x .shape , result_shape = st .just (result_shape )
122
+ )
123
+ ind = data .draw (ind_strat | ind_strat .map (np .asarray ), label = "ind" )
124
+ v = data .draw (
125
+ nps .arrays (
126
+ dtype = np_x .dtype ,
127
+ shape = nps .array_shapes ().filter (
128
+ lambda s : math .prod (s ) > math .prod (result_shape )
129
+ ),
130
+ ),
131
+ label = "v" ,
121
132
)
122
- v = data .draw (nps .arrays (dtype = np_x .dtype , shape = result_shapes ), label = "v" )
123
133
124
134
tnp_x_copy = tnp_x .copy ()
125
135
np .put (np_x , ind , v )
You can’t perform that action at this time.
0 commit comments