Skip to content

take() implementation #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions torch_np/_detail/implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,14 @@ def take_along_dim(tensor, t_indices, axis):
return result


def take(tensor, t_indices, axis):
(tensor,), axis = _util.axis_none_ravel(tensor, axis=axis)
axis = _util.normalize_axis_index(axis, tensor.ndim)
idx = (slice(None),) * axis + (t_indices, ...)
result = tensor[idx]
return result


def put_along_dim(tensor, t_indices, t_values, axis):
(tensor,), axis = _util.axis_none_ravel(tensor, axis=axis)
axis = _util.normalize_axis_index(axis, tensor.ndim)
Expand Down
3 changes: 3 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ def __setitem__(self, index, value):
value = _helpers.ndarrays_to_tensors(value)
return self.tensor.__setitem__(index, value)

def take(*a, **kw):
raise NotImplementedError()


# This is the ideally the only place which talks to ndarray directly.
# The rest goes through asarray (preferred) or array.
Expand Down
12 changes: 11 additions & 1 deletion torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def asfarray():
raise NotImplementedError


# ### put/take_along_axis ###
# ### put/take et al ###


@normalizer
Expand All @@ -824,6 +824,16 @@ def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
return result


@normalizer
def take(a: ArrayLike, indices: ArrayLike, axis=None, out=None, mode="raise"):
if out is not None:
raise NotImplementedError(f"{out=}")
if mode != "raise":
raise NotImplementedError(f"{mode=}")
result = _impl.take(a, indices, axis)
return result


@normalizer
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
# modify the argument in-place : here `arr` is `arr._tensor` of the orignal `arr` argument
Expand Down
7 changes: 2 additions & 5 deletions torch_np/tests/numpy_tests/core/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,9 +1053,8 @@ def test_non_integer_argument_errors(self):

assert_raises(TypeError, np.reshape, a, (1., 1., -1))
assert_raises(TypeError, np.reshape, a, (np.array(1.), -1))
pytest.xfail("XXX: take not implemented")
assert_raises(TypeError, np.take, a, [0], 1.)
assert_raises(TypeError, np.take, a, [0], np.float64(1.))
assert_raises((TypeError, RuntimeError), np.take, a, [0], np.float64(1.))

@pytest.mark.skip(
reason=(
Expand Down Expand Up @@ -1089,7 +1088,6 @@ def test_bool_as_int_argument_errors(self):
# array is thus also deprecated, but not with the same message:
assert_warns(DeprecationWarning, operator.index, np.True_)

pytest.xfail("XXX: take not implemented")
assert_raises(TypeError, np.take, args=(a, [0], False))

pytest.skip("torch consumes boolean tensors as ints, no bother raising here")
Expand Down Expand Up @@ -1138,8 +1136,7 @@ def test_array_to_index_error(self):
# so no exception is expected. The raising is effectively tested above.
a = np.array([[[1]]])

pytest.xfail("XXX: take not implemented")
assert_raises(TypeError, np.take, a, [0], a)
assert_raises((TypeError, RuntimeError), np.take, a, [0], a)

pytest.skip(
"Multi-dimensional tensors are indexable just as long as they only "
Expand Down
51 changes: 18 additions & 33 deletions torch_np/tests/numpy_tests/core/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4179,63 +4179,48 @@ def test_kwargs(self):
np.putmask(a=x, values=[-1, -2], mask=[0, 1])


@pytest.mark.xfail(reason='TODO')
class TestTake:
def tst_basic(self, x):
ind = list(range(x.shape[0]))
assert_array_equal(x.take(ind, axis=0), x)
assert_array_equal(np.take(x, ind, axis=0), x)

def test_ip_types(self):
unchecked_types = [bytes, str, np.void]

x = np.random.random(24)*100
x.shape = 2, 3, 4
x = np.reshape(x, (2, 3, 4))
for types in np.sctypes.values():
for T in types:
if T not in unchecked_types:
self.tst_basic(x.copy().astype(T))

# Also test string of a length which uses an untypical length
self.tst_basic(x.astype("S3"))
self.tst_basic(x.copy().astype(T))

def test_raise(self):
x = np.random.random(24)*100
x.shape = 2, 3, 4
assert_raises(IndexError, x.take, [0, 1, 2], axis=0)
assert_raises(IndexError, x.take, [-3], axis=0)
assert_array_equal(x.take([-1], axis=0)[0], x[1])
x = np.reshape(x, (2, 3, 4))
assert_raises(IndexError, np.take, x, [0, 1, 2], axis=0)
assert_raises(IndexError, np.take, x, [-3], axis=0)
assert_array_equal(np.take(x, [-1], axis=0)[0], x[1])

@pytest.mark.xfail(reason="XXX: take(..., mode='clip')")
def test_clip(self):
x = np.random.random(24)*100
x.shape = 2, 3, 4
assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0])
assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1])
x = np.reshape(x, (2, 3, 4))
assert_array_equal(np.take(x, [-1], axis=0, mode='clip')[0], x[0])
assert_array_equal(np.take(x, [2], axis=0, mode='clip')[0], x[1])

