Skip to content

Commit 17ca051

Browse files
committed
put(): a: ArrayLike -> a: NDArray
Prevents normalising non-ndarray arguments
1 parent c84cc52 commit 17ca051

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

torch_np/_funcs_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
910910

911911

912912
def put(
913-
a: ArrayLike,
913+
a: NDArray,
914914
ind: ArrayLike,
915915
v: ArrayLike,
916916
mode: NotImplementedType = "raise",

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,7 +2637,6 @@ def test_trace(self):
26372637
ret = a.trace(out=out)
26382638
assert ret is out
26392639

2640-
@pytest.mark.xfail(reason="TODO: implement put")
26412640
def test_put(self):
26422641
icodes = np.typecodes['AllInteger']
26432642
fcodes = np.typecodes['AllFloat']
@@ -2670,7 +2669,6 @@ def test_put(self):
26702669
# when calling np.put, make sure a
26712670
# TypeError is raised if the object
26722671
# isn't an ndarray
2673-
pytest.xfail(reason="XXX: Argument normalisation prevents catching this")
26742672
bad_array = [1, 2, 3]
26752673
assert_raises(TypeError, np.put, bad_array, [0, 2], 5)
26762674

0 commit comments

Comments
 (0)