diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 62c0a2e5..6df0aba6 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -743,6 +743,14 @@ def take_along_dim(tensor, t_indices, axis): return result +def take(tensor, t_indices, axis): + (tensor,), axis = _util.axis_none_ravel(tensor, axis=axis) + axis = _util.normalize_axis_index(axis, tensor.ndim) + idx = (slice(None),) * axis + (t_indices, ...) + result = tensor[idx] + return result + + def put_along_dim(tensor, t_indices, t_values, axis): (tensor,), axis = _util.axis_none_ravel(tensor, axis=axis) axis = _util.normalize_axis_index(axis, tensor.ndim) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 8cecd76b..f99f7c86 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -404,6 +404,9 @@ def __setitem__(self, index, value): value = _helpers.ndarrays_to_tensors(value) return self.tensor.__setitem__(index, value) + def take(*a, **kw): + raise NotImplementedError() + # This is the ideally the only place which talks to ndarray directly. # The rest goes through asarray (preferred) or array. diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 1a9ed8d6..565cb102 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -815,7 +815,7 @@ def asfarray(): raise NotImplementedError -# ### put/take_along_axis ### +# ### put/take et al ### @normalizer @@ -824,6 +824,16 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): return result +@normalizer +def take(a: ArrayLike, indices: ArrayLike, axis=None, out=None, mode="raise"): + if out is not None: + raise NotImplementedError(f"{out=}") + if mode != "raise": + raise NotImplementedError(f"{mode=}") + result = _impl.take(a, indices, axis) + return result + + @normalizer def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis): # modify the argument in-place : here `arr` is `arr._tensor` of the orignal `arr` argument diff --git a/torch_np/tests/numpy_tests/core/test_indexing.py b/torch_np/tests/numpy_tests/core/test_indexing.py index 2867ec56..d083a480 100644 --- a/torch_np/tests/numpy_tests/core/test_indexing.py +++ b/torch_np/tests/numpy_tests/core/test_indexing.py @@ -1053,9 +1053,8 @@ def test_non_integer_argument_errors(self): assert_raises(TypeError, np.reshape, a, (1., 1., -1)) assert_raises(TypeError, np.reshape, a, (np.array(1.), -1)) - pytest.xfail("XXX: take not implemented") assert_raises(TypeError, np.take, a, [0], 1.) - assert_raises(TypeError, np.take, a, [0], np.float64(1.)) + assert_raises((TypeError, RuntimeError), np.take, a, [0], np.float64(1.)) @pytest.mark.skip( reason=( @@ -1089,7 +1088,6 @@ def test_bool_as_int_argument_errors(self): # array is thus also deprecated, but not with the same message: assert_warns(DeprecationWarning, operator.index, np.True_) - pytest.xfail("XXX: take not implemented") assert_raises(TypeError, np.take, args=(a, [0], False)) pytest.skip("torch consumes boolean tensors as ints, no bother raising here") @@ -1138,8 +1136,7 @@ def test_array_to_index_error(self): # so no exception is expected. The raising is effectively tested above. a = np.array([[[1]]]) - pytest.xfail("XXX: take not implemented") - assert_raises(TypeError, np.take, a, [0], a) + assert_raises((TypeError, RuntimeError), np.take, a, [0], a) pytest.skip( "Multi-dimensional tensors are indexable just as long as they only " diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index 4e513a4a..68100d4c 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -4179,63 +4179,48 @@ def test_kwargs(self): np.putmask(a=x, values=[-1, -2], mask=[0, 1]) -@pytest.mark.xfail(reason='TODO') class TestTake: def tst_basic(self, x): ind = list(range(x.shape[0])) - assert_array_equal(x.take(ind, axis=0), x) + assert_array_equal(np.take(x, ind, axis=0), x) def test_ip_types(self): - unchecked_types = [bytes, str, np.void] - x = np.random.random(24)*100 - x.shape = 2, 3, 4 + x = np.reshape(x, (2, 3, 4)) for types in np.sctypes.values(): for T in types: - if T not in unchecked_types: - self.tst_basic(x.copy().astype(T)) - - # Also test string of a length which uses an untypical length - self.tst_basic(x.astype("S3")) + self.tst_basic(x.copy().astype(T)) def test_raise(self): x = np.random.random(24)*100 - x.shape = 2, 3, 4 - assert_raises(IndexError, x.take, [0, 1, 2], axis=0) - assert_raises(IndexError, x.take, [-3], axis=0) - assert_array_equal(x.take([-1], axis=0)[0], x[1]) + x = np.reshape(x, (2, 3, 4)) + assert_raises(IndexError, np.take, x, [0, 1, 2], axis=0) + assert_raises(IndexError, np.take, x, [-3], axis=0) + assert_array_equal(np.take(x, [-1], axis=0)[0], x[1]) + @pytest.mark.xfail(reason="XXX: take(..., mode='clip')") def test_clip(self): x = np.random.random(24)*100 - x.shape = 2, 3, 4 - assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0]) - assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1]) + x = np.reshape(x, (2, 3, 4)) + assert_array_equal(np.take(x, [-1], axis=0, mode='clip')[0], x[0]) + assert_array_equal(np.take(x, [2], axis=0, mode='clip')[0], x[1]) + @pytest.mark.xfail(reason="XXX: take(..., mode='wrap')") def test_wrap(self): x = np.random.random(24)*100 - x.shape = 2, 3, 4 - assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1]) - assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0]) - assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1]) - - @pytest.mark.parametrize('dtype', ('>i4', 'f8'), ('z', '>> x = np.array([[[1, 1], [0, 1]], [[1, 0], [0, 0]]]) + # >>> np.unique(x, axis=2) + # [[1, 1], [0, 1]], [[1, 0], [0, 0]] + # >>> torch.unique(torch.as_tensor(x), dim=2) + # [[1, 1], [1, 0]], [[0, 1], [0, 0]] + # msg = 'Unique with 3d array and axis=2 failed' data3d = np.array([[[1, 1], [1, 0]],