Skip to content

Commit 8670aaf

Browse files
committed
Faster implementation of numba convolve1d
1 parent 2cc864b commit 8670aaf

File tree

2 files changed

+83
-8
lines changed

2 files changed

+83
-8
lines changed
Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,70 @@
11
import numpy as np
2+
from numba.np.arraymath import _get_inner_prod
23

34
from pytensor.link.numba.dispatch import numba_funcify
45
from pytensor.link.numba.dispatch.basic import numba_njit
56
from pytensor.tensor.signal.conv import Convolve1d
67

78

89
@numba_funcify.register(Convolve1d)
9-
def numba_funcify_Conv1d(op, node, **kwargs):
10+
def numba_funcify_Convolve1d(op, node, **kwargs):
11+
# This specialized version is faster than the overloaded numba np.convolve
1012
mode = op.mode
13+
a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype
14+
out_dtype = node.outputs[0].type.dtype
15+
innerprod = _get_inner_prod(a_dtype, b_dtype)
1116

12-
@numba_njit
13-
def conv1d(data, kernel):
14-
return np.convolve(data, kernel, mode=mode)
17+
if mode == "valid":
1518

16-
return conv1d
19+
def valid_convolve1d(x, y):
20+
nx = len(x)
21+
ny = len(y)
22+
if nx < ny:
23+
x, y = y, x
24+
nx, ny = ny, nx
25+
y_flipped = y[::-1]
26+
27+
length = nx - ny + 1
28+
ret = np.empty(length, out_dtype)
29+
30+
for i in range(length):
31+
ret[i] = innerprod(x[i : i + ny], y_flipped)
32+
33+
return ret
34+
35+
return numba_njit(valid_convolve1d)
36+
37+
elif mode == "full":
38+
39+
def full_convolve1d(x, y):
40+
nx = len(x)
41+
ny = len(y)
42+
if nx < ny:
43+
x, y = y, x
44+
nx, ny = ny, nx
45+
y_flipped = y[::-1]
46+
47+
length = nx + ny - 1
48+
ret = np.empty(length, out_dtype)
49+
idx = 0
50+
51+
for i in range(ny - 1):
52+
k = i + 1
53+
ret[idx] = innerprod(x[:k], y_flipped[-k:])
54+
idx = idx + 1
55+
56+
for i in range(nx - ny + 1):
57+
ret[idx] = innerprod(x[i : i + ny], y_flipped)
58+
idx = idx + 1
59+
60+
for i in range(ny - 1):
61+
k = ny - i - 1
62+
ret[idx] = innerprod(x[-k:], y_flipped[:k])
63+
idx = idx + 1
64+
65+
return ret
66+
67+
return numba_njit(full_convolve1d)
68+
69+
else:
70+
raise ValueError(f"Unsupported mode: {mode}")

tests/link/numba/signal/test_conv.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22
import pytest
33

4-
from pytensor.tensor import dmatrix
4+
from pytensor import function
5+
from pytensor.tensor import dmatrix, vector
56
from pytensor.tensor.signal import convolve1d
67
from tests.link.numba.test_basic import compare_numba_and_py
78

@@ -10,13 +11,33 @@
1011

1112

1213
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
13-
def test_convolve1d(mode):
14+
@pytest.mark.parametrize("x_smaller", (False, True))
15+
def test_convolve1d(x_smaller, mode):
1416
x = dmatrix("x")
1517
y = dmatrix("y")
16-
out = convolve1d(x[None], y[:, None], mode=mode)
18+
if x_smaller:
19+
out = convolve1d(x[None], y[:, None], mode=mode)
20+
else:
21+
out = convolve1d(y[:, None], x[None], mode=mode)
1722

1823
rng = np.random.default_rng()
1924
test_x = rng.normal(size=(3, 5))
2025
test_y = rng.normal(size=(7, 11))
2126
# Blockwise dispatch for numba can't be run on object mode
2227
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)
28+
29+
30+
@pytest.mark.parametrize("mode", ("full", "valid"))
31+
def test_convolve_benchmark(mode, benchmark):
32+
x = vector(shape=(183,))
33+
y = vector(shape=(6,))
34+
out = convolve1d(x, y, mode=mode)
35+
fn = function([x, y], out, mode="NUMBA", trust_input=True)
36+
37+
rng = np.random.default_rng()
38+
x_test = rng.normal(size=(x.type.shape)).astype(x.type.dtype)
39+
y_test = rng.normal(size=(y.type.shape)).astype(y.type.dtype)
40+
np.testing.assert_allclose(
41+
fn(x_test, y_test), np.convolve(x_test, y_test, mode=mode)
42+
)
43+
benchmark(fn, x_test, y_test)

0 commit comments

Comments
 (0)