Skip to content

Commit e7ff741

Browse files
committed
Simplify what test_put tests for
* `list_at_ind` stuff should be covered when testing the normilizer * Manually converting `ind` and `v` to `tnp.ndarray` is mostly redundant
1 parent 17ca051 commit e7ff741

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

torch_np/tests/test_xps.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def test_integer_indexing(x, data):
9696
assert result.shape == result_shape
9797

9898

99+
@pytest.mark.filterwarnings(
100+
"ignore:Creating a tensor from a list of numpy.ndarrays.*:UserWarning"
101+
)
99102
@given(
100103
np_x=nps.arrays(
101104
# We specifically use namespaced dtypes to prevent non-native byte-order issues
@@ -124,22 +127,7 @@ def test_put(np_x, data):
124127
assert_array_equal(tnp_x, tnp_x_copy) # sanity check
125128

126129
note(f"{tnp_x=}")
127-
tnp_ind = []
128-
list_at_ind = data.draw(
129-
st.lists(st.booleans(), min_size=len(ind), max_size=len(ind)),
130-
label="list_at_ind",
131-
)
132-
for np_indices, use_list in zip(ind, list_at_ind):
133-
if use_list:
134-
indices = np_indices.tolist()
135-
else:
136-
indices = tnp.asarray(np_indices).astype(np_indices.dtype.name)
137-
tnp_ind.append(indices)
138-
tnp_ind = tuple(tnp_ind)
139-
note(f"{tnp_ind=}")
140-
tnp_v = tnp.asarray(v.copy()).astype(v.dtype.name)
141-
note(f"{tnp_v=}")
142-
tnp.put(tnp_x, tnp_ind, tnp_v)
130+
tnp.put(tnp_x, ind, v)
143131
note(f"(after put) {tnp_x=}")
144132

145133
assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype))

0 commit comments

Comments
 (0)