Skip to content

Commit 65c2127

Browse files
committed
put(): expand over repeat internally
Also test 0d indices
1 parent cae6d94 commit 65c2127

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

torch_np/_funcs_impl.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -901,14 +901,12 @@ def put(
901901
mode: NotImplementedType = "raise",
902902
):
903903
v = v.type(a.dtype)
904-
# If ind is larger than v, broadcast v to the would-be resulting shape. Any
904+
# If ind is larger than v, expand v to at least the size of ind. Any
905905
# unnecessary trailing elements are then trimmed.
906906
if ind.numel() > v.numel():
907907
ratio = (ind.numel() + v.numel() - 1) // v.numel()
908-
sizes = [ratio]
909-
sizes.extend([1 for _ in range(v.dim() - 1)])
910-
v = v.repeat(*sizes)
911-
# Trim unnecessary elements, regarldess if v was broadcasted or not. Note
908+
v = v.unsqueeze(0).expand((ratio,) + v.shape)
909+
# Trim unnecessary elements, regarldess if v was expanded or not. Note
912910
# np.put() trims v to match ind by default too.
913911
if ind.numel() < v.numel():
914912
v = v.flatten()

torch_np/tests/test_xps.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,12 @@ def test_put(np_x, data):
117117
tnp_x = tnp.asarray(np_x.copy()).astype(np_x.dtype.name)
118118

119119
result_shape = data.draw(nps.array_shapes(), label="result_shape")
120-
ind_strat = nps.integer_array_indices(
121-
np_x.shape, result_shape=st.just(result_shape)
122-
)
120+
if result_shape == ():
121+
ind_strat = st.integers(np_x.size)
122+
else:
123+
ind_strat = nps.integer_array_indices(
124+
np_x.shape, result_shape=st.just(result_shape)
125+
)
123126
ind = data.draw(ind_strat | ind_strat.map(np.asarray), label="ind")
124127
v = data.draw(
125128
nps.arrays(

0 commit comments

Comments
 (0)