Skip to content

ENH: add np.convolve and np.correlate #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions autogen/numpy_api_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,6 @@ def compress(condition, a, axis=None, out=None):
raise NotImplementedError


def convolve(a, v, mode="full"):
raise NotImplementedError


def copyto(dst, src, casting="same_kind", where=True):
raise NotImplementedError


def correlate(a, v, mode="valid"):
raise NotImplementedError


def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
raise NotImplementedError

Expand Down
33 changes: 33 additions & 0 deletions torch_np/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,39 @@ def cov(
return result


def _conv_corr_impl(a, v, mode):
dt = _dtypes_impl.result_type_impl((a.dtype, v.dtype))
a = _util.cast_if_needed(a, dt)
v = _util.cast_if_needed(v, dt)

padding = v.shape[0] - 1 if mode == "full" else mode

# NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights
aa = a[None, :]
vv = v[None, None, :]

result = torch.nn.functional.conv1d(aa, vv, padding=padding)

# torch returns a 2D result, numpy returns a 1D array
return result[0, :]


def convolve(a: ArrayLike, v: ArrayLike, mode="full"):
# NumPy: if v is longer than a, the arrays are swapped before computation
if a.shape[0] < v.shape[0]:
a, v = v, a

# flip the weights since numpy does and torch does not
v = torch.flip(v, (0,))

return _conv_corr_impl(a, v, mode)


def correlate(a: ArrayLike, v: ArrayLike, mode="valid"):
v = torch.conj_physical(v)
return _conv_corr_impl(a, v, mode)


# ### logic & element selection ###


Expand Down
32 changes: 18 additions & 14 deletions torch_np/tests/numpy_tests/core/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
assert_, assert_equal, assert_raises_regex,
assert_array_equal, assert_almost_equal, assert_array_almost_equal,
assert_warns, # assert_array_max_ulp, HAS_REFCOUNT, IS_WASM
assert_allclose,
)
from numpy.core._rational_tests import rational

Expand Down Expand Up @@ -2368,7 +2369,6 @@ def test_dtype_str_bytes(self, likefunc, dtype):
assert result.strides == (4, 1)


@pytest.mark.xfail(reason="TODO")
class TestCorrelate:
def _setup(self, dt):
self.x = np.array([1, 2, 3, 4, 5], dtype=dt)
Expand All @@ -2390,20 +2390,13 @@ def test_float(self):
assert_array_almost_equal(z, self.z1_4)
z = np.correlate(self.y, self.x, 'full')
assert_array_almost_equal(z, self.z2)
z = np.correlate(self.x[::-1], self.y, 'full')
z = np.correlate(np.flip(self.x), self.y, 'full')
assert_array_almost_equal(z, self.z1r)
z = np.correlate(self.y, self.x[::-1], 'full')
z = np.correlate(self.y, np.flip(self.x), 'full')
assert_array_almost_equal(z, self.z2r)
z = np.correlate(self.xs, self.y, 'full')
assert_array_almost_equal(z, self.zs)

def test_object(self):
self._setup(Decimal)
z = np.correlate(self.x, self.y, 'full')
assert_array_almost_equal(z, self.z1)
z = np.correlate(self.y, self.x, 'full')
assert_array_almost_equal(z, self.z2)

def test_no_overwrite(self):
d = np.ones(100)
k = np.ones(3)
Expand All @@ -2415,16 +2408,17 @@ def test_complex(self):
x = np.array([1, 2, 3, 4+1j], dtype=complex)
y = np.array([-1, -2j, 3+1j], dtype=complex)
r_z = np.array([3-1j, 6, 8+1j, 11+5j, -5+8j, -4-1j], dtype=complex)
r_z = r_z[::-1].conjugate()
r_z = np.flip(r_z).conjugate()
z = np.correlate(y, x, mode='full')
assert_array_almost_equal(z, r_z)

def test_zero_size(self):
with pytest.raises(ValueError):
with pytest.raises((ValueError, RuntimeError)):
np.correlate(np.array([]), np.ones(1000), mode='full')
with pytest.raises(ValueError):
with pytest.raises((ValueError, RuntimeError)):
np.correlate(np.ones(1000), np.array([]), mode='full')

@pytest.mark.skip(reason='do not implement deprecated behavior')
def test_mode(self):
d = np.ones(100)
k = np.ones(3)
Expand All @@ -2441,7 +2435,6 @@ def test_mode(self):
np.correlate(d, k, mode=None)


@pytest.mark.xfail(reason="TODO")
class TestConvolve:
def test_object(self):
d = [1.] * 100
Expand All @@ -2455,6 +2448,7 @@ def test_no_overwrite(self):
assert_array_equal(d, np.ones(100))
assert_array_equal(k, np.ones(3))

@pytest.mark.skip(reason='do not implement deprecated behavior')
def test_mode(self):
d = np.ones(100)
k = np.ones(3)
Expand All @@ -2470,6 +2464,16 @@ def test_mode(self):
with assert_raises(TypeError):
np.convolve(d, k, mode=None)

def test_numpy_doc_examples(self):
conv = np.convolve([1, 2, 3], [0, 1, 0.5])
assert_allclose(conv, [0., 1., 2.5, 4., 1.5], atol=1e-15)

conv = np.convolve([1, 2, 3], [0, 1, 0.5], 'same')
assert_allclose(conv, [1., 2.5, 4.], atol=1e-15)

conv = np.convolve([1, 2, 3], [0, 1, 0.5], 'valid')
assert_allclose(conv, [2.5], atol=1e-15)


class TestDtypePositional:

Expand Down