From ae5d1a55f22fd182af3e9ff9f51cebab74882239 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 28 Mar 2023 11:05:21 +0100 Subject: [PATCH 1/3] Remove `take()`-not-implemented xfails --- torch_np/_ndarray.py | 3 ++ .../tests/numpy_tests/core/test_indexing.py | 7 +-- .../tests/numpy_tests/core/test_multiarray.py | 51 +++++++------------ .../tests/numpy_tests/core/test_numeric.py | 1 - .../tests/numpy_tests/lib/test_arraysetops.py | 10 +++- 5 files changed, 32 insertions(+), 40 deletions(-) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 3a02e277..16d6f407 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -407,6 +407,9 @@ def __setitem__(self, index, value): value = _helpers.ndarrays_to_tensors(value) return self.tensor.__setitem__(index, value) + def take(*a, **kw): + raise NotImplementedError() + # This is the ideally the only place which talks to ndarray directly. # The rest goes through asarray (preferred) or array. diff --git a/torch_np/tests/numpy_tests/core/test_indexing.py b/torch_np/tests/numpy_tests/core/test_indexing.py index 2867ec56..d083a480 100644 --- a/torch_np/tests/numpy_tests/core/test_indexing.py +++ b/torch_np/tests/numpy_tests/core/test_indexing.py @@ -1053,9 +1053,8 @@ def test_non_integer_argument_errors(self): assert_raises(TypeError, np.reshape, a, (1., 1., -1)) assert_raises(TypeError, np.reshape, a, (np.array(1.), -1)) - pytest.xfail("XXX: take not implemented") assert_raises(TypeError, np.take, a, [0], 1.) - assert_raises(TypeError, np.take, a, [0], np.float64(1.)) + assert_raises((TypeError, RuntimeError), np.take, a, [0], np.float64(1.)) @pytest.mark.skip( reason=( @@ -1089,7 +1088,6 @@ def test_bool_as_int_argument_errors(self): # array is thus also deprecated, but not with the same message: assert_warns(DeprecationWarning, operator.index, np.True_) - pytest.xfail("XXX: take not implemented") assert_raises(TypeError, np.take, args=(a, [0], False)) pytest.skip("torch consumes boolean tensors as ints, no bother raising here") @@ -1138,8 +1136,7 @@ def test_array_to_index_error(self): # so no exception is expected. The raising is effectively tested above. a = np.array([[[1]]]) - pytest.xfail("XXX: take not implemented") - assert_raises(TypeError, np.take, a, [0], a) + assert_raises((TypeError, RuntimeError), np.take, a, [0], a) pytest.skip( "Multi-dimensional tensors are indexable just as long as they only " diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index b44e89d2..f5da7ec2 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -3817,63 +3817,48 @@ def test_kwargs(self): np.putmask(a=x, values=[-1, -2], mask=[0, 1]) -@pytest.mark.xfail(reason='TODO') class TestTake: def tst_basic(self, x): ind = list(range(x.shape[0])) - assert_array_equal(x.take(ind, axis=0), x) + assert_array_equal(np.take(x, ind, axis=0), x) def test_ip_types(self): - unchecked_types = [bytes, str, np.void] - x = np.random.random(24)*100 - x.shape = 2, 3, 4 + x = np.reshape(x, (2, 3, 4)) for types in np.sctypes.values(): for T in types: - if T not in unchecked_types: - self.tst_basic(x.copy().astype(T)) - - # Also test string of a length which uses an untypical length - self.tst_basic(x.astype("S3")) + self.tst_basic(x.copy().astype(T)) def test_raise(self): x = np.random.random(24)*100 - x.shape = 2, 3, 4 - assert_raises(IndexError, x.take, [0, 1, 2], axis=0) - assert_raises(IndexError, x.take, [-3], axis=0) - assert_array_equal(x.take([-1], axis=0)[0], x[1]) + x = np.reshape(x, (2, 3, 4)) + assert_raises(IndexError, np.take, x, [0, 1, 2], axis=0) + assert_raises(IndexError, np.take, x, [-3], axis=0) + assert_array_equal(np.take(x, [-1], axis=0)[0], x[1]) + @pytest.mark.xfail(reason="XXX: take(..., mode='clip')") def test_clip(self): x = np.random.random(24)*100 - x.shape = 2, 3, 4 - assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0]) - assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1]) + x = np.reshape(x, (2, 3, 4)) + assert_array_equal(np.take(x, [-1], axis=0, mode='clip')[0], x[0]) + assert_array_equal(np.take(x, [2], axis=0, mode='clip')[0], x[1]) + @pytest.mark.xfail(reason="XXX: take(..., mode='wrap')") def test_wrap(self): x = np.random.random(24)*100 - x.shape = 2, 3, 4 - assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1]) - assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0]) - assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1]) - - @pytest.mark.parametrize('dtype', ('>i4', 'f8'), ('z', '>> x = np.array([[[1, 1], [0, 1]], [[1, 0], [0, 0]]]) + # >>> np.unique(x, axis=2) + # [[1, 1], [0, 1]], [[1, 0], [0, 0]] + # >>> torch.unique(torch.as_tensor(x), dim=2) + # [[1, 1], [1, 0]], [[0, 1], [0, 0]] + # msg = 'Unique with 3d array and axis=2 failed' data3d = np.array([[[1, 1], [1, 0]], From e55688f68f4488940872a943c07a4c0640960528 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 31 Mar 2023 17:39:10 +0300 Subject: [PATCH 2/3] `take()` implementation --- torch_np/_funcs.py | 15 +++++++++++++++ torch_np/_ndarray.py | 3 +-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 115d7c64..963c9abc 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -929,6 +929,21 @@ def asfarray(): # ### put/take_along_axis ### +@normalizer +def take(a: ArrayLike, indices: ArrayLike, axis=None, out : Optional[NDArray]=None, mode="raise"): + if out is not None: + raise NotImplementedError(f"{out=}") + if mode != "raise": + raise NotImplementedError(f"{mode=}") + + (a,), axis = _util.axis_none_ravel(a, axis=axis) + axis = _util.normalize_axis_index(axis, a.ndim) + idx = (slice(None),) * axis + (indices, ...) + result = a[idx] + return result + + + @normalizer def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 16d6f407..08b0ed18 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -407,8 +407,7 @@ def __setitem__(self, index, value): value = _helpers.ndarrays_to_tensors(value) return self.tensor.__setitem__(index, value) - def take(*a, **kw): - raise NotImplementedError() + take = _funcs.take # This is the ideally the only place which talks to ndarray directly. From 73ea76ca8b3fd035c30841acb8a68b25d21743d7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 31 Mar 2023 17:56:14 +0300 Subject: [PATCH 3/3] TST: un-xfail tests of take(..., out=out) --- torch_np/_funcs.py | 13 ++++++++----- torch_np/tests/numpy_tests/core/test_multiarray.py | 3 +-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 963c9abc..e230f319 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -929,10 +929,15 @@ def asfarray(): # ### put/take_along_axis ### + @normalizer -def take(a: ArrayLike, indices: ArrayLike, axis=None, out : Optional[NDArray]=None, mode="raise"): - if out is not None: - raise NotImplementedError(f"{out=}") +def take( + a: ArrayLike, + indices: ArrayLike, + axis=None, + out: Optional[NDArray] = None, + mode="raise", +): if mode != "raise": raise NotImplementedError(f"{mode=}") @@ -943,8 +948,6 @@ def take(a: ArrayLike, indices: ArrayLike, axis=None, out : Optional[NDArray]=No return result - - @normalizer def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): (arr,), axis = _util.axis_none_ravel(arr, axis=axis) diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index f5da7ec2..38caf6d9 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -3851,14 +3851,13 @@ def test_wrap(self): assert_array_equal(np.take(x, [2], axis=0, mode='wrap')[0], x[0]) assert_array_equal(np.take(x, [3], axis=0, mode='wrap')[0], x[1]) - @pytest.mark.xfail(reason="XXX: take(out=...)") + @pytest.mark.xfail(reason="XXX: take(mode='wrap')") def test_out_overlap(self): # gh-6272 check overlap on out x = np.arange(5) y = np.take(x, [1, 2, 3], out=x[2:5], mode='wrap') assert_equal(y, np.array([1, 2, 3])) - @pytest.mark.xfail(reason="XXX: take(out=...)") @pytest.mark.parametrize('shape', [(1, 2), (1,), ()]) def test_ret_is_out(self, shape): # 0d arrays should not be an exception to this rule