Skip to content

Commit d70cab6

Browse files
committed
MAINT: address review comments
1 parent 0e2438e commit d70cab6

File tree

5 files changed

+18
-30
lines changed

5 files changed

+18
-30
lines changed

torch_np/_detail/_reductions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue):
145145

146146
is_half = dtype == torch.float16
147147
if is_half:
148+
# XXX revisit when the pytorch version has pytorch/pytorch#95166
148149
dtype = torch.float32
149150

150151
if axis is None:

torch_np/_detail/implementations.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,6 @@ def trace(tensor, offset=0, axis1=0, axis2=1, dtype=None, out=None):
179179
def diagonal(tensor, offset=0, axis1=0, axis2=1):
180180
axis1 = _util.normalize_axis_index(axis1, tensor.ndim)
181181
axis2 = _util.normalize_axis_index(axis2, tensor.ndim)
182-
if axis1 == axis2:
183-
raise ValueError("axis1 and axis2 cannot be the same")
184182
result = torch.diagonal(tensor, offset, axis1, axis2)
185183
return result
186184

torch_np/_ndarray.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,7 @@ def ravel(self, order="C"):
371371
def flatten(self, order="C"):
372372
if order != "C":
373373
raise NotImplementedError
374-
# return a copy
375-
result = self._tensor.ravel().clone()
374+
result = self._tensor.flatten()
376375
return asarray(result)
377376

378377
def nonzero(self):

torch_np/random.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
]
3030

3131

32-
def array_or_scalar(values, py_type=float, size=None):
33-
if size is None:
32+
def array_or_scalar(values, py_type=float, return_scalar=False):
33+
if return_scalar:
3434
return py_type(values.item())
3535
else:
3636
return asarray(values)
@@ -45,7 +45,7 @@ def random_sample(size=None):
4545
if size is None:
4646
size = ()
4747
values = torch.empty(size, dtype=_default_dtype).uniform_()
48-
return array_or_scalar(values, size=size)
48+
return array_or_scalar(values, return_scalar=size is None)
4949

5050

5151
def rand(*size):
@@ -60,19 +60,19 @@ def uniform(low=0.0, high=1.0, size=None):
6060
if size is None:
6161
size = ()
6262
values = torch.empty(size, dtype=_default_dtype).uniform_(low, high)
63-
return array_or_scalar(values, size=size)
63+
return array_or_scalar(values, return_scalar=size is None)
6464

6565

6666
def randn(*size):
6767
values = torch.randn(size, dtype=_default_dtype)
68-
return array_or_scalar(values, size=size)
68+
return array_or_scalar(values, return_scalar=size is None)
6969

7070

7171
def normal(loc=0.0, scale=1.0, size=None):
7272
if size is None:
7373
size = ()
7474
values = torch.empty(size, dtype=_default_dtype).normal_(loc, scale)
75-
return array_or_scalar(values, size=size)
75+
return array_or_scalar(values, return_scalar=size is None)
7676

7777

7878
def shuffle(x):
@@ -90,7 +90,7 @@ def randint(low, high=None, size=None):
9090
if high is None:
9191
low, high = 0, low
9292
values = torch.randint(low, high, size=size)
93-
return array_or_scalar(values, int, size=size)
93+
return array_or_scalar(values, int, return_scalar=size is None)
9494

9595

9696
def choice(a, size=None, replace=True, p=None):

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,7 +1686,7 @@ def assert_c(arr):
16861686
assert_fortran(a.copy('F'))
16871687
assert_c(a.copy('A'))
16881688

1689-
@pytest.mark.xfail(reason="no .ctypes attribute")
1689+
@pytest.mark.skip(reason="no .ctypes attribute")
16901690
@pytest.mark.parametrize("dtype", [np.int32])
16911691
def test__deepcopy__(self, dtype):
16921692
# Force the entry of NULLs into array
@@ -1715,7 +1715,7 @@ def test_argsort(self):
17151715
assert_equal(a.copy().argsort(kind=kind), a, msg)
17161716
assert_equal(b.copy().argsort(kind=kind), b, msg)
17171717

1718-
@pytest.mark.xfail(reason='argsort complex')
1718+
@pytest.mark.skip(reason='argsort complex')
17191719
def test_argsort_complex(self):
17201720
a = np.arange(101, dtype=np.float32)
17211721
b = np.flip(a)
@@ -2546,7 +2546,7 @@ def test_dot_out_mem_overlap(self):
25462546
assert_raises(ValueError, np.dot, a, b, out=b[::2])
25472547
assert_raises(ValueError, np.dot, a, b, out=b.T)
25482548

