Skip to content

Commit 04220ab

Browse files
committed
Faster implementation of numba convolve1d
1 parent 2cc864b commit 04220ab

File tree

2 files changed

+84
-7
lines changed

2 files changed

+84
-7
lines changed
Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from numba.np.arraymath import _get_inner_prod
23

34
from pytensor.link.numba.dispatch import numba_funcify
45
from pytensor.link.numba.dispatch.basic import numba_njit
@@ -7,10 +8,65 @@
78

89
@numba_funcify.register(Convolve1d)
910
def numba_funcify_Conv1d(op, node, **kwargs):
11+
# This specialized version is faster than the overloaded numba np.convolve
1012
mode = op.mode
1113

12-
@numba_njit
13-
def conv1d(data, kernel):
14-
return np.convolve(data, kernel, mode=mode)
14+
a_dt = np.dtype(node.inputs[0].dtype)
15+
b_dt = np.dtype(node.inputs[1].dtype)
16+
dt = np.promote_types(a_dt, b_dt)
17+
innerprod = _get_inner_prod(a_dt, b_dt)
1518

16-
return conv1d
19+
if mode == "valid":
20+
21+
def valid_convolve1d(x, y):
22+
nx = len(x)
23+
ny = len(y)
24+
if nx < ny:
25+
x, y = y, x
26+
nx, ny = ny, nx
27+
y_flipped = y[::-1]
28+
29+
length = nx - ny + 1
30+
ret = np.empty(length, dt)
31+
32+
for i in range(length):
33+
ret[i] = innerprod(x[i : i + ny], y_flipped)
34+
35+
return ret
36+
37+
return numba_njit(valid_convolve1d)
38+
39+
elif mode == "full":
40+
41+
def full_convolve1d(x, y):
42+
nx = len(x)
43+
ny = len(y)
44+
if nx < ny:
45+
x, y = y, x
46+
nx, ny = ny, nx
47+
y_flipped = y[::-1]
48+
49+
length = nx + ny - 1
50+
ret = np.empty(length, dt)
51+
idx = 0
52+
53+
for i in range(ny - 1):
54+
k = i + 1
55+
ret[idx] = innerprod(x[:k], y_flipped[-k:])
56+
idx = idx + 1
57+
58+
for i in range(nx - ny + 1):
59+
ret[idx] = innerprod(x[i : i + ny], y_flipped)
60+
idx = idx + 1
61+
62+
for i in range(ny - 1):
63+
k = ny - i - 1
64+
ret[idx] = innerprod(x[-k:], y_flipped[:k])
65+
idx = idx + 1
66+
67+
return ret
68+
69+
return numba_njit(full_convolve1d)
70+
71+
else:
72+
raise ValueError(f"Unsupported mode: {mode}")

tests/link/numba/signal/test_conv.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22
import pytest
33

4-
from pytensor.tensor import dmatrix
4+
from pytensor import function
5+
from pytensor.tensor import dmatrix, vector
56
from pytensor.tensor.signal import convolve1d
67
from tests.link.numba.test_basic import compare_numba_and_py
78

@@ -10,13 +11,33 @@
1011

1112

1213
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
13-
def test_convolve1d(mode):
14+
@pytest.mark.parametrize("x_smaller", (False, True))
15+
def test_convolve1d(x_smaller, mode):
1416
x = dmatrix("x")
1517
y = dmatrix("y")
16-
out = convolve1d(x[None], y[:, None], mode=mode)
18+
if x_smaller:
19+
out = convolve1d(x[None], y[:, None], mode=mode)
20+
else:
21+
out = convolve1d(y[:, None], x[None], mode=mode)
1722

1823
rng = np.random.default_rng()
1924
test_x = rng.normal(size=(3, 5))
2025
test_y = rng.normal(size=(7, 11))
2126
# Blockwise dispatch for numba can't be run on object mode
2227
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)
28+
29+
30+
@pytest.mark.parametrize("mode", ("full", "valid"))
31+
def test_convolve_benchmark(mode, benchmark):
32+
x = vector(shape=(183,))
33+
y = vector(shape=(6,))
34+
out = convolve1d(x, y, mode=mode)
35+
fn = function([x, y], out, mode="NUMBA", trust_input=True)
36+
37+
rng = np.random.default_rng()
38+
x_test = rng.normal(size=(x.type.shape)).astype(x.type.dtype)
39+
y_test = rng.normal(size=(y.type.shape)).astype(y.type.dtype)
40+
np.testing.assert_allclose(
41+
fn(x_test, y_test), np.convolve(x_test, y_test, mode=mode)
42+
)
43+
benchmark(fn, x_test, y_test)

0 commit comments

Comments
 (0)