Skip to content

Commit cae6d94

Browse files
committedApr 25, 2023
put(): support and test broadcast-incompatible args
1 parent 40f7ccb commit cae6d94

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed
 

‎torch_np/_funcs_impl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -904,8 +904,10 @@ def put(
904904
# If ind is larger than v, broadcast v to the would-be resulting shape. Any
905905
# unnecessary trailing elements are then trimmed.
906906
if ind.numel() > v.numel():
907-
result_shape = torch.broadcast_shapes(v.shape, ind.shape)
908-
v = torch.broadcast_to(v, result_shape)
907+
ratio = (ind.numel() + v.numel() - 1) // v.numel()
908+
sizes = [ratio]
909+
sizes.extend([1 for _ in range(v.dim() - 1)])
910+
v = v.repeat(*sizes)
909911
# Trim unnecessary elements, regarldess if v was broadcasted or not. Note
910912
# np.put() trims v to match ind by default too.
911913
if ind.numel() < v.numel():

‎torch_np/tests/test_xps.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
These tests aren't specifically for testing Array API adoption!
55
"""
66
import cmath
7+
import math
78
import warnings
89

910
import pytest
@@ -115,11 +116,20 @@ def test_put(np_x, data):
115116

116117
tnp_x = tnp.asarray(np_x.copy()).astype(np_x.dtype.name)
117118

118-
result_shapes = st.shared(nps.array_shapes())
119-
ind = data.draw(
120-
nps.integer_array_indices(np_x.shape, result_shape=result_shapes), label="ind"
119+
result_shape = data.draw(nps.array_shapes(), label="result_shape")
120+
ind_strat = nps.integer_array_indices(
121+
np_x.shape, result_shape=st.just(result_shape)
122+
)
123+
ind = data.draw(ind_strat | ind_strat.map(np.asarray), label="ind")
124+
v = data.draw(
125+
nps.arrays(
126+
dtype=np_x.dtype,
127+
shape=nps.array_shapes().filter(
128+
lambda s: math.prod(s) > math.prod(result_shape)
129+
),
130+
),
131+
label="v",
121132
)
122-
v = data.draw(nps.arrays(dtype=np_x.dtype, shape=result_shapes), label="v")
123133

124134
tnp_x_copy = tnp_x.copy()
125135
np.put(np_x, ind, v)

0 commit comments

Comments
 (0)
Please sign in to comment.