@pytest.mark.xfail(reason="XXX: take(..., mode='wrap')")
def test_wrap(self):
x = np.random.random(24)*100
x.shape = 2, 3, 4
assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1])
assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0])
assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1])

@pytest.mark.parametrize('dtype', ('>i4', '<i4'))
def test_byteorder(self, dtype):
x = np.array([1, 2, 3], dtype)
assert_array_equal(x.take([0, 2, 1]), [1, 3, 2])

def test_record_array(self):
# Note mixed byteorder.
rec = np.array([(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
dtype=[('x', '<f8'), ('y', '>f8'), ('z', '<f8')])
rec1 = rec.take([1])
assert_(rec1['x'] == 5.0 and rec1['y'] == 4.0)
x = np.reshape(x, (2, 3, 4))
assert_array_equal(np.take(x, [-1], axis=0, mode='wrap')[0], x[1])
assert_array_equal(np.take(x, [2], axis=0, mode='wrap')[0], x[0])
assert_array_equal(np.take(x, [3], axis=0, mode='wrap')[0], x[1])

@pytest.mark.xfail(reason="XXX: take(out=...)")
def test_out_overlap(self):
# gh-6272 check overlap on out
x = np.arange(5)
y = np.take(x, [1, 2, 3], out=x[2:5], mode='wrap')
assert_equal(y, np.array([1, 2, 3]))

@pytest.mark.xfail(reason="XXX: take(out=...)")
@pytest.mark.parametrize('shape', [(1, 2), (1,), ()])
def test_ret_is_out(self, shape):
# 0d arrays should not be an exception to this rule
Expand Down
1 change: 0 additions & 1 deletion torch_np/tests/numpy_tests/core/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def test_sum(self):

assert_equal(tgt, out)

@pytest.mark.xfail(reason="TODO implement take(...)")
def test_take(self):
tgt = [2, 3, 5]
indices = [1, 2, 4]
Expand Down
10 changes: 9 additions & 1 deletion torch_np/tests/numpy_tests/lib/test_arraysetops.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,6 @@ def test_unique_axis_list(self):
assert_array_equal(unique(inp, axis=0), unique(inp_arr, axis=0), msg)
assert_array_equal(unique(inp, axis=1), unique(inp_arr, axis=1), msg)

@pytest.mark.xfail(reason='TODO: implement take')
def test_unique_axis(self):
types = []
types.extend(np.typecodes['AllInteger'])
Expand Down Expand Up @@ -857,6 +856,15 @@ def _run_axis_tests(self, dtype):
result = np.array([[0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 1, 0]])
assert_array_equal(unique(data, axis=1), result.astype(dtype), msg)

pytest.xfail("torch has different unique ordering behaviour")
# e.g.
#
# >>> x = np.array([[[1, 1], [0, 1]], [[1, 0], [0, 0]]])
# >>> np.unique(x, axis=2)
# [[1, 1], [0, 1]], [[1, 0], [0, 0]]
# >>> torch.unique(torch.as_tensor(x), dim=2)
# [[1, 1], [1, 0]], [[0, 1], [0, 0]]
#
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we need to fix in our unique?

The Notes section in https://numpy.org/doc/stable/reference/generated/numpy.unique.html is probably relevant, if somewhat obscure.

WDYT @lezcano ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unique against axis gave me a headache so I had left this xfail'd and documented incase you/Mario had an immediate answer 😅)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uf. This looks like a bit of a pain to implement. I'd say we don't touch it and leave it xfailed.

msg = 'Unique with 3d array and axis=2 failed'
data3d = np.array([[[1, 1],
[1, 0]],
Expand Down