diff --git a/autogen/numpy_api_dump.py b/autogen/numpy_api_dump.py index 4ae0df25..1a754d2c 100644 --- a/autogen/numpy_api_dump.py +++ b/autogen/numpy_api_dump.py @@ -176,18 +176,6 @@ def compress(condition, a, axis=None, out=None): raise NotImplementedError -def convolve(a, v, mode="full"): - raise NotImplementedError - - -def copyto(dst, src, casting="same_kind", where=True): - raise NotImplementedError - - -def correlate(a, v, mode="valid"): - raise NotImplementedError - - def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): raise NotImplementedError diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 0f70893c..09217725 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -570,6 +570,39 @@ def cov( return result +def _conv_corr_impl(a, v, mode): + dt = _dtypes_impl.result_type_impl((a.dtype, v.dtype)) + a = _util.cast_if_needed(a, dt) + v = _util.cast_if_needed(v, dt) + + padding = v.shape[0] - 1 if mode == "full" else mode + + # NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights + aa = a[None, :] + vv = v[None, None, :] + + result = torch.nn.functional.conv1d(aa, vv, padding=padding) + + # torch returns a 2D result, numpy returns a 1D array + return result[0, :] + + +def convolve(a: ArrayLike, v: ArrayLike, mode="full"): + # NumPy: if v is longer than a, the arrays are swapped before computation + if a.shape[0] < v.shape[0]: + a, v = v, a + + # flip the weights since numpy does and torch does not + v = torch.flip(v, (0,)) + + return _conv_corr_impl(a, v, mode) + + +def correlate(a: ArrayLike, v: ArrayLike, mode="valid"): + v = torch.conj_physical(v) + return _conv_corr_impl(a, v, mode) + + # ### logic & element selection ### diff --git a/torch_np/tests/numpy_tests/core/test_numeric.py b/torch_np/tests/numpy_tests/core/test_numeric.py index 6f7f9725..a74057e3 100644 --- a/torch_np/tests/numpy_tests/core/test_numeric.py +++ b/torch_np/tests/numpy_tests/core/test_numeric.py @@ -13,6 +13,7 @@ assert_, assert_equal, assert_raises_regex, assert_array_equal, assert_almost_equal, assert_array_almost_equal, assert_warns, # assert_array_max_ulp, HAS_REFCOUNT, IS_WASM + assert_allclose, ) from numpy.core._rational_tests import rational @@ -2368,7 +2369,6 @@ def test_dtype_str_bytes(self, likefunc, dtype): assert result.strides == (4, 1) -@pytest.mark.xfail(reason="TODO") class TestCorrelate: def _setup(self, dt): self.x = np.array([1, 2, 3, 4, 5], dtype=dt) @@ -2390,20 +2390,13 @@ def test_float(self): assert_array_almost_equal(z, self.z1_4) z = np.correlate(self.y, self.x, 'full') assert_array_almost_equal(z, self.z2) - z = np.correlate(self.x[::-1], self.y, 'full') + z = np.correlate(np.flip(self.x), self.y, 'full') assert_array_almost_equal(z, self.z1r) - z = np.correlate(self.y, self.x[::-1], 'full') + z = np.correlate(self.y, np.flip(self.x), 'full') assert_array_almost_equal(z, self.z2r) z = np.correlate(self.xs, self.y, 'full') assert_array_almost_equal(z, self.zs) - def test_object(self): - self._setup(Decimal) - z = np.correlate(self.x, self.y, 'full') - assert_array_almost_equal(z, self.z1) - z = np.correlate(self.y, self.x, 'full') - assert_array_almost_equal(z, self.z2) - def test_no_overwrite(self): d = np.ones(100) k = np.ones(3) @@ -2415,16 +2408,17 @@ def test_complex(self): x = np.array([1, 2, 3, 4+1j], dtype=complex) y = np.array([-1, -2j, 3+1j], dtype=complex) r_z = np.array([3-1j, 6, 8+1j, 11+5j, -5+8j, -4-1j], dtype=complex) - r_z = r_z[::-1].conjugate() + r_z = np.flip(r_z).conjugate() z = np.correlate(y, x, mode='full') assert_array_almost_equal(z, r_z) def test_zero_size(self): - with pytest.raises(ValueError): + with pytest.raises((ValueError, RuntimeError)): np.correlate(np.array([]), np.ones(1000), mode='full') - with pytest.raises(ValueError): + with pytest.raises((ValueError, RuntimeError)): np.correlate(np.ones(1000), np.array([]), mode='full') + @pytest.mark.skip(reason='do not implement deprecated behavior') def test_mode(self): d = np.ones(100) k = np.ones(3) @@ -2441,7 +2435,6 @@ def test_mode(self): np.correlate(d, k, mode=None) -@pytest.mark.xfail(reason="TODO") class TestConvolve: def test_object(self): d = [1.] * 100 @@ -2455,6 +2448,7 @@ def test_no_overwrite(self): assert_array_equal(d, np.ones(100)) assert_array_equal(k, np.ones(3)) + @pytest.mark.skip(reason='do not implement deprecated behavior') def test_mode(self): d = np.ones(100) k = np.ones(3) @@ -2470,6 +2464,16 @@ def test_mode(self): with assert_raises(TypeError): np.convolve(d, k, mode=None) + def test_numpy_doc_examples(self): + conv = np.convolve([1, 2, 3], [0, 1, 0.5]) + assert_allclose(conv, [0., 1., 2.5, 4., 1.5], atol=1e-15) + + conv = np.convolve([1, 2, 3], [0, 1, 0.5], 'same') + assert_allclose(conv, [1., 2.5, 4.], atol=1e-15) + + conv = np.convolve([1, 2, 3], [0, 1, 0.5], 'valid') + assert_allclose(conv, [2.5], atol=1e-15) + class TestDtypePositional: