Skip to content

Commit ae5d1a5

Browse files
honnoev-br
authored andcommitted
Remove take()-not-implemented xfails
1 parent 0e8e6b5 commit ae5d1a5

File tree

5 files changed

+32
-40
lines changed

5 files changed

+32
-40
lines changed

torch_np/_ndarray.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ def __setitem__(self, index, value):
407407
value = _helpers.ndarrays_to_tensors(value)
408408
return self.tensor.__setitem__(index, value)
409409

410+
def take(*a, **kw):
411+
raise NotImplementedError()
412+
410413

411414
# This is the ideally the only place which talks to ndarray directly.
412415
# The rest goes through asarray (preferred) or array.

torch_np/tests/numpy_tests/core/test_indexing.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,9 +1053,8 @@ def test_non_integer_argument_errors(self):
10531053

10541054
assert_raises(TypeError, np.reshape, a, (1., 1., -1))
10551055
assert_raises(TypeError, np.reshape, a, (np.array(1.), -1))
1056-
pytest.xfail("XXX: take not implemented")
10571056
assert_raises(TypeError, np.take, a, [0], 1.)
1058-
assert_raises(TypeError, np.take, a, [0], np.float64(1.))
1057+
assert_raises((TypeError, RuntimeError), np.take, a, [0], np.float64(1.))
10591058

10601059
@pytest.mark.skip(
10611060
reason=(
@@ -1089,7 +1088,6 @@ def test_bool_as_int_argument_errors(self):
10891088
# array is thus also deprecated, but not with the same message:
10901089
assert_warns(DeprecationWarning, operator.index, np.True_)
10911090

1092-
pytest.xfail("XXX: take not implemented")
10931091
assert_raises(TypeError, np.take, args=(a, [0], False))
10941092

10951093
pytest.skip("torch consumes boolean tensors as ints, no bother raising here")
@@ -1138,8 +1136,7 @@ def test_array_to_index_error(self):
11381136
# so no exception is expected. The raising is effectively tested above.
11391137
a = np.array([[[1]]])
11401138

1141-
pytest.xfail("XXX: take not implemented")
1142-
assert_raises(TypeError, np.take, a, [0], a)
1139+
assert_raises((TypeError, RuntimeError), np.take, a, [0], a)
11431140

11441141
pytest.skip(
11451142
"Multi-dimensional tensors are indexable just as long as they only "

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3817,63 +3817,48 @@ def test_kwargs(self):
38173817
np.putmask(a=x, values=[-1, -2], mask=[0, 1])
38183818

38193819

3820-
@pytest.mark.xfail(reason='TODO')
38213820
class TestTake:
38223821
def tst_basic(self, x):
38233822
ind = list(range(x.shape[0]))
3824-
assert_array_equal(x.take(ind, axis=0), x)
3823+
assert_array_equal(np.take(x, ind, axis=0), x)
38253824

38263825
def test_ip_types(self):
3827-
unchecked_types = [bytes, str, np.void]
3828-
38293826
x = np.random.random(24)*100
3830-
x.shape = 2, 3, 4
3827+
x = np.reshape(x, (2, 3, 4))
38313828
for types in np.sctypes.values():
38323829
for T in types:
3833-
if T not in unchecked_types:
3834-
self.tst_basic(x.copy().astype(T))
3835-
3836-
# Also test string of a length which uses an untypical length
3837-
self.tst_basic(x.astype("S3"))
3830+
self.tst_basic(x.copy().astype(T))
38383831

38393832
def test_raise(self):
38403833
x = np.random.random(24)*100
3841-
x.shape = 2, 3, 4
3842-
assert_raises(IndexError, x.take, [0, 1, 2], axis=0)
3843-
assert_raises(IndexError, x.take, [-3], axis=0)
3844-
assert_array_equal(x.take([-1], axis=0)[0], x[1])
3834+
x = np.reshape(x, (2, 3, 4))
3835+
assert_raises(IndexError, np.take, x, [0, 1, 2], axis=0)
3836+
assert_raises(IndexError, np.take, x, [-3], axis=0)
3837+
assert_array_equal(np.take(x, [-1], axis=0)[0], x[1])
38453838

3839+
@pytest.mark.xfail(reason="XXX: take(..., mode='clip')")
38463840
def test_clip(self):
38473841
x = np.random.random(24)*100
3848-
x.shape = 2, 3, 4
3849-
assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0])
3850-
assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1])
3842+
x = np.reshape(x, (2, 3, 4))
3843+
assert_array_equal(np.take(x, [-1], axis=0, mode='clip')[0], x[0])
3844+
assert_array_equal(np.take(x, [2], axis=0, mode='clip')[0], x[1])
38513845

