diff --git a/torch_np/_detail/_ufunc_impl.py b/torch_np/_detail/_ufunc_impl.py index 32db069b..2cd1ecee 100644 --- a/torch_np/_detail/_ufunc_impl.py +++ b/torch_np/_detail/_ufunc_impl.py @@ -1,6 +1,6 @@ import torch -from . import _util +from . import _dtypes_impl, _util def deco_ufunc(torch_func): @@ -70,7 +70,6 @@ def wrapped( logical_and = deco_ufunc(torch.logical_and) logical_or = deco_ufunc(torch.logical_or) logical_xor = deco_ufunc(torch.logical_xor) -matmul = deco_ufunc(torch.matmul) maximum = deco_ufunc(torch.maximum) minimum = deco_ufunc(torch.minimum) remainder = deco_ufunc(torch.remainder) @@ -144,6 +143,16 @@ def _absolute(x): return torch.absolute(x) +def _matmul(x, y): + # work around RuntimeError: expected scalar type Int but found Double + dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype)) + x = _util.cast_if_needed(x, dtype) + y = _util.cast_if_needed(y, dtype) + result = torch.matmul(x, y) + return result + + cbrt = deco_ufunc(_cbrt) positive = deco_ufunc(_positive) absolute = deco_ufunc(_absolute) +matmul = deco_ufunc(_matmul) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index b6a4a120..315b28b0 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -492,17 +492,31 @@ def arange(start=None, stop=None, step=1, dtype=None): if start is None: start = 0 + # the dtype of the result if dtype is None: - dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)] dtype = _dtypes_impl.default_int_dtype - dt_list.append(dtype) - dtype = _dtypes_impl.result_type_impl(dt_list) + dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)] + dt_list.append(dtype) + dtype = _dtypes_impl.result_type_impl(dt_list) + + # work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat' + if dtype.is_complex: + work_dtype, target_dtype = torch.float64, dtype + else: + work_dtype, target_dtype = dtype, dtype + + if (step > 0 and start > stop) or (step < 0 and start < stop): + # empty range + return torch.empty(0, dtype=target_dtype) try: - return torch.arange(start, stop, step, dtype=dtype) + result = torch.arange(start, stop, step, dtype=work_dtype) + result = _util.cast_if_needed(result, target_dtype) except RuntimeError: raise ValueError("Maximum allowed size exceeded") + return result + # ### empty/full et al ### @@ -783,6 +797,10 @@ def vdot(t_a, t_b, /): def dot(t_a, t_b): + dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype)) + t_a = _util.cast_if_needed(t_a, dtype) + t_b = _util.cast_if_needed(t_b, dtype) + if t_a.ndim == 0 or t_b.ndim == 0: result = t_a * t_b elif t_a.ndim == 1 and t_b.ndim == 1: diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index b9ae7643..9ebcd364 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -1,7 +1,7 @@ import torch from . import _decorators, _helpers -from ._detail import _flips, _util +from ._detail import _dtypes_impl, _flips, _util from ._detail import implementations as _impl @@ -93,7 +93,17 @@ def fill_diagonal(a, val, wrap=False): return _helpers.array_from(result) -# ### sorting ### +def vdot(a, b, /): + t_a, t_b = _helpers.to_tensors(a, b) + result = _impl.vdot(t_a, t_b) + return result.item() + + +def dot(a, b, out=None): + t_a, t_b = _helpers.to_tensors(a, b) + result = _impl.dot(t_a, t_b) + return _helpers.result_or_out(result, out) + # ### sort and partition ### diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 072b58f8..ea3e0739 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -362,6 +362,7 @@ def reshape(self, *shape, order="C"): diagonal = _funcs.diagonal trace = _funcs.trace + dot = _funcs.dot ### sorting ### diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index ac03e46d..dbf0bf9a 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -421,18 +421,6 @@ def where(condition, x=None, y=None, /): return asarray(result) -def vdot(a, b, /): - t_a, t_b = _helpers.to_tensors(a, b) - result = _impl.vdot(t_a, t_b) - return result.item() - - -def dot(a, b, out=None): - t_a, t_b = _helpers.to_tensors(a, b) - result = _impl.dot(t_a, t_b) - return _helpers.result_or_out(result, out) - - ###### module-level queries of object properties diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index cfe17a23..49e8f01d 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -2386,7 +2386,6 @@ def test_flatten(self): assert_equal(x1.flatten('F'), x1.T.flatten()) - @pytest.mark.xfail(reason="TODO np.dot") @pytest.mark.parametrize('func', (np.dot, np.matmul)) def test_arr_mult(self, func): a = np.array([[1, 0], [0, 1]]) @@ -2428,7 +2427,27 @@ def test_arr_mult(self, func): assert_equal(func(ebf.T, ebf), eaf) assert_equal(func(ebf, ebf.T), eaf) assert_equal(func(ebf.T, ebf.T), eaf) + # syrk - different shape + for et in [np.float32, np.float64, np.complex64, np.complex128]: + edf = d.astype(et) + eddtf = ddt.astype(et) + edtdf = dtd.astype(et) + assert_equal(func(edf, edf.T), eddtf) + assert_equal(func(edf.T, edf), edtdf) + + assert_equal( + func(edf[:edf.shape[0] // 2, :], edf[::2, :].T), + func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy()) + ) + assert_equal( + func(edf[::2, :], edf[:edf.shape[0] // 2, :].T), + func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy()) + ) + + @pytest.mark.skip(reason="dot/matmul with negative strides") + @pytest.mark.parametrize('func', (np.dot, np.matmul)) + def test_arr_mult_2(self, func): # syrk - different shape, stride, and view validations for et in [np.float32, np.float64, np.complex64, np.complex128]: edf = d.astype(et) @@ -2448,24 +2467,7 @@ def test_arr_mult(self, func): func(edf, edf[:, ::-1].T), func(edf, edf[:, ::-1].T.copy()) ) - assert_equal( - func(edf[:edf.shape[0] // 2, :], edf[::2, :].T), - func(edf[:edf.shape[0] // 2, :].copy(), edf[::2, :].T.copy()) - ) - assert_equal( - func(edf[::2, :], edf[:edf.shape[0] // 2, :].T), - func(edf[::2, :].copy(), edf[:edf.shape[0] // 2, :].T.copy()) - ) - # syrk - different shape - for et in [np.float32, np.float64, np.complex64, np.complex128]: - edf = d.astype(et) - eddtf = ddt.astype(et) - edtdf = dtd.astype(et) - assert_equal(func(edf, edf.T), eddtf) - assert_equal(func(edf.T, edf), edtdf) - - @pytest.mark.xfail(reason="TODO np.dot") @pytest.mark.parametrize('func', (np.dot, np.matmul)) @pytest.mark.parametrize('dtype', 'ifdFD') def test_no_dgemv(self, func, dtype): @@ -2481,6 +2483,11 @@ def test_no_dgemv(self, func, dtype): ret2 = func(b.T.copy(), a.T) assert_equal(ret1, ret2) + + @pytest.mark.skip(reason="__array_interface__") + @pytest.mark.parametrize('func', (np.dot, np.matmul)) + @pytest.mark.parametrize('dtype', 'ifdFD') + def test_no_dgemv_2(self, func, dtype): # check for unaligned data dt = np.dtype(dtype) a = np.zeros(8 * dt.itemsize // 2 + 1, dtype='int16')[1:].view(dtype) @@ -2496,7 +2503,6 @@ def test_no_dgemv(self, func, dtype): ret2 = func(b.T.copy(), a.T.copy()) assert_equal(ret1, ret2) - @pytest.mark.xfail(reason="TODO np.dot") def test_dot(self): a = np.array([[1, 0], [0, 1]]) b = np.array([[0, 1], [1, 0]]) @@ -2515,15 +2521,8 @@ def test_dot(self): a.dot(b=b, out=c) assert_equal(c, np.dot(a, b)) - @pytest.mark.xfail(reason="TODO np.dot") - def test_dot_type_mismatch(self): - c = 1. - A = np.array((1,1), dtype='i,i') - assert_raises(TypeError, np.dot, c, A) - assert_raises(TypeError, np.dot, A, c) - - @pytest.mark.xfail(reason="TODO np.dot") + @pytest.mark.xfail(reason="_aligned_zeros") def test_dot_out_mem_overlap(self): np.random.seed(1) @@ -5627,7 +5626,7 @@ def test_dot_array_order(self): assert_equal(np.dot(b, a), res) assert_equal(np.dot(b, b), res) - @pytest.mark.skip(reason='TODO: nbytes, view') + @pytest.mark.skip(reason='TODO: nbytes, view, __array_interface__') def test_accelerate_framework_sgemv_fix(self): def aligned_array(shape, align, dtype, order='C'): @@ -7878,7 +7877,6 @@ def test_view_discard_refcount(self): assert_equal(arr, orig) -@pytest.mark.xfail(reason='TODO') class TestArange: def test_infinite(self): assert_raises_regex( @@ -7887,8 +7885,8 @@ def test_infinite(self): ) def test_nan_step(self): - assert_raises_regex( - ValueError, "cannot compute length", + assert_raises( + ValueError, # "cannot compute length", np.arange, 0, 1, np.nan ) @@ -7904,6 +7902,9 @@ def test_require_range(self): assert_raises(TypeError, np.arange) assert_raises(TypeError, np.arange, step=3) assert_raises(TypeError, np.arange, dtype='int64') + + @pytest.mark.xfail(reason="weird arange signature (optionals before required args)") + def test_require_range_2(self): assert_raises(TypeError, np.arange, start=4) def test_start_stop_kwarg(self): @@ -7916,6 +7917,7 @@ def test_start_stop_kwarg(self): assert len(keyword_start_stop) == 6 assert_array_equal(keyword_stop, keyword_zerotostop) + @pytest.mark.skip(reason="arange for booleans: numpy maybe deprecates?") def test_arange_booleans(self): # Arange makes some sense for booleans and works up to length 2. # But it is weird since `arange(2, 4, dtype=bool)` works. @@ -7936,28 +7938,6 @@ def test_arange_booleans(self): with pytest.raises(TypeError): np.arange(3, dtype="bool") - @pytest.mark.parametrize("dtype", ["S3", "U", "5i"]) - def test_rejects_bad_dtypes(self, dtype): - dtype = np.dtype(dtype) - DType_name = re.escape(str(type(dtype))) - with pytest.raises(TypeError, - match=rf"arange\(\) not supported for inputs .* {DType_name}"): - np.arange(2, dtype=dtype) - - def test_rejects_strings(self): - # Explicitly test error for strings which may call "b" - "a": - DType_name = re.escape(str(type(np.array("a").dtype))) - with pytest.raises(TypeError, - match=rf"arange\(\) not supported for inputs .* {DType_name}"): - np.arange("a", "b") - - def test_byteswapped(self): - res_be = np.arange(1, 1000, dtype=">i4") - res_le = np.arange(1, 1000, dtype="i4" - assert res_le.dtype == "