Skip to content

Commit b975238

Browse files
authored
ENH: add np.convolve and np.correlate (#134)
1 parent 7e39bfd commit b975238

File tree

3 files changed

+51
-26
lines changed

3 files changed

+51
-26
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,18 +176,6 @@ def compress(condition, a, axis=None, out=None):
176176
raise NotImplementedError
177177

178178

179-
def convolve(a, v, mode="full"):
180-
raise NotImplementedError
181-
182-
183-
def copyto(dst, src, casting="same_kind", where=True):
184-
raise NotImplementedError
185-
186-
187-
def correlate(a, v, mode="valid"):
188-
raise NotImplementedError
189-
190-
191179
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
192180
raise NotImplementedError
193181

torch_np/_funcs_impl.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,39 @@ def cov(
570570
return result
571571

572572

573+
def _conv_corr_impl(a, v, mode):
574+
dt = _dtypes_impl.result_type_impl((a.dtype, v.dtype))
575+
a = _util.cast_if_needed(a, dt)
576+
v = _util.cast_if_needed(v, dt)
577+
578+
padding = v.shape[0] - 1 if mode == "full" else mode
579+
580+
# NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights
581+
aa = a[None, :]
582+
vv = v[None, None, :]
583+
584+
result = torch.nn.functional.conv1d(aa, vv, padding=padding)
585+
586+
# torch returns a 2D result, numpy returns a 1D array
587+
return result[0, :]
588+
589+
590+
def convolve(a: ArrayLike, v: ArrayLike, mode="full"):
591+
# NumPy: if v is longer than a, the arrays are swapped before computation
592+
if a.shape[0] < v.shape[0]:
593+
a, v = v, a
594+
595+
# flip the weights since numpy does and torch does not
596+
v = torch.flip(v, (0,))
597+
598+
return _conv_corr_impl(a, v, mode)
599+
600+
601+
def correlate(a: ArrayLike, v: ArrayLike, mode="valid"):
602+
v = torch.conj_physical(v)
603+
return _conv_corr_impl(a, v, mode)
604+
605+
573606
# ### logic & element selection ###
574607

575608

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
assert_, assert_equal, assert_raises_regex,
1414
assert_array_equal, assert_almost_equal, assert_array_almost_equal,
1515
assert_warns, # assert_array_max_ulp, HAS_REFCOUNT, IS_WASM
16+
assert_allclose,
1617
)
1718
from numpy.core._rational_tests import rational
1819

@@ -2368,7 +2369,6 @@ def test_dtype_str_bytes(self, likefunc, dtype):
23682369
assert result.strides == (4, 1)
23692370

23702371

2371-
@pytest.mark.xfail(reason="TODO")
23722372
class TestCorrelate:
23732373
def _setup(self, dt):
23742374
self.x = np.array([1, 2, 3, 4, 5], dtype=dt)
@@ -2390,20 +2390,13 @@ def test_float(self):
23902390
assert_array_almost_equal(z, self.z1_4)
23912391
z = np.correlate(self.y, self.x, 'full')
23922392
assert_array_almost_equal(z, self.z2)
2393-
z = np.correlate(self.x[::-1], self.y, 'full')
2393+
z = np.correlate(np.flip(self.x), self.y, 'full')
23942394
assert_array_almost_equal(z, self.z1r)
2395-
z = np.correlate(self.y, self.x[::-1], 'full')
2395+
z = np.correlate(self.y, np.flip(self.x), 'full')
23962396
assert_array_almost_equal(z, self.z2r)
23972397
z = np.correlate(self.xs, self.y, 'full')
23982398
assert_array_almost_equal(z, self.zs)
23992399

2400-
def test_object(self):
2401-
self._setup(Decimal)
2402-
z = np.correlate(self.x, self.y, 'full')
2403-
assert_array_almost_equal(z, self.z1)
2404-
z = np.correlate(self.y, self.x, 'full')
2405-
assert_array_almost_equal(z, self.z2)
2406-
24072400
def test_no_overwrite(self):
24082401
d = np.ones(100)
24092402
k = np.ones(3)
@@ -2415,16 +2408,17 @@ def test_complex(self):
24152408
x = np.array([1, 2, 3, 4+1j], dtype=complex)
24162409
y = np.array([-1, -2j, 3+1j], dtype=complex)
24172410
r_z = np.array([3-1j, 6, 8+1j, 11+5j, -5+8j, -4-1j], dtype=complex)
2418-
r_z = r_z[::-1].conjugate()
2411+
r_z = np.flip(r_z).conjugate()
24192412
z = np.correlate(y, x, mode='full')
24202413
assert_array_almost_equal(z, r_z)
24212414

24222415
def test_zero_size(self):
2423-
with pytest.raises(ValueError):
2416+
with pytest.raises((ValueError, RuntimeError)):
24242417
np.correlate(np.array([]), np.ones(1000), mode='full')
2425-
with pytest.raises(ValueError):
2418+
with pytest.raises((ValueError, RuntimeError)):
24262419
np.correlate(np.ones(1000), np.array([]), mode='full')
24272420

2421+
@pytest.mark.skip(reason='do not implement deprecated behavior')
24282422
def test_mode(self):
24292423
d = np.ones(100)
24302424
k = np.ones(3)
@@ -2441,7 +2435,6 @@ def test_mode(self):
24412435
np.correlate(d, k, mode=None)
24422436

24432437

2444-
@pytest.mark.xfail(reason="TODO")
24452438
class TestConvolve:
24462439
def test_object(self):
24472440
d = [1.] * 100
@@ -2455,6 +2448,7 @@ def test_no_overwrite(self):
24552448
assert_array_equal(d, np.ones(100))
24562449
assert_array_equal(k, np.ones(3))
24572450

2451+
@pytest.mark.skip(reason='do not implement deprecated behavior')
24582452
def test_mode(self):
24592453
d = np.ones(100)
24602454
k = np.ones(3)
@@ -2470,6 +2464,16 @@ def test_mode(self):
24702464
with assert_raises(TypeError):
24712465
np.convolve(d, k, mode=None)
24722466

2467+
def test_numpy_doc_examples(self):
2468+
conv = np.convolve([1, 2, 3], [0, 1, 0.5])
2469+
assert_allclose(conv, [0., 1., 2.5, 4., 1.5], atol=1e-15)
2470+
2471+
conv = np.convolve([1, 2, 3], [0, 1, 0.5], 'same')
2472+
assert_allclose(conv, [1., 2.5, 4.], atol=1e-15)
2473+
2474+
conv = np.convolve([1, 2, 3], [0, 1, 0.5], 'valid')
2475+
assert_allclose(conv, [2.5], atol=1e-15)
2476+
24732477

24742478
class TestDtypePositional:
24752479

0 commit comments

Comments
 (0)