Skip to content

Commit a149892

Browse files
committed
ENH: add np.convolve
1 parent 7e39bfd commit a149892

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,6 @@ def compress(condition, a, axis=None, out=None):
176176
raise NotImplementedError
177177

178178

179-
def convolve(a, v, mode="full"):
180-
raise NotImplementedError
181-
182-
183-
def copyto(dst, src, casting="same_kind", where=True):
184-
raise NotImplementedError
185-
186-
187179
def correlate(a, v, mode="valid"):
188180
raise NotImplementedError
189181

torch_np/_funcs_impl.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,28 @@ def cov(
570570
return result
571571

572572

573+
def convolve(a: ArrayLike, v: ArrayLike, mode="full"):
574+
# NumPy: if v is longer than a, the arrays are swapped before computation
575+
if a.shape[0] < v.shape[0]:
576+
a, v = v, a
577+
578+
dt = _dtypes_impl.result_type_impl((a.dtype, v.dtype))
579+
a = _util.cast_if_needed(a, dt)
580+
v = _util.cast_if_needed(v, dt)
581+
582+
padding = v.shape[0] - 1 if mode == "full" else mode
583+
584+
# 1. NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights
585+
# 2. flip the weights since numpy does and torch does not
586+
aa = a[None, :]
587+
vv = torch.flip(v, (0,))[None, None, :]
588+
589+
result = torch.nn.functional.conv1d(aa, vv, padding=padding)
590+
591+
# torch returns a 2D result, numpy returns a 1D array
592+
return result[0, :]
593+
594+
573595
# ### logic & element selection ###
574596

575597

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
assert_, assert_equal, assert_raises_regex,
1414
assert_array_equal, assert_almost_equal, assert_array_almost_equal,
1515
assert_warns, # assert_array_max_ulp, HAS_REFCOUNT, IS_WASM
16+
assert_allclose,
1617
)
1718
from numpy.core._rational_tests import rational
1819

@@ -2441,7 +2442,6 @@ def test_mode(self):
24412442
np.correlate(d, k, mode=None)
24422443

24432444

2444-
@pytest.mark.xfail(reason="TODO")
24452445
class TestConvolve:
24462446
def test_object(self):
24472447
d = [1.] * 100
@@ -2455,6 +2455,7 @@ def test_no_overwrite(self):
24552455
assert_array_equal(d, np.ones(100))
24562456
assert_array_equal(k, np.ones(3))
24572457

2458+
@pytest.mark.skip(reason='do not implement deprecated behavior')
24582459
def test_mode(self):
24592460
d = np.ones(100)
24602461
k = np.ones(3)
@@ -2470,6 +2471,16 @@ def test_mode(self):
24702471
with assert_raises(TypeError):
24712472
np.convolve(d, k, mode=None)
24722473

2474+
def test_numpy_doc_examples(self):
2475+
conv = np.convolve([1, 2, 3], [0, 1, 0.5])
2476+
assert_allclose(conv, [0., 1., 2.5, 4., 1.5], atol=1e-15)
2477+
2478+
conv = np.convolve([1, 2, 3], [0, 1, 0.5], 'same')
2479+
assert_allclose(conv, [1., 2.5, 4.], atol=1e-15)
2480+
2481+
conv = np.convolve([1, 2, 3], [0, 1, 0.5], 'valid')
2482+
assert_allclose(conv, [2.5], atol=1e-15)
2483+
24732484

24742485
class TestDtypePositional:
24752486

0 commit comments

Comments
 (0)