3846+
@pytest.mark.xfail(reason="XXX: take(..., mode='wrap')")
38523847
def test_wrap(self):
38533848
x = np.random.random(24)*100
3854-
x.shape = 2, 3, 4
3855-
assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1])
3856-
assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0])
3857-
assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1])
3858-
3859-
@pytest.mark.parametrize('dtype', ('>i4', '<i4'))
3860-
def test_byteorder(self, dtype):
3861-
x = np.array([1, 2, 3], dtype)
3862-
assert_array_equal(x.take([0, 2, 1]), [1, 3, 2])
3863-
3864-
def test_record_array(self):
3865-
# Note mixed byteorder.
3866-
rec = np.array([(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
3867-
dtype=[('x', '<f8'), ('y', '>f8'), ('z', '<f8')])
3868-
rec1 = rec.take([1])
3869-
assert_(rec1['x'] == 5.0 and rec1['y'] == 4.0)
3849+
x = np.reshape(x, (2, 3, 4))
3850+
assert_array_equal(np.take(x, [-1], axis=0, mode='wrap')[0], x[1])
3851+
assert_array_equal(np.take(x, [2], axis=0, mode='wrap')[0], x[0])
3852+
assert_array_equal(np.take(x, [3], axis=0, mode='wrap')[0], x[1])
38703853

3854+
@pytest.mark.xfail(reason="XXX: take(out=...)")
38713855
def test_out_overlap(self):
38723856
# gh-6272 check overlap on out
38733857
x = np.arange(5)
38743858
y = np.take(x, [1, 2, 3], out=x[2:5], mode='wrap')
38753859
assert_equal(y, np.array([1, 2, 3]))
38763860

3861+
@pytest.mark.xfail(reason="XXX: take(out=...)")
38773862
@pytest.mark.parametrize('shape', [(1, 2), (1,), ()])
38783863
def test_ret_is_out(self, shape):
38793864
# 0d arrays should not be an exception to this rule

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ def test_sum(self):
282282

283283
assert_equal(tgt, out)
284284

285-
@pytest.mark.xfail(reason="TODO implement take(...)")
286285
def test_take(self):
287286
tgt = [2, 3, 5]
288287
indices = [1, 2, 4]

torch_np/tests/numpy_tests/lib/test_arraysetops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,6 @@ def test_unique_axis_list(self):
768768
assert_array_equal(unique(inp, axis=0), unique(inp_arr, axis=0), msg)
769769
assert_array_equal(unique(inp, axis=1), unique(inp_arr, axis=1), msg)
770770

771-
@pytest.mark.xfail(reason='TODO: implement take')
772771
def test_unique_axis(self):
773772
types = []
774773
types.extend(np.typecodes['AllInteger'])
@@ -857,6 +856,15 @@ def _run_axis_tests(self, dtype):
857856
result = np.array([[0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 1, 0]])
858857
assert_array_equal(unique(data, axis=1), result.astype(dtype), msg)
859858

859+
pytest.xfail("torch has different unique ordering behaviour")
860+
# e.g.
861+
#
862+
# >>> x = np.array([[[1, 1], [0, 1]], [[1, 0], [0, 0]]])
863+
# >>> np.unique(x, axis=2)
864+
# [[1, 1], [0, 1]], [[1, 0], [0, 0]]
865+
# >>> torch.unique(torch.as_tensor(x), dim=2)
866+
# [[1, 1], [1, 0]], [[0, 1], [0, 0]]
867+
#
860868
msg = 'Unique with 3d array and axis=2 failed'
861869
data3d = np.array([[[1, 1],
862870
[1, 0]],

0 commit comments

Comments
 (0)