diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 115d7c64..e230f319 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -930,6 +930,24 @@ def asfarray(): # ### put/take_along_axis ### +@normalizer +def take( + a: ArrayLike, + indices: ArrayLike, + axis=None, + out: Optional[NDArray] = None, + mode="raise", +): + if mode != "raise": + raise NotImplementedError(f"{mode=}") + + (a,), axis = _util.axis_none_ravel(a, axis=axis) + axis = _util.normalize_axis_index(axis, a.ndim) + idx = (slice(None),) * axis + (indices, ...) + result = a[idx] + return result + + @normalizer def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): (arr,), axis = _util.axis_none_ravel(arr, axis=axis) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 3a02e277..08b0ed18 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -407,6 +407,8 @@ def __setitem__(self, index, value): value = _helpers.ndarrays_to_tensors(value) return self.tensor.__setitem__(index, value) + take = _funcs.take + # This is the ideally the only place which talks to ndarray directly. # The rest goes through asarray (preferred) or array. diff --git a/torch_np/tests/numpy_tests/core/test_indexing.py b/torch_np/tests/numpy_tests/core/test_indexing.py index 2867ec56..d083a480 100644 --- a/torch_np/tests/numpy_tests/core/test_indexing.py +++ b/torch_np/tests/numpy_tests/core/test_indexing.py @@ -1053,9 +1053,8 @@ def test_non_integer_argument_errors(self): assert_raises(TypeError, np.reshape, a, (1., 1., -1)) assert_raises(TypeError, np.reshape, a, (np.array(1.), -1)) - pytest.xfail("XXX: take not implemented") assert_raises(TypeError, np.take, a, [0], 1.) - assert_raises(TypeError, np.take, a, [0], np.float64(1.)) + assert_raises((TypeError, RuntimeError), np.take, a, [0], np.float64(1.)) @pytest.mark.skip( reason=( @@ -1089,7 +1088,6 @@ def test_bool_as_int_argument_errors(self): # array is thus also deprecated, but not with the same message: assert_warns(DeprecationWarning, operator.index, np.True_) - pytest.xfail("XXX: take not implemented") assert_raises(TypeError, np.take, args=(a, [0], False)) pytest.skip("torch consumes boolean tensors as ints, no bother raising here") @@ -1138,8 +1136,7 @@ def test_array_to_index_error(self): # so no exception is expected. The raising is effectively tested above. a = np.array([[[1]]]) - pytest.xfail("XXX: take not implemented") - assert_raises(TypeError, np.take, a, [0], a) + assert_raises((TypeError, RuntimeError), np.take, a, [0], a) pytest.skip( "Multi-dimensional tensors are indexable just as long as they only " diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index b44e89d2..38caf6d9 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -3817,57 +3817,41 @@ def test_kwargs(self): np.putmask(a=x, values=[-1, -2], mask=[0, 1]) -@pytest.mark.xfail(reason='TODO') class TestTake: def tst_basic(self, x): ind = list(range(x.shape[0])) - assert_array_equal(x.take(ind, axis=0), x) + assert_array_equal(np.take(x, ind, axis=0), x) def test_ip_types(self): - unchecked_types = [bytes, str, np.void] - x = np.random.random(24)*100 - x.shape = 2, 3, 4 + x = np.reshape(x, (2, 3, 4)) for types in np.sctypes.values(): for T in types: - if T not in unchecked_types: - self.tst_basic(x.copy().astype(T)) - - # Also test string of a length which uses an untypical length - self.tst_basic(x.astype("S3")) + self.tst_basic(x.copy().astype(T)) def test_raise(self): x = np.random.random(24)*100 - x.shape = 2, 3, 4 - assert_raises(IndexError, x.take, [0, 1, 2], axis=0) - assert_raises(IndexError, x.take, [-3], axis=0) - assert_array_equal(x.take([-1], axis=0)[0], x[1]) + x = np.reshape(x, (2, 3, 4)) + assert_raises(IndexError, np.take, x, [0, 1, 2], axis=0) + assert_raises(IndexError, np.take, x, [-3], axis=0) + assert_array_equal(np.take(x, [-1], axis=0)[0], x[1]) + @pytest.mark.xfail(reason="XXX: take(..., mode='clip')") def test_clip(self): x = np.random.random(24)*100 - x.shape = 2, 3, 4 - assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0]) - assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1]) + x = np.reshape(x, (2, 3, 4)) + assert_array_equal(np.take(x, [-1], axis=0, mode='clip')[0], x[0]) + assert_array_equal(np.take(x, [2], axis=0, mode='clip')[0], x[1]) + @pytest.mark.xfail(reason="XXX: take(..., mode='wrap')") def test_wrap(self): x = np.random.random(24)*100 - x.shape = 2, 3, 4 - assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1]) - assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0]) - assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1]) - - @pytest.mark.parametrize('dtype', ('>i4', 'f8'), ('z', '>> x = np.array([[[1, 1], [0, 1]], [[1, 0], [0, 0]]]) + # >>> np.unique(x, axis=2) + # [[1, 1], [0, 1]], [[1, 0], [0, 0]] + # >>> torch.unique(torch.as_tensor(x), dim=2) + # [[1, 1], [1, 0]], [[0, 1], [0, 0]] + # msg = 'Unique with 3d array and axis=2 failed' data3d = np.array([[[1, 1], [1, 0]],