From e1ad9d8119753c3490f37847026352ef760e43b7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 27 Feb 2023 20:57:24 +0300 Subject: [PATCH 1/5] add ndarray.dot --- torch_np/_funcs.py | 13 ++++- torch_np/_ndarray.py | 1 + torch_np/_wrapper.py | 11 ---- .../tests/numpy_tests/core/test_multiarray.py | 52 +++++++++---------- 4 files changed, 39 insertions(+), 38 deletions(-) diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index b9ae7643..f9f3b841 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -93,7 +93,18 @@ def fill_diagonal(a, val, wrap=False): return _helpers.array_from(result) -# ### sorting ### +def vdot(a, b, /): + t_a, t_b = _helpers.to_tensors(a, b) + result = _impl.vdot(t_a, t_b) + return result.item() + + +def dot(a, b, out=None): + t_a, t_b = _helpers.to_tensors(a, b) + result = _impl.dot(t_a, t_b) + return _helpers.result_or_out(result, out) + + # ### sort and partition ### diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 072b58f8..ea3e0739 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -362,6 +362,7 @@ def reshape(self, *shape, order="C"): diagonal = _funcs.diagonal trace = _funcs.trace + dot = _funcs.dot ### sorting ### diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index ac03e46d..047c969d 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -421,17 +421,6 @@ def where(condition, x=None, y=None, /): return asarray(result) -def vdot(a, b, /): - t_a, t_b = _helpers.to_tensors(a, b) - result = _impl.vdot(t_a, t_b) - return result.item() - - -def dot(a, b, out=None): - t_a, t_b = _helpers.to_tensors(a, b) - result = _impl.dot(t_a, t_b) - return _helpers.result_or_out(result, out) - ###### module-level queries of object properties diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index cfe17a23..9269e586 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -2386,7 +2386,6 @@ def test_flatten(self): assert_equal(x1.flatten('F'), x1.T.flatten()) - @pytest.mark.xfail(reason="TODO np.dot") @pytest.mark.parametrize('func', (np.dot, np.matmul)) def test_arr_mult(self, func): a = np.array([[1, 0], [0, 1]]) @@ -2428,7 +2427,27 @@ def test_arr_mult(self, func): assert_equal(func(ebf.T, ebf), eaf) assert_equal(func(ebf, ebf.T), eaf) assert_equal(func(ebf.T, ebf.T), eaf) + # syrk - different shape + for et in [np.float32, np.float64, np.complex64, np.complex128]: + edf = d.astype(et) + eddtf = ddt.astype(et) + edtdf = dtd.astype(et) + assert_equal(func(edf, edf.T), eddtf) + assert_equal(func(edf.T, edf), edtdf) + assert_equal( + func(edf[:edf.shape[0] // 2, :], edf[::2, :].T), + func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy()) + ) + assert_equal( + func(edf[::2, :], edf[:edf.shape[0] // 2, :].T), + func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy()) + ) + + + @pytest.mark.skip(reason="dot/matmul with negative strides") + @pytest.mark.parametrize('func', (np.dot, np.matmul)) + def test_arr_mult_2(self, func): # syrk - different shape, stride, and view validations for et in [np.float32, np.float64, np.complex64, np.complex128]: edf = d.astype(et) @@ -2448,22 +2467,6 @@ def test_arr_mult(self, func): func(edf, edf[:, ::-1].T), func(edf, edf[:, ::-1].T.copy()) ) - assert_equal( - func(edf[:edf.shape[0] // 2, :], edf[::2, :].T), - func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy()) - ) - assert_equal( - func(edf[::2, :], edf[:edf.shape[0] // 2, :].T), - func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy()) - ) - - # syrk - different shape - for et in [np.float32, np.float64, np.complex64, np.complex128]: - edf = d.astype(et) - eddtf = ddt.astype(et) - edtdf = dtd.astype(et) - assert_equal(func(edf, edf.T), eddtf) - assert_equal(func(edf.T, edf), edtdf) @pytest.mark.xfail(reason="TODO np.dot") @pytest.mark.parametrize('func', (np.dot, np.matmul)) @@ -2481,6 +2484,11 @@ def test_no_dgemv(self, func, dtype): ret2 = func(b.T.copy(), a.T) assert_equal(ret1, ret2) + + @pytest.mark.skip(reason="__array_interface__") + @pytest.mark.parametrize('func', (np.dot, np.matmul)) + @pytest.mark.parametrize('dtype', 'ifdFD') + def test_no_dgemv_2(self, func, dtype): # check for unaligned data dt = np.dtype(dtype) a = np.zeros(8 * dt.itemsize // 2 + 1, dtype='int16')[1:].view(dtype) @@ -2496,7 +2504,6 @@ def test_no_dgemv(self, func, dtype): ret2 = func(b.T.copy(), a.T.copy()) assert_equal(ret1, ret2) - @pytest.mark.xfail(reason="TODO np.dot") def test_dot(self): a = np.array([[1, 0], [0, 1]]) b = np.array([[0, 1], [1, 0]]) @@ -2515,15 +2522,8 @@ def test_dot(self): a.dot(b=b, out=c) assert_equal(c, np.dot(a, b)) - @pytest.mark.xfail(reason="TODO np.dot") - def test_dot_type_mismatch(self): - c = 1. - A = np.array((1,1), dtype='i,i') - assert_raises(TypeError, np.dot, c, A) - assert_raises(TypeError, np.dot, A, c) - - @pytest.mark.xfail(reason="TODO np.dot") + @pytest.mark.xfail(reason="_aligned_zeros") def test_dot_out_mem_overlap(self): np.random.seed(1) From 32e98440c427c46c75aed23fe839fbe13afb115c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 28 Feb 2023 01:18:05 +0300 Subject: [PATCH 2/5] BUG: fix up matmult/dot --- torch_np/_detail/_ufunc_impl.py | 10 ++++++++- torch_np/_detail/implementations.py | 22 ++++++++++++++----- torch_np/_funcs.py | 5 ++++- .../tests/numpy_tests/core/test_multiarray.py | 1 - 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/torch_np/_detail/_ufunc_impl.py b/torch_np/_detail/_ufunc_impl.py index 32db069b..e3741da1 100644 --- a/torch_np/_detail/_ufunc_impl.py +++ b/torch_np/_detail/_ufunc_impl.py @@ -1,6 +1,7 @@ import torch from . import _util +from . import _dtypes_impl def deco_ufunc(torch_func): @@ -70,7 +71,6 @@ def wrapped( logical_and = deco_ufunc(torch.logical_and) logical_or = deco_ufunc(torch.logical_or) logical_xor = deco_ufunc(torch.logical_xor) -matmul = deco_ufunc(torch.matmul) maximum = deco_ufunc(torch.maximum) minimum = deco_ufunc(torch.minimum) remainder = deco_ufunc(torch.remainder) @@ -143,7 +143,15 @@ def _absolute(x): return x return torch.absolute(x) +def _matmul(x, y): + # work around RuntimeError: expected scalar type Int but found Double + dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype)) + x = x.to(dtype) + y = y.to(dtype) + result = torch.matmul(x, y) + return result cbrt = deco_ufunc(_cbrt) positive = deco_ufunc(_positive) absolute = deco_ufunc(_absolute) +matmul = deco_ufunc(_matmul) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index b6a4a120..c90d821d 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -178,6 +178,8 @@ def trace(tensor, offset=0, axis1=0, axis2=1, dtype=None, out=None): def diagonal(tensor, offset=0, axis1=0, axis2=1): axis1 = _util.normalize_axis_index(axis1, tensor.ndim) axis2 = _util.normalize_axis_index(axis2, tensor.ndim) + if axis1 == axis2: + raise ValueError("axis1 and axis2 cannot be the same") result = torch.diagonal(tensor, offset, axis1, axis2) return result @@ -492,17 +494,25 @@ def arange(start=None, stop=None, step=1, dtype=None): if start is None: start = 0 - if dtype is None: - dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)] - dtype = _dtypes_impl.default_int_dtype - dt_list.append(dtype) - dtype = _dtypes_impl.result_type_impl(dt_list) +# if dtype is None: + dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)] + dtype = _dtypes_impl.default_int_dtype + dt_list.append(dtype) + dtype = _dtypes_impl.result_type_impl(dt_list) + # work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat' + orig_dtype = dtype + is_complex = dtype is not None and dtype.is_complex try: - return torch.arange(start, stop, step, dtype=dtype) + if is_complex: + dtype = torch.float64 + result = torch.arange(start, stop, step, dtype=orig_dtype) + if is_complex: + result = result.to(dttype) except RuntimeError: raise ValueError("Maximum allowed size exceeded") + return result # ### empty/full et al ### diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index f9f3b841..582189c8 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -1,7 +1,7 @@ import torch from . import _decorators, _helpers -from ._detail import _flips, _util +from ._detail import _flips, _util, _dtypes_impl from ._detail import implementations as _impl @@ -101,6 +101,9 @@ def vdot(a, b, /): def dot(a, b, out=None): t_a, t_b = _helpers.to_tensors(a, b) + dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype)) + t_a = t_a.to(dtype) + t_b = t_b.to(dtype) result = _impl.dot(t_a, t_b) return _helpers.result_or_out(result, out) diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index 9269e586..21c34ee3 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -2468,7 +2468,6 @@ def test_arr_mult_2(self, func): func(edf, edf[:, ::-1].T.copy()) ) - @pytest.mark.xfail(reason="TODO np.dot") @pytest.mark.parametrize('func', (np.dot, np.matmul)) @pytest.mark.parametrize('dtype', 'ifdFD') def test_no_dgemv(self, func, dtype): From 9328f11120aa0cbd6ff6ea4d098acdc4b081dd82 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 28 Feb 2023 11:10:36 +0300 Subject: [PATCH 3/5] BUG: fix arange w/complex dtypes --- torch_np/_detail/_ufunc_impl.py | 5 ++- torch_np/_detail/implementations.py | 28 ++++++------ torch_np/_funcs.py | 3 +- torch_np/_wrapper.py | 1 - .../tests/numpy_tests/core/test_multiarray.py | 43 ++++--------------- 5 files changed, 28 insertions(+), 52 deletions(-) diff --git a/torch_np/_detail/_ufunc_impl.py b/torch_np/_detail/_ufunc_impl.py index e3741da1..a19c74f7 100644 --- a/torch_np/_detail/_ufunc_impl.py +++ b/torch_np/_detail/_ufunc_impl.py @@ -1,7 +1,6 @@ import torch -from . import _util -from . import _dtypes_impl +from . import _dtypes_impl, _util def deco_ufunc(torch_func): @@ -143,6 +142,7 @@ def _absolute(x): return x return torch.absolute(x) + def _matmul(x, y): # work around RuntimeError: expected scalar type Int but found Double dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype)) @@ -151,6 +151,7 @@ def _matmul(x, y): result = torch.matmul(x, y) return result + cbrt = deco_ufunc(_cbrt) positive = deco_ufunc(_positive) absolute = deco_ufunc(_absolute) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index c90d821d..30b95214 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -178,8 +178,6 @@ def trace(tensor, offset=0, axis1=0, axis2=1, dtype=None, out=None): def diagonal(tensor, offset=0, axis1=0, axis2=1): axis1 = _util.normalize_axis_index(axis1, tensor.ndim) axis2 = _util.normalize_axis_index(axis2, tensor.ndim) - if axis1 == axis2: - raise ValueError("axis1 and axis2 cannot be the same") result = torch.diagonal(tensor, offset, axis1, axis2) return result @@ -494,26 +492,32 @@ def arange(start=None, stop=None, step=1, dtype=None): if start is None: start = 0 -# if dtype is None: + # the dtype of the result + if dtype is None: + dtype = _dtypes_impl.default_int_dtype dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)] - dtype = _dtypes_impl.default_int_dtype dt_list.append(dtype) dtype = _dtypes_impl.result_type_impl(dt_list) - # work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat' - orig_dtype = dtype - is_complex = dtype is not None and dtype.is_complex + # work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat' + if dtype.is_complex: + work_dtype, target_dtype = torch.float64, dtype + else: + work_dtype, target_dtype = dtype, dtype + + if (step > 0 and start > stop) or (step < 0 and start < stop): + # empty range + return torch.as_tensor([], dtype=target_dtype) + try: - if is_complex: - dtype = torch.float64 - result = torch.arange(start, stop, step, dtype=orig_dtype) - if is_complex: - result = result.to(dttype) + result = torch.arange(start, stop, step, dtype=work_dtype) + result = _util.cast_if_needed(result, target_dtype) except RuntimeError: raise ValueError("Maximum allowed size exceeded") return result + # ### empty/full et al ### diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 582189c8..80963dd5 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -1,7 +1,7 @@ import torch from . import _decorators, _helpers -from ._detail import _flips, _util, _dtypes_impl +from ._detail import _dtypes_impl, _flips, _util from ._detail import implementations as _impl @@ -108,7 +108,6 @@ def dot(a, b, out=None): return _helpers.result_or_out(result, out) - # ### sort and partition ### diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 047c969d..dbf0bf9a 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -421,7 +421,6 @@ def where(condition, x=None, y=None, /): return asarray(result) - ###### module-level queries of object properties diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index 21c34ee3..49e8f01d 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -5626,7 +5626,7 @@ def test_dot_array_order(self): assert_equal(np.dot(b, a), res) assert_equal(np.dot(b, b), res) - @pytest.mark.skip(reason='TODO: nbytes, view') + @pytest.mark.skip(reason='TODO: nbytes, view, __array_interface__') def test_accelerate_framework_sgemv_fix(self): def aligned_array(shape, align, dtype, order='C'): @@ -7877,7 +7877,6 @@ def test_view_discard_refcount(self): assert_equal(arr, orig) -@pytest.mark.xfail(reason='TODO') class TestArange: def test_infinite(self): assert_raises_regex( @@ -7886,8 +7885,8 @@ def test_infinite(self): ) def test_nan_step(self): - assert_raises_regex( - ValueError, "cannot compute length", + assert_raises( + ValueError, # "cannot compute length", np.arange, 0, 1, np.nan ) @@ -7903,6 +7902,9 @@ def test_require_range(self): assert_raises(TypeError, np.arange) assert_raises(TypeError, np.arange, step=3) assert_raises(TypeError, np.arange, dtype='int64') + + @pytest.mark.xfail(reason="weird arange signature (optionals before required args)") + def test_require_range_2(self): assert_raises(TypeError, np.arange, start=4) def test_start_stop_kwarg(self): @@ -7915,6 +7917,7 @@ def test_start_stop_kwarg(self): assert len(keyword_start_stop) == 6 assert_array_equal(keyword_stop, keyword_zerotostop) + @pytest.mark.skip(reason="arange for booleans: numpy maybe deprecates?") def test_arange_booleans(self): # Arange makes some sense for booleans and works up to length 2. # But it is weird since `arange(2, 4, dtype=bool)` works. @@ -7935,28 +7938,6 @@ def test_arange_booleans(self): with pytest.raises(TypeError): np.arange(3, dtype="bool") - @pytest.mark.parametrize("dtype", ["S3", "U", "5i"]) - def test_rejects_bad_dtypes(self, dtype): - dtype = np.dtype(dtype) - DType_name = re.escape(str(type(dtype))) - with pytest.raises(TypeError, - match=rf"arange\(\) not supported for inputs .* {DType_name}"): - np.arange(2, dtype=dtype) - - def test_rejects_strings(self): - # Explicitly test error for strings which may call "b" - "a": - DType_name = re.escape(str(type(np.array("a").dtype))) - with pytest.raises(TypeError, - match=rf"arange\(\) not supported for inputs .* {DType_name}"): - np.arange("a", "b") - - def test_byteswapped(self): - res_be = np.arange(1, 1000, dtype=">i4") - res_le = np.arange(1, 1000, dtype="i4" - assert res_le.dtype == " Date: Tue, 28 Feb 2023 10:45:52 +0200 Subject: [PATCH 4/5] Update torch_np/_detail/implementations.py Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> --- torch_np/_detail/implementations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 30b95214..1f320587 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -507,7 +507,7 @@ def arange(start=None, stop=None, step=1, dtype=None): if (step > 0 and start > stop) or (step < 0 and start < stop): # empty range - return torch.as_tensor([], dtype=target_dtype) + return torch.empty(0, dtype=target_type) try: result = torch.arange(start, stop, step, dtype=work_dtype) From ff75f7a3136b4c78ba44bed96082fdbeba35ffd1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 28 Feb 2023 11:54:09 +0300 Subject: [PATCH 5/5] MAINT: address review comments --- torch_np/_detail/_ufunc_impl.py | 4 ++-- torch_np/_detail/implementations.py | 6 +++++- torch_np/_funcs.py | 3 --- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torch_np/_detail/_ufunc_impl.py b/torch_np/_detail/_ufunc_impl.py index a19c74f7..2cd1ecee 100644 --- a/torch_np/_detail/_ufunc_impl.py +++ b/torch_np/_detail/_ufunc_impl.py @@ -146,8 +146,8 @@ def _absolute(x): def _matmul(x, y): # work around RuntimeError: expected scalar type Int but found Double dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype)) - x = x.to(dtype) - y = y.to(dtype) + x = _util.cast_if_needed(x, dtype) + y = _util.cast_if_needed(y, dtype) result = torch.matmul(x, y) return result diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 1f320587..315b28b0 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -507,7 +507,7 @@ def arange(start=None, stop=None, step=1, dtype=None): if (step > 0 and start > stop) or (step < 0 and start < stop): # empty range - return torch.empty(0, dtype=target_type) + return torch.empty(0, dtype=target_dtype) try: result = torch.arange(start, stop, step, dtype=work_dtype) @@ -797,6 +797,10 @@ def vdot(t_a, t_b, /): def dot(t_a, t_b): + dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype)) + t_a = _util.cast_if_needed(t_a, dtype) + t_b = _util.cast_if_needed(t_b, dtype) + if t_a.ndim == 0 or t_b.ndim == 0: result = t_a * t_b elif t_a.ndim == 1 and t_b.ndim == 1: diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 80963dd5..9ebcd364 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -101,9 +101,6 @@ def vdot(a, b, /): def dot(a, b, out=None): t_a, t_b = _helpers.to_tensors(a, b) - dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype)) - t_a = t_a.to(dtype) - t_b = t_b.to(dtype) result = _impl.dot(t_a, t_b) return _helpers.result_or_out(result, out)