2549-
@pytest.mark.xfail(reason="TODO [::-1]")
2549+
@pytest.mark.xfail(reason="TODO: overlapping memor in matmul")
25502550
def test_matmul_out(self):
25512551
# overlapping memory
25522552
a = np.arange(18).reshape(2, 3, 3)
@@ -2568,14 +2568,14 @@ def test_diagonal(self):
25682568
assert_raises(np.AxisError, a.diagonal, axis1=0, axis2=5)
25692569
assert_raises(np.AxisError, a.diagonal, axis1=5, axis2=0)
25702570
assert_raises(np.AxisError, a.diagonal, axis1=5, axis2=5)
2571-
assert_raises(ValueError, a.diagonal, axis1=1, axis2=1)
2571+
assert_raises((ValueError, RuntimeError), a.diagonal, axis1=1, axis2=1)
25722572

25732573
b = np.arange(8).reshape((2, 2, 2))
25742574
assert_equal(b.diagonal(), [[0, 6], [1, 7]])
25752575
assert_equal(b.diagonal(0), [[0, 6], [1, 7]])
25762576
assert_equal(b.diagonal(1), [[2], [3]])
25772577
assert_equal(b.diagonal(-1), [[4], [5]])
2578-
assert_raises(ValueError, b.diagonal, axis1=0, axis2=0)
2578+
assert_raises((ValueError, RuntimeError), b.diagonal, axis1=0, axis2=0)
25792579
assert_equal(b.diagonal(0, 1, 2), [[0, 3], [4, 7]])
25802580
assert_equal(b.diagonal(0, 0, 1), [[0, 6], [1, 7]])
25812581
assert_equal(b.diagonal(offset=1, axis1=0, axis2=2), [[1], [3]])
@@ -2805,7 +2805,6 @@ def test_swapaxes(self):
28052805
if k == 1:
28062806
b = c
28072807

2808-
@pytest.mark.xfail(reason="TODO: ndarray.conjugate")
28092808
def test_conjugate(self):
28102809
a = np.array([1-1j, 1+1j, 23+23.0j])
28112810
ac = a.conj()
@@ -2833,17 +2832,8 @@ def test_conjugate(self):
28332832
assert_equal(ac, a.conjugate())
28342833
assert_equal(ac, np.conjugate(a))
28352834

2836-
a = np.array([1-1j, 1+1j, 1, 2.0], object)
2837-
ac = a.conj()
2838-
assert_equal(ac, [k.conjugate() for k in a])
2839-
assert_equal(ac, a.conjugate())
2840-
assert_equal(ac, np.conjugate(a))
28412835

2842-
a = np.array([1-1j, 1, 2.0, 'f'], object)
2843-
assert_raises(TypeError, lambda: a.conj())
2844-
assert_raises(TypeError, lambda: a.conjugate())
2845-
2846-
@pytest.mark.xfail(reason="TODO: ndarray.conjugate")
2836+
@pytest.mark.xfail(reason="TODO: ndarray.conjugate with out")
28472837
def test_conjugate_out(self):
28482838
# Minimal test for the out argument being passed on correctly
28492839
# NOTE: The ability to pass `out` is currently undocumented!
@@ -3754,7 +3744,7 @@ def test_ret_is_out(self, ndim, method):
37543744
ret = arg_method(axis=0, out=out)
37553745
assert ret is out
37563746

3757-
@pytest.mark.xfail(reason='FIXME: keepdims w/ positional args?')
3747+
@pytest.mark.xfail(reason='FIXME: out w/ positional args?')
37583748
@pytest.mark.parametrize('arr_method, np_method',
37593749
[('argmax', np.argmax),
37603750
('argmin', np.argmin)])
@@ -5438,7 +5428,7 @@ class TestDot:
54385428
def setup_method(self):
54395429
np.random.seed(128)
54405430

5441-
# Numpy guarantees the random stream, and we don't. So inline the
5431+
# Numpy and pytorch random streams differ, so inline the
54425432
# values from numpy 1.24.1
54435433
# self.A = np.random.rand(4, 2)
54445434
self.A = np.array([[0.86663704, 0.26314485],
@@ -5626,7 +5616,7 @@ def test_dot_3args_errors(self):
56265616
r = np.empty((1024, 32), dtype=int)
56275617
assert_raises(ValueError, dot, f, v, r)
56285618

5629-
@pytest.mark.skip(reason="TODO order='F'")
5619+
@pytest.mark.xfail(reason="TODO order='F'")
56305620
def test_dot_array_order(self):
56315621
a = np.array([[1, 2], [3, 4]], order='C')
56325622
b = np.array([[1, 2], [3, 4]], order='F')

0 commit comments

Comments
 (0)