Skip to content

Commit d844d51

Browse files
committed
tnp.put and test_put
1 parent c5066ba commit d844d51

File tree

2 files changed

+81
-9
lines changed

2 files changed

+81
-9
lines changed

torch_np/_funcs_impl.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,28 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
913913
return torch.take_along_dim(arr, indices, axis)
914914

915915

916+
def put(a: ArrayLike, ind: Sequence[ArrayLike], v: ArrayLike, mode="raise"):
917+
if mode != "raise":
918+
raise NotImplementedError(f"{mode=}")
919+
920+
index = torch.concat(ind)
921+
index[index < 0] += a.numel() # normalise negative indices
922+
index_u, index_c = torch.unique(index, return_counts=True)
923+
duplicated_indices = index_u[index_c > 1]
924+
if duplicated_indices.numel() > 0:
925+
raise NotImplementedError(
926+
"duplicated indices are not supported. duplicated indices: "
927+
f"{list(duplicated_indices)}"
928+
)
929+
source = v
930+
if source.numel() < index.numel():
931+
source = torch.broadcast_to(source, index.size())
932+
# Note Tensor.put_ acts in-place, while Tensor.put (no trailing underscore)
933+
# is seemingly out-of-place.
934+
a.put_(index, source)
935+
return None
936+
937+
916938
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
917939
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
918940
axis = _util.normalize_axis_index(axis, arr.ndim)

torch_np/tests/test_xps.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,35 @@
1010

1111
pytest.importorskip("hypothesis")
1212

13-
from hypothesis import given
13+
import numpy as np
14+
import torch
15+
from hypothesis import given, note
1416
from hypothesis import strategies as st
1517
from hypothesis.errors import HypothesisWarning
18+
from hypothesis.extra import numpy as nps
1619
from hypothesis.extra.array_api import make_strategies_namespace
1720

18-
import torch_np as np
21+
import torch_np as tnp
22+
from torch_np._dtypes import sctypes
23+
from torch_np.testing import assert_array_equal
1924

2025
__all__ = ["xps"]
2126

2227
with warnings.catch_warnings():
2328
warnings.filterwarnings("ignore", category=HypothesisWarning)
24-
np.bool = np.bool_
25-
xps = make_strategies_namespace(np, api_version="2022.12")
29+
tnp.bool = tnp.bool_
30+
xps = make_strategies_namespace(tnp, api_version="2022.12")
2631

2732

28-
default_dtypes = [np.bool, np.int64, np.float64, np.complex128]
33+
default_dtypes = [tnp.bool, tnp.int64, tnp.float64, tnp.complex128]
2934
kind_to_strat = {
3035
"b": xps.boolean_dtypes(),
3136
"i": xps.integer_dtypes(),
3237
"u": xps.unsigned_integer_dtypes(sizes=8),
3338
"f": xps.floating_dtypes(),
3439
"c": xps.complex_dtypes(),
3540
}
36-
scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(np.dtype)
41+
scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(tnp.dtype)
3742

3843

3944
@pytest.mark.skip(reason="flaky")
@@ -55,14 +60,14 @@ def test_full(shape, data):
5560
else:
5661
values_dtypes_strat = st.just(_dtype)
5762
values_strat = values_dtypes_strat.flatmap(
58-
lambda d: values_strat.map(lambda v: np.asarray(v, dtype=d))
63+
lambda d: values_strat.map(lambda v: tnp.asarray(v, dtype=d))
5964
)
6065
fill_value = data.draw(values_strat, label="fill_value")
61-
out = np.full(shape, fill_value, **kw)
66+
out = tnp.full(shape, fill_value, **kw)
6267
assert out.dtype == _dtype
6368
assert out.shape == shape
6469
if cmath.isnan(fill_value):
65-
assert np.isnan(out).all()
70+
assert tnp.isnan(out).all()
6671
else:
6772
assert (out == fill_value).all()
6873

@@ -89,3 +94,48 @@ def test_integer_indexing(x, data):
8994
idx = data.draw(integer_array_indices(x.shape, result_shape), label="idx")
9095
result = x[idx]
9196
assert result.shape == result_shape
97+
98+
99+
@given(
100+
np_x=nps.arrays(
101+
# We specifically use namespaced dtypes to prevent non-native byte-order issues
102+
dtype=scalar_dtype_strat.map(lambda d: getattr(np, d.name)),
103+
shape=nps.array_shapes(),
104+
),
105+
data=st.data(),
106+
)
107+
def test_put(np_x, data):
108+
# We cast arrays from torch_np.asarray as currently it doesn't carry over
109+
# dtypes. XXX: Remove the below sanity check and subsequent casting when
110+
# this is fixed.
111+
assert tnp.asarray(np.zeros(5, dtype=np.int16)).dtype != tnp.int16
112+
113+
tnp_x = tnp.asarray(np_x.copy()).astype(np_x.dtype.name)
114+
115+
result_shapes = st.shared(nps.array_shapes())
116+
ind = data.draw(
117+
nps.integer_array_indices(np_x.shape, result_shape=result_shapes), label="ind"
118+
)
119+
v = data.draw(nps.arrays(dtype=np_x.dtype, shape=result_shapes), label="v")
120+
121+
tnp_x_copy = tnp_x.copy()
122+
np.put(np_x, ind, v)
123+
note(f"(after put) {np_x=}")
124+
assert_array_equal(tnp_x, tnp_x_copy) # sanity check
125+
126+
note(f"{tnp_x=}")
127+
tnp_ind = []
128+
for np_indices in ind:
129+
tnp_indices = tnp.asarray(np_indices).astype(np_indices.dtype.name)
130+
tnp_ind.append(tnp_indices)
131+
tnp_ind = tuple(tnp_ind)
132+
note(f"{tnp_ind=}")
133+
tnp_v = tnp.asarray(v.copy()).astype(v.dtype.name)
134+
note(f"{tnp_v=}")
135+
try:
136+
tnp.put(tnp_x, tnp_ind, tnp_v)
137+
except NotImplementedError:
138+
return
139+
note(f"(after put) {tnp_x=}")
140+
141+
assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype))

0 commit comments

Comments
 (0)