Skip to content

take() implementation, rebase #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torch_np/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,24 @@ def asfarray():
# ### put/take_along_axis ###


@normalizer
def take(
a: ArrayLike,
indices: ArrayLike,
axis=None,
out: Optional[NDArray] = None,
mode="raise",
):
if mode != "raise":
raise NotImplementedError(f"{mode=}")

(a,), axis = _util.axis_none_ravel(a, axis=axis)
axis = _util.normalize_axis_index(axis, a.ndim)
idx = (slice(None),) * axis + (indices, ...)
result = a[idx]
return result


@normalizer
def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
Expand Down
2 changes: 2 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ def __setitem__(self, index, value):
value = _helpers.ndarrays_to_tensors(value)
return self.tensor.__setitem__(index, value)

take = _funcs.take


# This is the ideally the only place which talks to ndarray directly.
# The rest goes through asarray (preferred) or array.
Expand Down
7 changes: 2 additions & 5 deletions torch_np/tests/numpy_tests/core/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,9 +1053,8 @@ def test_non_integer_argument_errors(self):

assert_raises(TypeError, np.reshape, a, (1., 1., -1))
assert_raises(TypeError, np.reshape, a, (np.array(1.), -1))
pytest.xfail("XXX: take not implemented")
assert_raises(TypeError, np.take, a, [0], 1.)
assert_raises(TypeError, np.take, a, [0], np.float64(1.))
assert_raises((TypeError, RuntimeError), np.take, a, [0], np.float64(1.))

@pytest.mark.skip(
reason=(
Expand Down Expand Up @@ -1089,7 +1088,6 @@ def test_bool_as_int_argument_errors(self):
# array is thus also deprecated, but not with the same message:
assert_warns(DeprecationWarning, operator.index, np.True_)

pytest.xfail("XXX: take not implemented")
assert_raises(TypeError, np.take, args=(a, [0], False))

pytest.skip("torch consumes boolean tensors as ints, no bother raising here")
Expand Down Expand Up @@ -1138,8 +1136,7 @@ def test_array_to_index_error(self):
# so no exception is expected. The raising is effectively tested above.
a = np.array([[[1]]])

pytest.xfail("XXX: take not implemented")
assert_raises(TypeError, np.take, a, [0], a)
assert_raises((TypeError, RuntimeError), np.take, a, [0], a)

pytest.skip(
"Multi-dimensional tensors are indexable just as long as they only "
Expand Down
50 changes: 17 additions & 33 deletions torch_np/tests/numpy_tests/core/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3817,57 +3817,41 @@ def test_kwargs(self):
np.putmask(a=x, values=[-1, -2], mask=[0, 1])


@pytest.mark.xfail(reason='TODO')
class TestTake:
def tst_basic(self, x):
ind = list(range(x.shape[0]))
assert_array_equal(x.take(ind, axis=0), x)
assert_array_equal(np.take(x, ind, axis=0), x)

def test_ip_types(self):
unchecked_types = [bytes, str, np.void]

x = np.random.random(24)*100
x.shape = 2, 3, 4
x = np.reshape(x, (2, 3, 4))
for types in np.sctypes.values():
for T in types:
if T not in unchecked_types:
self.tst_basic(x.copy().astype(T))

# Also test string of a length which uses an untypical length
self.tst_basic(x.astype("S3"))
self.tst_basic(x.copy().astype(T))

def test_raise(self):
x = np.random.random(24)*100
x.shape = 2, 3, 4
assert_raises(IndexError, x.take, [0, 1, 2], axis=0)
assert_raises(IndexError, x.take, [-3], axis=0)
assert_array_equal(x.take([-1], axis=0)[0], x[1])
x = np.reshape(x, (2, 3, 4))
assert_raises(IndexError, np.take, x, [0, 1, 2], axis=0)
assert_raises(IndexError, np.take, x, [-3], axis=0)
assert_array_equal(np.take(x, [-1], axis=0)[0], x[1])

@pytest.mark.xfail(reason="XXX: take(..., mode='clip')")
def test_clip(self):
x = np.random.random(24)*100
x.shape = 2, 3, 4
assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0])
assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1])
x = np.reshape(x, (2, 3, 4))
assert_array_equal(np.take(x, [-1], axis=0, mode='clip')[0], x[0])
assert_array_equal(np.take(x, [2], axis=0, mode='clip')[0], x[1])

@pytest.mark.xfail(reason="XXX: take(..., mode='wrap')")
def test_wrap(self):
x = np.random.random(24)*100
x.shape = 2, 3, 4
assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1])
assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0])
assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1])

@pytest.mark.parametrize('dtype', ('>i4', '<i4'))
def test_byteorder(self, dtype):
x = np.array([1, 2, 3], dtype)
assert_array_equal(x.take([0, 2, 1]), [1, 3, 2])

def test_record_array(self):
# Note mixed byteorder.
rec = np.array([(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
dtype=[('x', '<f8'), ('y', '>f8'), ('z', '<f8')])
rec1 = rec.take([1])
assert_(rec1['x'] == 5.0 and rec1['y'] == 4.0)
x = np.reshape(x, (2, 3, 4))
assert_array_equal(np.take(x, [-1], axis=0, mode='wrap')[0], x[1])
assert_array_equal(np.take(x, [2], axis=0, mode='wrap')[0], x[0])
assert_array_equal(np.take(x, [3], axis=0, mode='wrap')[0], x[1])

@pytest.mark.xfail(reason="XXX: take(mode='wrap')")
def test_out_overlap(self):
# gh-6272 check overlap on out
x = np.arange(5)
Expand Down
1 change: 0 additions & 1 deletion torch_np/tests/numpy_tests/core/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def test_sum(self):

assert_equal(tgt, out)

@pytest.mark.xfail(reason="TODO implement take(...)")
def test_take(self):
tgt = [2, 3, 5]
indices = [1, 2, 4]
Expand Down
10 changes: 9 additions & 1 deletion torch_np/tests/numpy_tests/lib/test_arraysetops.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,6 @@ def test_unique_axis_list(self):
assert_array_equal(unique(inp, axis=0), unique(inp_arr, axis=0), msg)
assert_array_equal(unique(inp, axis=1), unique(inp_arr, axis=1), msg)

@pytest.mark.xfail(reason='TODO: implement take')
def test_unique_axis(self):
types = []
types.extend(np.typecodes['AllInteger'])
Expand Down Expand Up @@ -857,6 +856,15 @@ def _run_axis_tests(self, dtype):
result = np.array([[0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 1, 0]])
assert_array_equal(unique(data, axis=1), result.astype(dtype), msg)

pytest.xfail("torch has different unique ordering behaviour")
# e.g.
#
# >>> x = np.array([[[1, 1], [0, 1]], [[1, 0], [0, 0]]])
# >>> np.unique(x, axis=2)
# [[1, 1], [0, 1]], [[1, 0], [0, 0]]
# >>> torch.unique(torch.as_tensor(x), dim=2)
# [[1, 1], [1, 0]], [[0, 1], [0, 0]]
#
msg = 'Unique with 3d array and axis=2 failed'
data3d = np.array([[[1, 1],
[1, 0]],
Expand Down